2016-12-26 74 views
1

我試圖編譯從Numeric.AD以下小例子:最小Numeric.AD例如不會編譯

import Numeric.AD 
timeAndGrad f l = grad f l 
main = putStrLn "hi" 

,我碰到這個錯誤:

test.hs:3:24: 
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse 
             s a) 
            -> Numeric.AD.Internal.Reverse.Reverse s a’ 
       with actual type ‘t’ 
     because type variable ‘s’ would escape its scope 
    This (rigid, skolem) type variable is bound by 
     a type expected by the context: 
     Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape => 
     f (Numeric.AD.Internal.Reverse.Reverse s a) 
     -> Numeric.AD.Internal.Reverse.Reverse s a 
     at test.hs:3:19-26 
    Relevant bindings include 
     l :: f a (bound at test.hs:3:15) 
     f :: t (bound at test.hs:3:13) 
     timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1) 
    In the first argument of ‘grad’, namely ‘f’ 
    In the expression: grad f l 

任何線索爲什麼會發生這種情況?通過觀察前面的例子據我瞭解,這是「扁平化」 grad的類型:

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

,但我真的需要做這樣的事情在我的代碼。事實上,這是不能編譯的最簡單的例子。我想要做的更復雜的事情是這樣的:

example :: SomeType 
example f x args = (do stuff with the gradient and gradient "function") 
    where gradient = grad f x 
      gradientFn = grad f 
      (other where clauses involving gradient and gradient "function") 

這裏有一個稍微複雜一點的版本類型簽名,它編譯。

{-# LANGUAGE RankNTypes #-} 

import Numeric.AD 
import Numeric.AD.Internal.Reverse 

-- compiles but I can't figure out how to use it in code 
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a] 
grad2 f l = grad f l 

-- compiles with the right type, but the resulting gradient is all 0s... 
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a] 
grad2' f l = grad f' l 
     where f' = Lift . f . extractAll 
     -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work 

extractAll :: [Reverse t a] -> [a] 
extractAll xs = map extract xs 
      where extract (Lift x) = x -- non-exhaustive pattern match 

dist :: (Show a, Num a, Floating a) => [a] -> a 
dist [x, y] = sqrt(x^2 + y^2) 

-- incorrect output: [0.0, 0.0] 
main = putStrLn $ show $ grad2' dist [1,2] 

但是,我無法弄清楚如何使用第一個版本,grad2,在代碼,因爲我不知道該如何處理Reverse s a。第二個版本grad2'具有正確的類型,因爲我使用內部構造函數Lift來創建Reverse s a,但我不能理解內部(特別是參數s)的工作方式,因爲輸出漸變全部爲0。使用其他構造函數Reverse(此處未顯示)也會產生錯誤的漸變。

或者,有沒有人們使用ad代碼的庫/代碼的例子?我認爲我的用例是非常普遍的用例。

+2

如果您向timeAndGrad提供類型簽名,會發生什麼情況?一級方法可能會帶來更多運氣。 – ocharles

+0

我編輯我的問題,添加一個類型簽名和另一種方法(這也不起作用)。 – kye

回答

2

隨着where f' = Lift . f . extractAll你基本上創建了一個後門進入自動分化基礎類型,拋出所有的派生物,只保留常量值。如果你使用這個爲grad,你得到一個零結果並不奇怪!

明智的辦法是隻使用grad,因爲它是:

dist :: Floating a => [a] -> a 
dist [x, y] = sqrt $ x^2 + y^2 
-- preferrable is of course `dist = sqrt . sum . map (^2)` 

main = print $ grad dist [1,2] 
-- output: [0.4472135954999579,0.8944271909999159] 

你並不真的需要知道什麼更復雜的使用自動分化。只要你只區分NumFloating多態函數,一切都會按原樣運行。如果您需要區分作爲參數傳入的函數,則需要將該參數設爲rank-2多態(可以選擇切換到ad函數的rank-1版本,但我敢說這不太優雅,並沒有真正獲得你的好處)。

{-# LANGUAGE Rank2Types, UnicodeSyntax #-} 

mainWith :: (∀n . Floating n => [n] -> n) -> IO() 
mainWith f = print $ grad f [1,2] 

main = mainWith dist 
+0

是的,我需要區分作爲參數傳入的函數。你能解釋一下更多你的意思嗎?「使這個參數排名-2多態?」我也嘗試切換到grad的1級版本,並且它要求我指定函數具有類型([Forward a] - > Forward a)。 – kye

+0

這個類型不允許在這裏使用類型爲(Num a => [a] - > a)的函數,但是我可以在代碼中傳遞它。我不知道爲什麼類型的行爲是這樣的。 – kye