2014-12-23 48 views
4

假設我們有一個std::vector,或任何其他序列容器(有時它將是一個雙流),其中存儲uint64_t元素。從連續的單詞序列中提取任意範圍的比特的最有效方法是什麼?

現在,讓我們將這個向量看作是連續位的一個序列size() * 64。我需要找到由給定的[begin, end)範圍內的位組成的單詞,因爲end - begin <= 64所以它適合一個單詞。

我現在所擁有的解決方案找到了兩個單詞,它們的各個部分將形成結果,並分別掩蓋並組合它們。因爲我需要儘可能高效,所以我試圖編寫沒有任何if分支的所有代碼,以避免導致分支預測失誤,例如,代碼適用於整個範圍適合單詞或跨越兩個單詞的情況字,而不採取不同的路徑。爲此,我需要對那些shiftlshiftr函數進行編碼,它們除了將單詞按指定的量移動外,沒有任何作用,例如>><<操作符,但優雅地處理n大於64的情況,否則將會是未定義的行爲,否則。

另一點是,現在編碼的get()函數也適用於空白範圍,從數學意義上說,例如,不僅如果begin == end,而且如果begin> end,這是調用此函數的主算法所需的。再次,我試圖做到這一點,而不是簡單地分支並返回零。

但是,也看着彙編代碼,所有這些看起來太複雜了,無法執行這樣一個看似簡單的任務。此代碼運行在性能至關重要的算法中,運行速度太慢。 valgrind告訴我們這個函數被稱爲2.3億次,佔總執行時間的40%,所以我真的需要使它更快。

那麼你能幫我找到一個更簡單和/或更有效的方法來完成這項任務嗎? 我不在乎很多關於可移植性。使用x86 SIMD內在函數(SSE3/4/AVX ecc ...)或編譯器內置函數的解決方案都可以,只要我可以使用g++clang編譯它們即可。

我當前的代碼包含如下:

using word_type = uint64_t; 
const size_t W = 64; 

// Shift right, but without being undefined behaviour if n >= 64 
word_type shiftr(word_type val, size_t n) 
{ 
    uint64_t good = n < W; 

    return good * (val >> (n * good)); 
} 

// Shift left, but without being undefined behaviour if n >= 64 
word_type shiftl(word_type val, size_t n) 
{ 
    uint64_t good = n < W; 

    return good * (val << (n * good)); 
} 

// Mask the word preserving only the lower n bits. 
word_type lowbits(word_type val, size_t n) 
{ 
    word_type mask = shiftr(word_type(-1), W - n); 

    return val & mask; 
} 

// Struct for return values of locate() 
struct range_location_t { 
    size_t lindex; // The word where is located the 'begin' position 
    size_t hindex; // The word where is located the 'end' position 
    size_t lbegin; // The position of 'begin' into its word 
    size_t llen; // The length of the lower part of the word 
    size_t hlen; // The length of the higher part of the word 
}; 

// Locate the one or two words that will make up the result 
range_location_t locate(size_t begin, size_t end) 
{ 
    size_t lindex = begin/W; 
    size_t hindex = end/W; 
    size_t lbegin = begin % W; 
    size_t hend = end % W; 

    size_t len = (end - begin) * size_t(begin <= end); 
    size_t hlen = hend * (hindex > lindex); 
    size_t llen = len - hlen; 

    return { lindex, hindex, lbegin, llen, hlen }; 
} 

// Main function. 
template<typename Container> 
word_type get(Container const&container, size_t begin, size_t end) 
{ 
    assert(begin < container.size() * W); 
    assert(end <= container.size() * W); 

    range_location_t loc = locate(begin, end); 

    word_type low = lowbits(container[loc.lindex] >> loc.lbegin, loc.llen); 

    word_type high = shiftl(lowbits(container[loc.hindex], loc.hlen), loc.llen); 

    return high | low; 
} 

非常感謝你。

+0

如何使用['std :: bitset'](http://en.cppreference.com/w/cpp/utility/bitset)作爲中介? –

+0

當你返回'好* ...'時,你不能認爲'好'是0或1:這是不可移植的。改用三元運算符:'return good?...:0;' – Christophe

+2

@Christophe在C++中完美的移植。 ''產生一個'bool',並且從那個轉換到任何其他整數只產生0或1.參見C++ 11,4.7/4。 –

回答

2

正如在聊天中宣佈的,我添加了一個精緻的答案。它包含三個部分,每個部分後面跟着該部分的描述。

第一部分get.h是我的解決方案,但是是一般化的,只需要一次修正。

第二部分got.h是發佈在問題中的原始算法,也可以用於任何無符號類型的STL容器。

第三部分main.cpp包含驗證正確性和度量性能的單元測試。

#include <cstddef> 

using std::size_t; 

template < typename C > 
typename C::value_type get (C const &container, size_t lo, size_t hi) 
{ 

    typedef typename C::value_type item; // a container entry 
    static unsigned const bits = (unsigned)sizeof(item)*8u; // bits in an item 
    static size_t const mask = ~(size_t)0u/bits*bits; // huge multiple of bits 

    // everthing above has been computed at compile time. Now do some work: 

    size_t lo_adr = (lo  )/bits; // the index in the container of ... 
    size_t hi_adr = (hi-(hi>0))/bits; // ... the lower or higher item needed 

    // we read container[hi_adr] first and possibly delete the highest bits: 

    unsigned hi_shift = (unsigned)(mask-hi)%bits; 
    item hi_val = container[hi_adr] <<hi_shift>> hi_shift; 

    // if all bits are in the same item, we delete the lower bits and are done: 

    unsigned lo_shift = (unsigned)lo%bits; 
    if (hi_adr <= lo_adr) return (hi_val>>lo_shift) * (lo<hi); 

    // else we have to read the lower item as well, and combine both 

    return (hi_val<<bits-lo_shift | container[lo_adr]>>lo_shift); 

} 

第1部分,get.h以上,是我原來的解決方案,但廣義與無符號整數類型的任何STL容器的工作。因此,您也可以使用它並測試它的32位整數或128位整數。我仍然使用unsigned爲非常小的數字,但你可以將它們替換爲size_t。該算法幾乎沒有改變,只需稍作修改 - 如果lo是容器中的總位數,我以前的get()將訪問容器大小正上方的項。這現在已經修復。

#include <cstddef> 

using std::size_t; 

// Shift right, but without being undefined behaviour if n >= 64 
template < typename val_type > 
val_type shiftr(val_type val, size_t n) 
{ 
    val_type good = n < sizeof(val_type)*8; 
    return good * (val >> (n * good)); 
} 

// Shift left, but without being undefined behaviour if n >= 64 
template < typename val_type > 
val_type shiftl(val_type val, size_t n) 
{ 
    val_type good = n < sizeof(val_type)*8; 
    return good * (val << (n * good)); 
} 

// Mask the word preserving only the lower n bits. 
template < typename val_type > 
val_type lowbits(val_type val, size_t n) 
{ 
    val_type mask = shiftr<val_type>((val_type)(-1), sizeof(val_type)*8 - n); 
    return val & mask; 
} 

// Struct for return values of locate() 
struct range_location_t { 
    size_t lindex; // The word where is located the 'begin' position 
    size_t hindex; // The word where is located the 'end' position 
    size_t lbegin; // The position of 'begin' into its word 
    size_t llen; // The length of the lower part of the word 
    size_t hlen; // The length of the higher part of the word 
}; 

// Locate the one or two words that will make up the result 
template < typename val_type > 
range_location_t locate(size_t begin, size_t end) 
{ 
    size_t lindex = begin/(sizeof(val_type)*8); 
    size_t hindex = end/(sizeof(val_type)*8); 
    size_t lbegin = begin % (sizeof(val_type)*8); 
    size_t hend = end % (sizeof(val_type)*8); 

    size_t len = (end - begin) * size_t(begin <= end); 
    size_t hlen = hend * (hindex > lindex); 
    size_t llen = len - hlen; 

    range_location_t l = { lindex, hindex, lbegin, llen, hlen }; 
    return l; 
} 

// Main function. 
template < typename C > 
typename C::value_type got (C const&container, size_t begin, size_t end) 
{ 
    typedef typename C::value_type val_type; 
    range_location_t loc = locate<val_type>(begin, end); 
    val_type low = lowbits<val_type>(container[loc.lindex] >> loc.lbegin, loc.llen); 
    val_type high = shiftl<val_type>(lowbits<val_type>(container[loc.hindex], loc.hlen), loc.llen); 
    return high | low; 
} 

這第二部分,上面got.h,是原來的算法中的問題,廣義以及接受任何無符號整數類型的任何STL容器。與get.h類似,除了定義容器類型的單個模板參數之外,此版本不使用任何定義,因此可以輕鬆地對其他項目大小或容器類型進行測試。

#include <vector> 
#include <cstddef> 
#include <stdint.h> 
#include <stdio.h> 
#include <sys/time.h> 
#include <sys/resource.h> 
#include "get.h" 
#include "got.h" 

template < typename Container > class Test { 

    typedef typename Container::value_type val_type; 
    typedef val_type (*fun_type) (Container const &, size_t, size_t); 
    typedef void (Test::*fun_test) (unsigned, unsigned); 
    static unsigned const total_bits = 256; // number of bits in the container 
    static unsigned const entry_bits = (unsigned)sizeof(val_type)*8u; 

    Container _container; 
    fun_type _function; 
    bool _failed; 

    void get_value (unsigned lo, unsigned hi) { 
     _function(_container,lo,hi); // we call this several times ... 
     _function(_container,lo,hi); // ... because we measure ... 
     _function(_container,lo,hi); // ... the performance ... 
     _function(_container,lo,hi); // ... of _function, .... 
     _function(_container,lo,hi); // ... not the performance ... 
     _function(_container,lo,hi); // ... of get_value and ... 
     _function(_container,lo,hi); // ... of the loop that ... 
     _function(_container,lo,hi); // ... calls get_value. 
    } 

    void verify (unsigned lo, unsigned hi) { 
     val_type value = _function(_container,lo,hi); 
     if (lo < hi) { 
     for (unsigned i=lo; i<hi; i++) { 
      val_type val = _container[i/entry_bits] >> i%entry_bits & 1u; 
      if (val != (value&1u)) { 
       printf("lo=%d hi=%d [%d] is'nt %d\n",lo,hi,i,(unsigned)val); 
       _failed = true; 
      } 
      value >>= 1u; 
     } 
     } 
     if (value) { 
     printf("lo=%d hi=%d value contains high bits set to 1\n",lo,hi); 
     _failed = true; 
     } 
    } 

    void run (fun_test fun) { 
     for (unsigned lo=0; lo<total_bits; lo++) { 
     unsigned h0 = 0; 
     if (lo > entry_bits) h0 = lo - (entry_bits+1); 
     unsigned h1 = lo+64; 
     if (h1 > total_bits) h1 = total_bits; 
     for (unsigned hi=h0; hi<=h1; hi++) { 
      (this->*fun)(lo,hi); 
     } 
     } 
    } 

    static uint64_t time_used () { 
     struct rusage ru; 
     getrusage(RUSAGE_THREAD,&ru); 
     struct timeval t = ru.ru_utime; 
     return (uint64_t) t.tv_sec*1000 + t.tv_usec/1000; 
    } 

public: 

    Test (fun_type function): _function(function), _failed() { 
     val_type entry; 
     unsigned index = 0; // position in the whole bit array 
     unsigned value = 0; // last value assigned to a bit 
     static char const entropy[] = "The quick brown Fox jumps over the lazy Dog"; 
     do { 
     if (! (index%entry_bits)) entry = 0; 
     entry <<= 1; 
     entry |= value ^= 1u & entropy[index/7%sizeof(entropy)] >> index%7; 
     ++index; 
     if (! (index%entry_bits)) _container.push_back(entry); 
     } while (index < total_bits); 
    } 

    bool correctness() { 
     _failed = false; 
     run(&Test::verify); 
     return !_failed; 
    } 

    void performance() { 
     uint64_t t1 = time_used(); 
     for (unsigned i=0; i<999; i++) run(&Test::get_value); 
     uint64_t t2 = time_used(); 
     printf("used %d ms\n",(unsigned)(t2-t1)); 
    } 

    void operator() (char const * name) { 
     printf("testing %s\n",name); 
     correctness(); 
     performance(); 
    } 

}; 

int main() 
{ 
    typedef typename std::vector<uint64_t> Container; 
    Test<Container> test(get<Container>); test("get"); 
    Test<Container> tost(got<Container>); tost("got"); 
} 

第三部分,上面main.cpp中,包含一個類的單元測試,並將其應用於get.h和got.h,就是對我的解決方案,並就問題的原代碼,略改性。單元測試驗證正確性並測量速度。他們通過創建一個256位的容器,用一些數據填充它,通過讀取所有可能的比特部分,以適合容器條目以及大量病理案例,並驗證每個結果的正確性來驗證正確性。他們通過再次閱讀相同的章節來測量速度,並報告用戶空間中使用的線程時間。

+0

聖切廷,你真的值得投票。 – nullpotent

+0

簡直太棒了! – gigabytes

4

這取代了get()和get()使用的所有幫助函數。它包含一個條件分支並保存大約16個算術運算,這意味着它通常應該運行得更快。編譯一些優化,它也產生很短的代碼。最後,它解決了導致容器[container.size()]在case == container.size()* W的情況下被訪問的錯誤。

最棘手的部分是「hi-(hi> 0)」,從hi中減去1,除非hi是0.除了hi只是指向一個單詞邊界,即hi,%, 64 == 0。在這種情況下,我們需要從上層容器條目中的0位,因此僅使用下層容器條目即可。通過在計算hi_off之前減1,我們確保條件「hi_off == lo_off」,並且我們遇到了更簡單的情況。

在這種更簡單的情況下,我們只需要一個容器條目,並在兩側刪除一些位。 hi_val是該條目,高位已被刪除,因此唯一需要做的是刪除一些低位。

在不那麼簡單的情況下,我們也必須讀取下面的容器條目,去掉它的未使用的字節,併合並兩個條目。

namespace { 
    size_t const upper_mask = ~(size_t)0u << 6u; 
    unsigned const lower_mask = (unsigned)~upper_mask; 
} 

word_type get (Container const &container, size_t lo, size_t hi) 
{ 
    size_t lo_off = lo  >>6u; assert (lo_off < container.size()); 
    size_t hi_off = hi-(hi>0)>>6u; assert (hi_off < container.size()); 
    unsigned hi_shift = lower_mask&(unsigned)(upper_mask-hi); 
    word_type hi_val = container[hi_off] <<hi_shift>> hi_shift; 
    unsigned lo_shift = lower_mask&(unsigned)lo; 
    if (hi_off == lo_off) return hi_val >> lo_shift; // use hi_val as lower word 
    return (hi_val<<W-lo_shift | container[lo_off]>>lo_shift) * (lo_off<hi_off); 
} 
+0

非常有趣!但是我仍然需要很好地理解它......爲什麼你會以6換? – gigabytes

+0

>> 6u與/ W相同 - 它除以64,計算包含所需位的單詞的索引。通常情況下,編譯器的優化器應該會看到,但是爲了防止它失敗,我會告訴它如何快速除以64。順便說一句,你會注意到我沒有使用特殊的移位功能,這對64位或更多的位移很好。這是因爲在我的實現中,所有的移位操作都少於64位。 –

+0

是的,這種轉變是顯而易見的,對不起......無論如何,我認爲這不需要任何合理的編譯器。 所以,你使用'lower_mask'來取代使用'%'運算符的模數64,對吧?但你爲什麼要做'upper_mask - hi'? – gigabytes

相關問題