我試圖編譯從Numeric.AD以下小例子:最小Numeric.AD例如不會編譯
import Numeric.AD
timeAndGrad f l = grad f l
main = putStrLn "hi"
,我碰到這個錯誤:
test.hs:3:24:
Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
s a)
-> Numeric.AD.Internal.Reverse.Reverse s a’
with actual type ‘t’
because type variable ‘s’ would escape its scope
This (rigid, skolem) type variable is bound by
a type expected by the context:
Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
f (Numeric.AD.Internal.Reverse.Reverse s a)
-> Numeric.AD.Internal.Reverse.Reverse s a
at test.hs:3:19-26
Relevant bindings include
l :: f a (bound at test.hs:3:15)
f :: t (bound at test.hs:3:13)
timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
In the first argument of ‘grad’, namely ‘f’
In the expression: grad f l
任何線索爲什麼會發生這種情況?通過觀察前面的例子據我瞭解,這是「扁平化」 grad
的類型:
grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a
,但我真的需要做這樣的事情在我的代碼。事實上,這是不能編譯的最簡單的例子。我想要做的更復雜的事情是這樣的:
example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
where gradient = grad f x
gradientFn = grad f
(other where clauses involving gradient and gradient "function")
這裏有一個稍微複雜一點的版本類型簽名,它編譯。
{-# LANGUAGE RankNTypes #-}
import Numeric.AD
import Numeric.AD.Internal.Reverse
-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l
-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
where f' = Lift . f . extractAll
-- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work
extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
where extract (Lift x) = x -- non-exhaustive pattern match
dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)
-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]
但是,我無法弄清楚如何使用第一個版本,grad2
,在代碼,因爲我不知道該如何處理Reverse s a
。第二個版本grad2'
具有正確的類型,因爲我使用內部構造函數Lift
來創建Reverse s a
,但我不能理解內部(特別是參數s
)的工作方式,因爲輸出漸變全部爲0。使用其他構造函數Reverse
(此處未顯示)也會產生錯誤的漸變。
或者,有沒有人們使用ad
代碼的庫/代碼的例子?我認爲我的用例是非常普遍的用例。
如果您向timeAndGrad提供類型簽名,會發生什麼情況?一級方法可能會帶來更多運氣。 – ocharles
我編輯我的問題,添加一個類型簽名和另一種方法(這也不起作用)。 – kye