提交 d9fa6159 编写于 作者: Q qijun

add Flatten method to EigenVector

上级 1981eaf9
......@@ -23,7 +23,7 @@ namespace framework {
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
template <int D>
struct EigenDim {
typedef Eigen::DSizes<Eigen::DenseIndex, D> Type;
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
......@@ -69,12 +69,23 @@ struct EigenVector {
using ConstType =
Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>;
// From is to transfer a one dimension Tensor into a one dimension EigenVector
static Type From(Tensor& tensor) { return EigenTensor<T, 1>::From(tensor); }
// Flatten is to reshape a Tensor into a one dimension EigenVector
static Type Flatten(Tensor& tensor) {
return EigenTensor<T, 1>::From(
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
}
static ConstType From(const Tensor& tensor) {
return EigenTensor<T, 1>::From(tensor);
}
static ConstType Flatten(const Tensor& tensor) {
return EigenTensor<T, 1>::From(
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
}
};
// Interpret paddle::platform::Tensor as EigenMatrix and EigenConstMatrix.
......
......@@ -32,6 +32,10 @@ TEST(Eigen, Tensor) {
}
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
for (int i = 0; i < 1 * 2 * 3; i++) {
EXPECT_EQ(et(i), i);
}
// TODO: check the content of et.
}
......@@ -39,5 +43,5 @@ TEST(Eigen, Vector) {}
TEST(Eigen, Matrix) {}
} // namespace platform
} // namespace framework
} // namespace paddle
......@@ -30,9 +30,10 @@ public:
output->mutable_data<T>(context.GetPlace());
framework::EigenVector<T>::From(*output).device(
framework::EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>(*input0) + framework::EigenVector<T>(*input1);
framework::EigenVector<T>::Flatten(input0) +
framework::EigenVector<T>::Flatten(input1);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册