2012-09-04 47 views
3

我有 相同的長度 在裝置上的整數dmapdflag的兩個陣列和我與推力裝置指針包裹它們,dmaptdflagt類型推力的返回值的::的remove_if

有dmap數組中值爲-1的一些元素。我想 從dflag數組中刪除這些-1和相應的值從 。

我正在使用remove_if函數來做到這一點,但我無法弄清楚 這個調用的返回值是什麼或我該如何使用這個 得到的返回值。

(我想這些降低數組傳遞到reduce_by_key功能 其中dflagt將被用作密鑰。)

我使用下列調用做的減少。請讓我 知道我可以存儲在一個變量返回的值,並 用它來解決各個陣列dflagdmap

thrust::remove_if( 
    thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)), 
    thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)), 
    minus_one_equality_test() 
); 

其中上面使用的謂詞函子被定義爲

struct minus_one_equality_test 
{ 
    typedef typename thrust::tuple<int,int> Tuple; 
    __host__ __device__ 
    bool operator()(const Tuple& a) 
    { 
     return thrust::get<0>(a) == (-1); 
    } 
} 

回答

6

返回值是一個zip_iterator,它標記在remove_if調用期間函數返回true的元組序列的新結束。要訪問底層數組的新結束迭代器,您需要從zip_iterator中檢索元組迭代器;那麼這個元組的內容就是你用來構建zip_iterator的原始數組的新結束迭代器。這是在口頭上比在代碼中多了很多令人費解:

#include <thrust/tuple.h> 
#include <thrust/device_vector.h> 
#include <thrust/device_ptr.h> 
#include <thrust/remove.h> 
#include <thrust/iterator/zip_iterator.h> 
#include <thrust/copy.h> 

#include <iostream> 

struct minus_one_equality_test 
{ 
    typedef thrust::tuple<int,int> Tuple; 
    __host__ __device__ 
    bool operator()(const Tuple& a) 
    { 
     return thrust::get<0>(a) == (-1); 
    }; 
}; 


int main(void) 
{ 
    const int numindices = 10; 

    int mapt[numindices] = { 1, 2, -1, 4, 5, -1, 7, 8, -1, 10 }; 
    int flagt[numindices] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; 

    thrust::device_vector<int> vmapt(10); 
    thrust::device_vector<int> vflagt(10); 

    thrust::copy(mapt, mapt+numindices, vmapt.begin()); 
    thrust::copy(flagt, flagt+numindices, vflagt.begin()); 

    thrust::device_ptr<int> dmapt = vmapt.data(); 
    thrust::device_ptr<int> dflagt = vflagt.data(); 

    typedef thrust::device_vector<int>::iterator VIt; 
    typedef thrust::tuple< VIt, VIt > TupleIt; 
    typedef thrust::zip_iterator<TupleIt> ZipIt; 

    ZipIt Zend = thrust::remove_if( 
     thrust::make_zip_iterator(thrust::make_tuple(dmapt, dflagt)), 
     thrust::make_zip_iterator(thrust::make_tuple(dmapt+numindices, dflagt+numindices)), 
     minus_one_equality_test() 
    ); 

    TupleIt Tend = Zend.get_iterator_tuple(); 
    VIt vmapt_end = thrust::get<0>(Tend); 

    for(VIt x = vmapt.begin(); x != vmapt_end; x++) { 
     std::cout << *x << std::endl; 
    } 

    return 0; 
} 

如果你編譯這個並運行它,你會看到這樣的事情:

$ nvcc -arch=sm_12 remove_if.cu 
$ ./a.out 
1 
2 
4 
5 
7 
8 
10 

在這個例子中,我只有「檢索」的元組的第一個元素的內容被縮短,第二個元素以相同的方式訪問,即。標記矢量的新結束的迭代器是thrust::get<1>(Tend)