提交 5aacd64b 编写于 作者: F fengjiayi

Follow comments

上级 0c13660a
...@@ -87,11 +87,11 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -87,11 +87,11 @@ template <typename T, int MajorType = Eigen::RowMajor,
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> { struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten reshapes a Tensor into an EigenVector. // Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) { static typename EigenVector::Type Flatten(Tensor& tensor) {
return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))}); return EigenVector::From(tensor, {product(tensor.dims_)});
} }
static typename EigenVector::ConstType Flatten(const Tensor& tensor) { static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
return EigenVector::From(tensor, {static_cast<int>(product(tensor.dims_))}); return EigenVector::From(tensor, {product(tensor.dims_)});
} }
}; };
......
...@@ -110,8 +110,7 @@ TEST(Eigen, Matrix) { ...@@ -110,8 +110,7 @@ TEST(Eigen, Matrix) {
TEST(Eigen, MatrixReshape) { TEST(Eigen, MatrixReshape) {
Tensor t; Tensor t;
float* p = float* p = t.mutable_data<float>({2, 3, 6, 4}, platform::CPUPlace());
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
p[i] = static_cast<float>(i); p[i] = static_cast<float>(i);
} }
......
...@@ -267,7 +267,7 @@ TEST(Tensor, ReshapeToMatrix) { ...@@ -267,7 +267,7 @@ TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework; using namespace paddle::framework;
using namespace paddle::platform; using namespace paddle::platform;
Tensor src; Tensor src;
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace()); int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i; src_ptr[i] = i;
} }
......
...@@ -27,8 +27,8 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -27,8 +27,8 @@ class MulOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dims = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
int x_num_col_dims = GetAttr<int>("x_num_col_dims"); int x_num_col_dims = Attr<int>("x_num_col_dims");
int y_num_col_dims = GetAttr<int>("y_num_col_dims"); int y_num_col_dims = Attr<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X(%s) should be larger than " "The rank of input tensor X(%s) should be larger than "
...@@ -58,19 +58,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -58,19 +58,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
AddAttr<int>( AddAttr<int>(
"x_num_col_dims", "x_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `X`, " R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
"in that case, tensors will be reshaped to a matrix. The matrix's " in that case, tensors will be reshaped to a matrix. The matrix's first
"first dimension(column length) will be the product of tensor's last " dimension(column length) will be the product of tensor's last
"`num_col_dims` dimensions, and the matrix's second dimension(row " `num_col_dims` dimensions, and the matrix's second dimension(row length)
"length) will be the product of tensor's first `rank - num_col_dims` " will be the product of tensor's first `rank - num_col_dims` dimensions.
"dimensions.") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualLargerThan(1);
AddAttr<int>( AddAttr<int>(
"y_num_col_dims", "y_num_col_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, " R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
"in that case, tensors will be reshaped to a matrix. Just like input " in that case, tensors will be reshaped to a matrix. Just like input `X`.
"`X`.") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualLargerThan(1); .EqualLargerThan(1);
AddComment(R"DOC( AddComment(R"DOC(
...@@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto x_mat_dims = auto x_mat_dims =
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_col_dims")); framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims"));
auto y_mat_dims = auto y_mat_dims =
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_col_dims")); framework::flatten_to_2d(y_dims, Attr<int>("y_num_col_dims"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0], x_mat_dims[0], out_dims[0],
......
...@@ -37,12 +37,12 @@ class MulKernel : public framework::OpKernel { ...@@ -37,12 +37,12 @@ class MulKernel : public framework::OpKernel {
const Tensor x_matrix = const Tensor x_matrix =
x->dims().size() > 2 x->dims().size() > 2
? framework::ReshapeToMatrix<T>( ? framework::ReshapeToMatrix<T>(
*x, context.template GetAttr<int>("x_num_col_dims")) *x, context.template Attr<int>("x_num_col_dims"))
: *x; : *x;
const Tensor y_matrix = const Tensor y_matrix =
y->dims().size() > 2 y->dims().size() > 2
? framework::ReshapeToMatrix<T>( ? framework::ReshapeToMatrix<T>(
*y, context.template GetAttr<int>("y_num_col_dims")) *y, context.template Attr<int>("y_num_col_dims"))
: *y; : *y;
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
...@@ -57,8 +57,8 @@ template <typename Place, typename T> ...@@ -57,8 +57,8 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template GetAttr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template GetAttr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X"); const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix = const Tensor x_matrix =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册