提交 d9fa6159 编写于 作者: Q qijun

add Flatten method to EigenVector

上级 1981eaf9
...@@ -23,7 +23,7 @@ namespace framework { ...@@ -23,7 +23,7 @@ namespace framework {
// EigenDim converts paddle::platform::DDim into Eigen::DSizes. // EigenDim converts paddle::platform::DDim into Eigen::DSizes.
template <int D> template <int D>
struct EigenDim { struct EigenDim {
typedef Eigen::DSizes<Eigen::DenseIndex, D> Type; using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
...@@ -69,12 +69,23 @@ struct EigenVector { ...@@ -69,12 +69,23 @@ struct EigenVector {
using ConstType = using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>; Eigen::Aligned>;
// From is to transfer a one dimension Tensor into a one dimension EigenVector
static Type From(Tensor& tensor) { return EigenTensor<T, 1>::From(tensor); } static Type From(Tensor& tensor) { return EigenTensor<T, 1>::From(tensor); }
// Flatten is to reshape a Tensor into a one dimension EigenVector
static Type Flatten(Tensor& tensor) {
return EigenTensor<T, 1>::From(
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
}
static ConstType From(const Tensor& tensor) { static ConstType From(const Tensor& tensor) {
return EigenTensor<T, 1>::From(tensor); return EigenTensor<T, 1>::From(tensor);
} }
static ConstType Flatten(const Tensor& tensor) {
return EigenTensor<T, 1>::From(
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
}
}; };
// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix. // Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix.
......
...@@ -32,6 +32,10 @@ TEST(Eigen, Tensor) { ...@@ -32,6 +32,10 @@ TEST(Eigen, Tensor) {
} }
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t); EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
for (int i = 0; i < 1 * 2 * 3; i++) {
EXPECT_EQ(et(i), i);
}
// TODO: check the content of et. // TODO: check the content of et.
} }
...@@ -39,5 +43,5 @@ TEST(Eigen, Vector) {} ...@@ -39,5 +43,5 @@ TEST(Eigen, Vector) {}
TEST(Eigen, Matrix) {} TEST(Eigen, Matrix) {}
} // namespace platform } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -30,9 +30,10 @@ public: ...@@ -30,9 +30,10 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::From(*output).device( framework::EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>(*input0) + framework::EigenVector<T>(*input1); framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册