From f2a66ffabbc704f2049addcb319620fde427e844 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 6 Sep 2017 10:53:27 -0700 Subject: [PATCH] Follow comments --- paddle/framework/attribute.h | 4 +-- paddle/framework/ddim.cc | 8 +++-- paddle/framework/ddim.h | 2 +- paddle/framework/eigen.h | 16 ++++----- paddle/framework/tensor_impl.h | 4 +-- paddle/framework/tensor_test.cc | 4 +-- paddle/operators/mul_op.cc | 34 +++++++++--------- paddle/operators/mul_op.h | 36 +++++++++---------- paddle/operators/rowwise_add_op.cc | 16 ++++----- paddle/operators/rowwise_add_op.h | 14 ++++---- .../paddle/v2/framework/tests/test_mul_op.py | 4 +-- 11 files changed, 73 insertions(+), 69 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 7da34e3f2b..31e8218743 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -44,7 +44,7 @@ class LargerThanChecker { public: explicit LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); + PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); } private: @@ -56,7 +56,7 @@ class EqualLargerThanChecker { public: explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); + PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fails."); } private: diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 972dac7073..499d4ecbf1 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -284,11 +284,13 @@ DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } -DDim flatten_to_2d(const DDim& src, int num_row_dims) { +// Reshape a tensor to a matrix. The matrix's first dimension(column length) +// will be the product of tensor's first `num_col_dims` dimensions +DDim flatten_to_2d(const DDim& src, int num_col_dims) { int rank = src.size(); return make_ddim( - {static_cast(product(slice_ddim(src, 0, rank - num_row_dims))), - static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); + {static_cast(product(slice_ddim(src, 0, num_col_dims))), + static_cast(product(slice_ddim(src, num_col_dims, rank)))}); } DDim flatten_to_1d(const DDim& src) { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 8f1269d9a1..2dbd5f5f70 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -115,7 +115,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); -DDim flatten_to_2d(const DDim& src, int num_row_dims); +DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index c6f42251da..4b798cd5ae 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -64,21 +64,21 @@ struct EigenTensor { template struct EigenMatrix : public EigenTensor { - static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { + static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) { int rank = tensor.dims_.size(); - PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, - "`num_row_dims` must be between (0, rank_of_tensor)."); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); return EigenMatrix::From(tensor, - flatten_to_2d(tensor.dims(), num_row_dims)); + flatten_to_2d(tensor.dims(), num_col_dims)); } static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, - int num_row_dims) { + int num_col_dims) { int rank = tensor.dims_.size(); - PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, - "`num_row_dims` must be between (0, rank_of_tensor)."); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); return EigenMatrix::From(tensor, - flatten_to_2d(tensor.dims(), num_row_dims)); + flatten_to_2d(tensor.dims(), num_col_dims)); } }; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index d32fe78f42..f1a7275899 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -149,10 +149,10 @@ inline Tensor& Tensor::Resize(const DDim& dims) { inline const DDim& Tensor::dims() const { return dims_; } template -inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { +inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { Tensor res; res.ShareDataWith(src); - res.Resize(flatten_to_2d(src.dims(), num_row_dims)); + res.Resize(flatten_to_2d(src.dims(), num_col_dims)); return res; } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index cdd68b303c..a2c2d19dc7 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -263,7 +263,7 @@ TEST(Tensor, CopyFrom) { #endif } -TEST(Tensor, FlattenToMatrix) { +TEST(Tensor, ReshapeToMatrix) { using namespace paddle::framework; using namespace paddle::platform; Tensor src; @@ -271,7 +271,7 @@ TEST(Tensor, FlattenToMatrix) { for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { src_ptr[i] = i; } - Tensor res = FlattenToMatrix(src, 2); + Tensor res = ReshapeToMatrix(src, 2); ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[1], 4 * 9); } \ No newline at end of file diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index dfc22decdc..fb96d322e9 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -27,20 +27,20 @@ class MulOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); - int x_num_row_dims = GetAttr("x_num_row_dims"); - int y_num_row_dims = GetAttr("y_num_row_dims"); + int x_num_col_dims = GetAttr("x_num_col_dims"); + int y_num_col_dims = GetAttr("y_num_col_dims"); - PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, + PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `x_num_row_dims`.", + "`mul_op`'s `x_num_col_dims`.", ctx.op().Input("X")); - PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, + PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `y_num_row_dims`.", + "`mul_op`'s `y_num_col_dims`.", ctx.op().Input("Y")); - auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); - auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); PADDLE_ENFORCE_EQ( x_mat_dims[1], y_mat_dims[0], @@ -57,19 +57,19 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); AddAttr( - "x_num_row_dims", + "x_num_col_dims", "mul_op can take tensors with more than two dimensions as input `X`, " - "in that case, tensors will be flattened to a matrix. The matrix's " - "second dimension(row length) will be the product of tensor's last " - "`num_row_dims` dimensions, and the matrix's first dimension(column " - "length) will be the product of tensor's first `rank - num_row_dims` " + "in that case, tensors will be reshaped to a matrix. The matrix's " + "first dimension(column length) will be the product of tensor's last " + "`num_col_dims` dimensions, and the matrix's second dimension(row " + "length) will be the product of tensor's first `rank - num_col_dims` " "dimensions.") .SetDefault(1) .EqualLargerThan(1); AddAttr( - "y_num_row_dims", + "y_num_col_dims", "mul_op can take tensors with more than two dimensions as input `Y`, " - "in that case, tensors will be flattened to a matrix. Just like input " + "in that case, tensors will be reshaped to a matrix. Just like input " "`X`.") .SetDefault(1) .EqualLargerThan(1); @@ -98,9 +98,9 @@ class MulOpGrad : public framework::OperatorWithKernel { auto *y_grad = ctx.Output(framework::GradVarName("Y")); auto x_mat_dims = - framework::flatten_to_2d(x_dims, GetAttr("x_num_row_dims")); + framework::flatten_to_2d(x_dims, GetAttr("x_num_col_dims")); auto y_mat_dims = - framework::flatten_to_2d(y_dims, GetAttr("y_num_row_dims")); + framework::flatten_to_2d(y_dims, GetAttr("y_num_col_dims")); PADDLE_ENFORCE_EQ( x_mat_dims[0], out_dims[0], diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 62557bb839..6656ecaf1a 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -1,14 +1,14 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - you may obtain a copy of the License at + You may not use this file except in compliance with the License. + You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied. + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ @@ -33,22 +33,22 @@ class MulKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* x = context.Input("X"); const Tensor* y = context.Input("Y"); - Tensor* Z = context.Output("Out"); + Tensor* z = context.Output("Out"); const Tensor x_matrix = x->dims().size() > 2 - ? framework::FlattenToMatrix( - *x, context.template GetAttr("x_num_row_dims")) + ? framework::ReshapeToMatrix( + *x, context.template GetAttr("x_num_col_dims")) : *x; const Tensor y_matrix = y->dims().size() > 2 - ? framework::FlattenToMatrix( - *y, context.template GetAttr("y_num_row_dims")) + ? framework::ReshapeToMatrix( + *y, context.template GetAttr("y_num_col_dims")) : *y; - Z->mutable_data(context.GetPlace()); + z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(x_matrix, false, y_matrix, false, 1, Z, 0, + math::matmul(x_matrix, false, y_matrix, false, 1, z, 0, device_context); } }; @@ -57,15 +57,15 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - int x_num_row_dims = ctx.template GetAttr("x_num_row_dims"); - int y_num_row_dims = ctx.template GetAttr("y_num_row_dims"); + int x_num_col_dims = ctx.template GetAttr("x_num_col_dims"); + int y_num_col_dims = ctx.template GetAttr("y_num_col_dims"); const Tensor* x = ctx.Input("X"); const Tensor* y = ctx.Input("Y"); const Tensor x_matrix = - x->dims().size() > 2 ? framework::FlattenToMatrix(*x, x_num_row_dims) + x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims) : *x; const Tensor y_matrix = - y->dims().size() > 2 ? framework::FlattenToMatrix(*y, y_num_row_dims) + y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims) : *y; const Tensor* dout = ctx.Input(framework::GradVarName("Out")); @@ -75,8 +75,8 @@ class MulGradKernel : public framework::OpKernel { const_cast(ctx.device_context_); if (dx) { dx->mutable_data(ctx.GetPlace()); - Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix( - *dx, x_num_row_dims) + Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( + *dx, x_num_col_dims) : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0, @@ -84,8 +84,8 @@ class MulGradKernel : public framework::OpKernel { } if (dy) { dy->mutable_data(ctx.GetPlace()); - Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix( - *dy, y_num_row_dims) + Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( + *dy, y_num_col_dims) : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0, diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 209281a45b..fa8f0ff1a8 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -31,11 +31,11 @@ class RowwiseAddOp : public framework::OperatorWithKernel { x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_row_dims = b_dims.size(); + int num_col_dims = x_dims.size() - b_dims.size(); - PADDLE_ENFORCE_EQ(framework::slice_ddim( - x_dims, x_dims.size() - num_row_dims, x_dims.size()), - b_dims, "The width of two operands must be same"); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); ctx.Output("Out")->Resize(x_dims); } @@ -72,10 +72,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { x_dims.size(), b_dims.size(), "The rank of input `X` must be larger than the one of input `b`."); - int num_row_dims = b_dims.size(); - PADDLE_ENFORCE_EQ(framework::slice_ddim( - x_dims, x_dims.size() - num_row_dims, x_dims.size()), - b_dims, "The width of two operands must be same"); + int num_col_dims = x_dims.size() - b_dims.size(); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); auto *dx = ctx.Output(framework::GradVarName("X")); auto *db = ctx.Output(framework::GradVarName("b")); if (dx) dx->Resize(x_dims); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index a52a53a7d2..35774b9409 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,11 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - int num_row_dims = context.Input("b")->dims().size(); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); auto input = - EigenMatrix::Reshape(*context.Input("X"), num_row_dims); + EigenMatrix::Reshape(*context.Input("X"), num_col_dims); auto bias = EigenVector::Flatten(*context.Input("b")); - auto output = EigenMatrix::Reshape(*out, num_row_dims); + auto output = EigenMatrix::Reshape(*out, num_col_dims); const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; @@ -55,14 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); - int num_row_dims = context.Input("b")->dims().size(); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); - auto out_grad = EigenMatrix::Reshape(*dout, num_row_dims); + auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims); auto place = context.GetEigenDevice(); if (dx) { dx->mutable_data(context.GetPlace()); - EigenMatrix::Reshape(*dx, num_row_dims).device(place) = out_grad; + EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad; } if (db) { diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index 3ea73d94b2..d8057f4ffa 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -26,7 +26,7 @@ class TestMulOp2(unittest.TestCase): 'X': np.random.random((15, 4, 12, 10)).astype("float32"), 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") } - self.attrs = {'x_num_row_dims': 2, 'y_num_row_dims': 3} + self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} self.outputs = { 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) @@ -69,7 +69,7 @@ class TestMulGradOp(GradientChecker): class TestMulGradTest2(GradientChecker): def setUp(self): self.op = Operator( - "mul", X="X", Y="Y", Out="Out", x_num_row_dims=2, y_num_row_dims=3) + "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2) self.inputs = { "X": np.random.random((15, 4, 12, 10)).astype("float32"), "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32") -- GitLab