2017-06-01 76 views
3

我使用Keras Sequential模型來訓練多個多類分類器。將類標籤附加到Keras模型

在評估中,Keras輸出一個置信度向量,我可以從argmax中推斷出正確的類ID。然後,我可以使用查找表來接收實際的類標籤(例如字符串)。

到目前爲止,解決方案是加載訓練好的模型,然後單獨加載查找表。由於我有相當多的分類器,我寧願將兩個結構都保存在一個文件中。

所以我在尋找的是一種將實際標籤查找矢量集成到Keras模型中的方法。這將允許我有一個分類器文件,它能夠獲取一些輸入數據並返回該數據的正確類標籤。

解決這個問題的一種方法是將模型和查找表存儲在一個元組中,並將該元組寫入一個pickle中,但這看起來不太優雅。

回答

5

所以我嘗試了一下我自己的解決方案,這似乎工作。儘管我希望更簡單一些。

第二次打開模型文件並不是真的最佳我認爲。如果任何人都可以做得更好,無論如何,請做。

import h5py 

from keras.models import load_model 
from keras.models import save_model 


def load_model_ext(filepath, custom_objects=None): 
    model = load_model(filepath, custom_objects=None) 
    f = h5py.File(filepath, mode='r') 
    meta_data = None 
    if 'my_meta_data' in f.attrs: 
     meta_data = f.attrs.get('my_meta_data') 
    f.close() 
    return model, meta_data 


def save_model_ext(model, filepath, overwrite=True, meta_data=None): 
    save_model(model, filepath, overwrite) 
    if meta_data is not None: 
     f = h5py.File(filepath, mode='a') 
     f.attrs['my_meta_data'] = meta_data 
     f.close() 
+0

接受我自己的答案缺乏替代品。如果有人提出更好的解決方案,我會接受他們的。 – Cerno

+1

我試圖解決的同樣的問題。但你的解決方案不適用於我: '''save_model_ext(mod1,filepath ='test_model.h5',meta_data = {0:'c1',1:'c2'})''' 產生一個錯誤: ''' TypeError:Object dtype dtype('O')沒有原生HDF5等價物 ''' 您的函數期望得到'meta_data'是什麼類型? – slymore

+0

你好。您必須使用可以轉換爲HDF5的數據。 dtype =「O」表示您的數據包含一個顯然無效的Python對象。如果我記得,我用python字典沒有問題。這真的是你嘗試的代碼還是事實更復雜? – Cerno