2017-09-06 67 views
0

我試圖使用Eigen實現可修改的自定義表達式,類似於此question。基本上,我想要的是類似於tutorial中的索引示例,但可以爲選定的係數分配新值。本徵值:可修改的自定義表達式

正如在上面提到的問題中接受的答案中所建議的那樣,我研究了Transpose實現並嘗試了很多事情,但都沒有成功。基本上,我的嘗試失敗,像'Eigen::internal::evaluator<SrcXprType>::evaluator(const Eigen::internal::evaluator<SrcXprType> &)': cannot convert argument 1 from 'const Eigen::Indexing<Derived>' to 'Eigen::Indexing<Derived> &'錯誤。可能,問題在於我的evaluator結構似乎是隻讀的。

namespace Eigen { 
namespace internal { 
    template<typename ArgType> 
    struct evaluator<Indexing<ArgType> > 
     : evaluator_base<Indexing<ArgType> > 
    { 
     typedef Indexing<ArgType> XprType; 
     typedef typename nested_eval<ArgType, XprType::ColsAtCompileTime>::type ArgTypeNested; 
     typedef typename remove_all<ArgTypeNested>::type ArgTypeNestedCleaned; 
     typedef typename XprType::CoeffReturnType CoeffReturnType; 
     typedef typename traits<ArgType>::Scalar Scalar; 
     enum { 
      CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost, 
      Flags = Eigen::ColMajor 
     }; 

     evaluator(XprType& xpr) 
      : m_argImpl(xpr.m_arg), m_rows(xpr.rows()) 
     { } 
     const Scalar& coeffRef(Index row, Index col) const 
     { 
      return m_argImpl.coeffRef(... very clever stuff ...) 
     } 

     Scalar& coeffRef(Index row, Index col) 
     { 
      return m_argImpl.coeffRef(... very clever stuff ...) 
     } 

     evaluator<ArgTypeNestedCleaned> m_argImpl; 
     const Index m_rows; 
    }; 
} 
} 

而且,我已經改變了的typedef typename Eigen::internal::ref_selector<ArgType>::type所有出現到...::non_const_type,但沒有效果。

由於特徵庫的複雜性,我無法弄清楚如何正確拼寫表達式和評估器。我不明白,爲什麼我的評估者是隻讀的,或者如何獲得寫入評估者。 如果有人能爲可修改的自定義表達式提供一個最小示例,那將會很棒。

+0

不要打擾和檢查用dev​​el的分支。這已經[there](http://eigen.tuxfamily.org/dox-devel/classEigen_1_1DenseBase.html#a0b44220621cd59a75cd0f48cc499518f)。 – ggael

+0

@ggael很棒,謝謝!不過,我真的很感興趣,如何製作例如該教程的Circulant表達式是可寫的(即使這個特定情況沒有太多意義)。我試圖將一些概念從IndexedView轉移到Circulant示例中,但沒有成功。 – florestan

+0

我想你需要更新'Indexing'類來支持非const嵌套表達式,所以'm_arg'可以是非const的,就像我們在'IndexedView'中做的那樣。然後在評估器中,ctor應該採用'const XprType&'。基本上只要按照'IndexedView'的例子。 – ggael

回答

0

藉助ggael的提示,我已經能夠成功添加我自己的可修改表達式。我已經基本適應了Eigen開發分支的IndexedView

由於最初請求的功能被IndexedView覆蓋,我寫了一個可修改的循環移位函數作爲可修改自定義表達式的簡單示例。大部分代碼直接來自IndexedView,所以學分歸於作者。

// circ_shift.h 
#pragma once 
#include <Eigen/Core> 

namespace helper 
{ 
     namespace detail 
    { 
     template <typename T> 
     constexpr std::true_type is_matrix(Eigen::MatrixBase<T>); 
     std::false_type constexpr is_matrix(...); 

     template <typename T> 
     constexpr std::true_type is_array(Eigen::ArrayBase<T>); 
     std::false_type constexpr is_array(...); 
    } 


    template <typename T> 
    struct is_matrix : decltype(detail::is_matrix(std::declval<std::remove_cv_t<T>>())) 
    { 
    }; 

    template <typename T> 
    struct is_array : decltype(detail::is_array(std::declval<std::remove_cv_t<T>>())) 
    { 
    }; 

    template <typename T> 
    using is_matrix_or_array = std::bool_constant<is_array<T>::value || is_matrix<T>::value>; 



    /* 
    * Index something if it's not an scalar 
    */ 
    template <typename T, typename std::enable_if<is_matrix_or_array<T>::value, int>::type = 0> 
    auto index_if_necessary(T&& thing, Eigen::Index idx) 
    { 
     return thing(idx); 
    } 

    /* 
    * Overload for scalar. 
    */ 
    template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value, int>::type = 0> 
    auto index_if_necessary(T&& thing, Eigen::Index) 
    { 
     return thing; 
    } 
} 

namespace Eigen 
{ 
    template <typename XprType, typename RowIndices, typename ColIndices> 
    class CircShiftedView; 

    namespace internal 
    { 
     template <typename XprType, typename RowIndices, typename ColIndices> 
     struct traits<CircShiftedView<XprType, RowIndices, ColIndices>> 
      : traits<XprType> 
     { 
      enum 
      { 
       RowsAtCompileTime = traits<XprType>::RowsAtCompileTime, 
       ColsAtCompileTime = traits<XprType>::ColsAtCompileTime, 
       MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : int(traits<XprType>::MaxRowsAtCompileTime), 
       MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : int(traits<XprType>::MaxColsAtCompileTime), 

       XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0, 
       IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1 
           : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0 
           : XprTypeIsRowMajor, 


       FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, 
       FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, 
       Flags = (traits<XprType>::Flags & HereditaryBits) | FlagsLvalueBit | FlagsRowMajorBit 
      }; 
     }; 
    } 

    template <typename XprType, typename RowShift, typename ColShift, typename StorageKind> 
    class CircShiftedViewImpl; 


    template <typename XprType, typename RowShift, typename ColShift> 
    class CircShiftedView : public CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind> 
    { 
    public: 
     typedef typename CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>::Base Base; 
     EIGEN_GENERIC_PUBLIC_INTERFACE(CircShiftedView) 
     EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CircShiftedView) 

     typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested; 
     typedef typename internal::remove_all<XprType>::type NestedExpression; 

     template <typename T0, typename T1> 
     CircShiftedView(XprType& xpr, const T0& rowShift, const T1& colShift) 
      : m_xpr(xpr), m_rowShift(rowShift), m_colShift(colShift) 
     { 
      for (auto c = 0; c < xpr.cols(); ++c) 
      assert(std::abs(helper::index_if_necessary(m_rowShift, c)) < m_xpr.rows()); // row shift must be within +- rows()-1 
      for (auto r = 0; r < xpr.rows(); ++r) 
      assert(std::abs(helper::index_if_necessary(m_colShift, r)) < m_xpr.cols()); // col shift must be within +- cols()-1 
     } 

     /** \returns number of rows */ 
     Index rows() const { return m_xpr.rows(); } 

     /** \returns number of columns */ 
     Index cols() const { return m_xpr.cols(); } 

     /** \returns the nested expression */ 
     const typename internal::remove_all<XprType>::type& 
     nestedExpression() const { return m_xpr; } 

     /** \returns the nested expression */ 
     typename internal::remove_reference<XprType>::type& 
     nestedExpression() { return m_xpr.const_cast_derived(); } 

     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 
     Index getRowIdx(Index row, Index col) const 
     { 
      Index R = m_xpr.rows(); 
      assert(row >= 0 && row < R && col >= 0 && col < m_xpr.cols()); 
      Index r = row - helper::index_if_necessary(m_rowShift, col); 
      if (r >= R) 
       return r - R; 
      if (r < 0) 
       return r + R; 
      return r; 
     } 

     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 
     Index getColIdx(Index row, Index col) const 
     { 
      Index C = m_xpr.cols(); 
      assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < C); 
      Index c = col - helper::index_if_necessary(m_colShift, row); 
      if (c >= C) 
       return c - C; 
      if (c < 0) 
       return c + C; 
      return c; 
     } 

    protected: 
     MatrixTypeNested m_xpr; 
     RowShift m_rowShift; 
     ColShift m_colShift; 
    }; 


    // Generic API dispatcher 
    template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind> 
    class CircShiftedViewImpl 
     : public internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type 
    { 
    public: 
     typedef typename internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type Base; 
    }; 

    namespace internal 
    { 
     template <typename ArgType, typename RowIndices, typename ColIndices> 
     struct unary_evaluator<CircShiftedView<ArgType, RowIndices, ColIndices>, IndexBased> 
      : evaluator_base<CircShiftedView<ArgType, RowIndices, ColIndices>> 
     { 
      typedef CircShiftedView<ArgType, RowIndices, ColIndices> XprType; 

      enum 
      { 
       CoeffReadCost = evaluator<ArgType>::CoeffReadCost + NumTraits<Index>::AddCost /* for comparison */ + NumTraits<Index>::AddCost /*for addition*/, 

       Flags = (evaluator<ArgType>::Flags & HereditaryBits), 

       Alignment = 0 
      }; 

      EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr) 
      { 
       EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); 
      } 

      typedef typename XprType::Scalar Scalar; 
      typedef typename XprType::CoeffReturnType CoeffReturnType; 


      EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 
      CoeffReturnType coeff(Index row, Index col) const 
      { 
       return m_argImpl.coeff(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col)); 
      } 

      EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 
      Scalar& coeffRef(Index row, Index col) 
      { 
       assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < m_xpr.cols()); 

       return m_argImpl.coeffRef(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col)); 
      } 

     protected: 

      evaluator<ArgType> m_argImpl; 
      const XprType& m_xpr; 
     }; 
    } // end namespace internal 
} // end namespace Eigen 


template <typename XprType, typename RowShift, typename ColShift> 
auto circShift(Eigen::DenseBase<XprType>& x, RowShift r, ColShift c) 
{ 
    return Eigen::CircShiftedView<XprType, RowShift, ColShift>(x.derived(), r, c); 
} 

和:

// main.cpp 
#include "stdafx.h" 
#include "Eigen/Core" 
#include <iostream> 
#include "circ_shift.h" 

using namespace Eigen; 


int main() 
{ 

    ArrayXXf x(4, 2); 
    x.transpose() << 1, 2, 3, 4, 10, 20, 30, 40; 


    Vector2i rowShift; 
    rowShift << 3, -3; // rotate col 1 by 3 and col 2 by -3 

    Index colShift = 1; // flip columns 

    auto shifted = circShift(x, rowShift, colShift); 

    std::cout << "shifted: " << std::endl << shifted << std::endl; 

    shifted.block(2,0,2,1) << -1, -2; // will appear in row 3 and 0. 
    shifted.col(1) << 2,4,6,8; // shifted col 1 is col 0 of the original 

    std::cout << "modified original:" << std::endl << x << std::endl; 

    return 0; 
}