2017-05-10 48 views

回答

0

不幸的是,在MxNet中還沒有實現混合密度網絡(MDN)。而且,由於MxNet是一項社區活動,因此歡迎您的參與!

從Keras/TF移植代碼應該相當簡單。對於MxNet來說,R綁定目前相當有限,因爲現在不可能創建自定義操作,但考慮這個例子,我沒有看到需要任何自定義操作。

我還沒有運行此代碼,但這裏是如何從你的榜樣MDN模型看起來像使用Python的MxNet符號API:

def mapping(self, X): 
    """pi, mu, sigma = NN(x; theta)""" 
    hidden1 = mx.sym.FullyConnected(data=X, num_hidden=15) # fully-connected layer with 15 hidden units 
    act1 = mx.sym.Activation(data=hidden1, act_type='relu') 
    hidden2 = mx.sym.FullyConnected(data=act1, num_hidden=15) # fully-connected layer with 15 hidden units 
    act2 = mx.sym.Activation(data=hidden2, act_type='relu') 

    self.mus = mx.sym.FullyConnected(data=act2, num_hidden=self.K) # fully-connected layer with 15 hidden units 

    sigma_fc = mx.sym.FullyConnected(data=act2, num_hidden=self.K) 
    self.sigmas = mx.sym.exp(data=sigma_fc) # the variance 

    pi_fc = mx.sym.FullyConnected(data=act2, num_hidden=self.K) 
    self.pi = mx.sym.SoftmaxActivation(data=pi_fc) # the mixture components