diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 47d1f301167ee630cc39fb15efd7f58218ae0964..972dac7073a2c6f420e8a6c2dc518db8e339b9e2 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -291,5 +291,9 @@ DDim flatten_to_2d(const DDim& src, int num_row_dims) { static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); } +DDim flatten_to_1d(const DDim& src) { + return make_ddim({static_cast(product(src))}); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index cf786d140ed136f9db4586000f5c7784bdd250b5..8f1269d9a1da887da199f5b738ff5505c43d1de8 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -117,6 +117,8 @@ std::ostream& operator<<(std::ostream&, const DDim&); DDim flatten_to_2d(const DDim& src, int num_row_dims); +DDim flatten_to_1d(const DDim& src); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 656aef42127a823bf5d4cd1eb7e3a96abde7e169..c6f42251da272129affa5562c01b82261b657feb 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -71,6 +71,15 @@ struct EigenMatrix : public EigenTensor { return EigenMatrix::From(tensor, flatten_to_2d(tensor.dims(), num_row_dims)); } + + static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, + int num_row_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, + "`num_row_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_row_dims)); + } }; template { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); } static typename EigenVector::ConstType Flatten(const Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {static_cast(product(tensor.dims_))}); } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 30b4b404315a9f041e21d79b75fd06307e33f7f9..209281a45b67d8b3b49512ddda1ecf294d3a31f7 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("b")->dims(); - - PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); - PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); - PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1"); - ctx.Output("Out")->Resize(ctx.Input("X")->dims()); + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_row_dims = b_dims.size(); + + PADDLE_ENFORCE_EQ(framework::slice_ddim( + x_dims, x_dims.size() - num_row_dims, x_dims.size()), + b_dims, "The width of two operands must be same"); + PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); + ctx.Output("Out")->Resize(x_dims); } }; @@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto dims0 = ctx.Input("X")->dims(); - auto dims1 = ctx.Input("b")->dims(); - PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_row_dims = b_dims.size(); + PADDLE_ENFORCE_EQ(framework::slice_ddim( + x_dims, x_dims.size() - num_row_dims, x_dims.size()), + b_dims, "The width of two operands must be same"); auto *dx = ctx.Output(framework::GradVarName("X")); auto *db = ctx.Output(framework::GradVarName("b")); - if (dx) dx->Resize(dims0); - if (db) db->Resize(dims1); + if (dx) dx->Resize(x_dims); + if (db) db->Resize(b_dims); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 4e926d9f2947f37b71e81c0fa592b0c66b19c640..a52a53a7d275426f310c9b82bc125e59f9ae7fb1 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,10 +33,11 @@ class RowwiseAddKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - - auto input = EigenMatrix::From(*context.Input("X")); - auto bias = EigenVector::From(*context.Input("b")); - auto output = EigenMatrix::From(*out); + int num_row_dims = context.Input("b")->dims().size(); + auto input = + EigenMatrix::Reshape(*context.Input("X"), num_row_dims); + auto bias = EigenVector::Flatten(*context.Input("b")); + auto output = EigenMatrix::Reshape(*out, num_row_dims); const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; @@ -54,12 +55,14 @@ class RowwiseAddGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); + int num_row_dims = context.Input("b")->dims().size(); - auto out_grad = EigenMatrix::From(*dout); + auto out_grad = EigenMatrix::Reshape(*dout, num_row_dims); auto place = context.GetEigenDevice(); + if (dx) { dx->mutable_data(context.GetPlace()); - EigenMatrix::From(*dx).device(place) = out_grad; + EigenMatrix::Reshape(*dx, num_row_dims).device(place) = out_grad; } if (db) { diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 2ddb85e2e7a98a08bd1d6e24e6f812f6021142e8..8378c1cd21c21fd31da9b82d2cdaaff332f291d7 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -16,6 +16,18 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class TestRowwiseAddOp2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "rowwise_add" + self.inputs = { + 'X': np.random.random((13, 6, 7, 8)).astype("float32"), + 'b': np.random.random((7, 8)).astype("float32") + } + self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} + + class TestRowwiseAddGradOp(GradientChecker): def setUp(self): self.op = create_op("rowwise_add") @@ -34,5 +46,23 @@ class TestRowwiseAddGradOp(GradientChecker): self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) +class TestRowwiseAddGradOp2(GradientChecker): + def setUp(self): + self.op = create_op("rowwise_add") + self.inputs = { + "X": np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"), + "b": np.random.uniform(0.1, 1, [2, 5]).astype("float32") + } + + def test_normal(self): + self.check_grad(self.op, self.inputs, ["X", "b"], "Out") + + def test_ignore_b(self): + self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"}) + + def test_ignore_x(self): + self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) + + if __name__ == '__main__': unittest.main()