2017-06-21 44 views
4

我正在解決文本分類問題。我用我自己的model_fn使用Estimator類定義了我的分類器。我想使用Google的預先訓練好的word2vec嵌入作爲初始值,然後針對當前的任務對其進行進一步優化。加載預先訓練好的word2vec在Estimator中初始化embedding_lookup model_fn

我看到這篇文章:Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
它解釋瞭如何在'原始'TensorFlow代碼中去解決它。但是,我真的很想使用Estimator類。

作爲一個擴展,我想在Cloud ML引擎上訓練這個代碼,是否有一種很好的方式來傳遞具有初始值的相當大的文件?

比方說,我們有這樣的事:

def build_model_fn(): 
    def _model_fn(features, labels, mode, params): 
     input_layer = features['feat'] #shape=[-1, params["sequence_length"]] 
     #... what goes here to initialize W 

     embedded = tf.nn.embedding_lookup(W, input_layer) 
     ... 
     return predictions 

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(), 
    model_dir=MODEL_DIR, 
    params=params) 
estimator.fit(input_fn=read_data, max_steps=2500) 

回答

7

曲面嵌入足夠,唯一可行的辦法是用它們來初始化在圖中的tf.Variable一般較大。這將允許你利用分佈式的參數服務器等。

對於這個(和其他任何東西),我建議你使用新的「核心」估計器,因爲這會使事情變得更容易。

從您提供的鏈接的答案,知道我們想要一個變量不是一個常量,我們可以採取的方法:

(2)使用飼料字典初始化變量,或 (3)從檢查點加載可變


我將選擇(3)第一,因爲它更容易,更好地:

在你model_fn,只需初始化使用Tensor返回的變量由tf.contrib.framework.load_variable打電話。這就要求:

  1. ,你必須與你的嵌入
  2. 你知道檢查站內的嵌入變量的完全合格名有效TF檢查點。

的代碼非常簡單:

def model_fn(mode, features, labels, hparams): 
    embeddings = tf.Variable(tf.contrib.framework.load_variable(
     'gs://my-bucket/word2vec_checkpoints/', 
     'a/fully/qualified/scope/embeddings' 
)) 
    .... 
    return tf.estimator.EstimatorSpec(...) 

但是這種方法不會爲你工作,如果你的嵌入不是由另一TF模型產生的,因此選項(2)。


對於(2),我們需要使用tf.train.Scaffold其基本上保持所有的選項用於開始tf.Session(其估計有意隱藏有許多原因)的配置對象。

您可以在model_fn返回的tf.train.EstimatorSpec中指定Scaffold

我們在我們的model_fn中創建一個佔位符,並將其設爲 對我們的嵌入變量進行初始化操作,然後通過Scaffold傳遞init_feed_dict。例如

def model_fn(mode, features, labels, hparams): 
    embed_ph = tf.placeholder(
     shape=[hparams.vocab_size, hparams.embedding_size], 
     dtype=tf.float32) 
    embeddings = tf.Variable(embed_ph) 
    # Define your model 
    return tf.estimator.EstimatorSpec(
     ..., # normal EstimatorSpec args 
     scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array}) 
) 

這裏發生的是init_feed_dict將填充embed_ph佔位符在運行時的值,那麼這將使embeddings.initialization_op(佔位符的分配),以運行。


+0

謝謝,只是一個很小的事情:它應該是'tf.estimator.EstimatorSpec(...,支架= tf.train.Scaffold(ini​​t_feed_dict = {embed_ph:my_embedding_numpy_array})' – Tristan

+0

感謝特里斯坦淨距那語法,即使我有解釋大聲笑。 –

相關問題