2017-03-12 104 views
1

我想用Python將3通道RGB圖像轉換爲索引圖像。它用於處理訓練標籤,用於語義分割的深層網絡。通過索引圖像我的意思是它有一個通道,每個像素是索引,應該從零開始。當然他們應該有相同的大小。該轉換是基於在Python字典以下映射:將RGB圖像轉換爲索引圖像

color2index = { 
     (255, 255, 255) : 0, 
     (0,  0, 255) : 1, 
     (0, 255, 255) : 2, 
     (0, 255, 0) : 3, 
     (255, 255, 0) : 4, 
     (255, 0, 0) : 5 
    } 

我已經實現了一個幼稚功能:

def im2index(im): 
    """ 
    turn a 3 channel RGB image to 1 channel index image 
    """ 
    assert len(im.shape) == 3 
    height, width, ch = im.shape 
    assert ch == 3 
    m_lable = np.zeros((height, width, 1), dtype=np.uint8) 
    for w in range(width): 
     for h in range(height): 
      b, g, r = im[h, w, :] 
      m_lable[h, w, :] = color2index[(r, g, b)] 
    return m_lable 

輸入im是由cv2.imread()創建的numpy的陣列。但是,這段代碼真的很慢。 由於im是numpy的陣列我先試着用ufuncnumpy的是這樣的:

RGB2index = np.frompyfunc(lambda x: color2index(tuple(x))) 
indices = RGB2index(im) 

但事實是,ufunc需要每次只有一個元素。我被啓用一次給函數三個參數(RGB值)。

那麼還有其他方法可以做優化嗎? 如果存在更高效的數據結構,那麼映射並非如此。我注意到一個Python字典的訪問不需要太多時間,但是從numpy array元組(它是可散列的)可以。

PS: 我得到的一個想法是在CUDA中實現一個內核。但它會更復雜。

UPDATA1: Dan Mašek's Answer工作正常。但首先我們必須將RGB圖像轉換爲灰度圖。當兩種顏色具有相同的灰度值時,這可能會有問題。

我在這裏粘貼工作代碼。希望它可以幫助別人。

lut = np.ones(256, dtype=np.uint8) * 255 
lut[[255,29,179,150,226,76]] = np.arange(6, dtype=np.uint8) 
im_out = cv2.LUT(cv2.cvtColor(im, cv2.COLOR_BGR2GRAY), lut) 
+0

因此,輸入將包含您列出的那6種不同的顏色?如果是這樣,從RBG到灰度的轉換將爲您提供以下灰度值:[255,29,179,150,226,76] - 6個不同的值。然後通過'cv2.LUT'運行它,將其重新映射到0-5。 –

+1

更改訂單或循環,即使這會加快您的代碼 – smttsp

+1

@DanMašek謝謝!你的解決方案工作正常在我沒有意識到轉換爲灰度時,RGB的權重不同。灰度圖像的體積爲0-255。這意味着類的最大數量是256.儘管如此,在大多數情況下它仍然可以。問題可能是某些顏色可能具有相同的灰度值。 – Kun

回答

0

這裏有一個小工具功能將圖像轉換(np.array)到每個像素標籤(指數),也可以是一個熱碼:

def rgb2label(img, color_codes = None, one_hot_encode=False): 
    if color_codes is None: 
     color_codes = {val:i for i,val in enumerate(set(tuple(v) for m2d in img for v in m2d))} 
    n_labels = len(color_codes) 
    result = np.ndarray(shape=img.shape[:2], dtype=int) 
    result[:,:] = -1 
    for rgb, idx in color_codes.items(): 
     result[(img==rgb).all(2)] = idx 

    if one_hot_encode: 
     one_hot_labels = np.zeros((img.shape[0],img.shape[1],n_labels)) 
     # one-hot encoding 
     for c in range(n_labels): 
      one_hot_labels[: , : , c ] = (result == c).astype(int) 
     result = one_hot_labels 

    return result, color_codes 


img = cv2.imread("input_rgb_for_labels.png") 
img_labels, color_codes = rgb2label(img) 
print(color_codes) # e.g. to see what the codebook is 

img1 = cv2.imread("another_rgb_for_labels.png") 
img1_labels, _ = rgb2label(img1, color_codes) # use the same codebook 

它計算(和如果提供了None,則返回顏色代碼簿。

相關問題