diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index cde3dfa1d3d19b1bee9fd23dad52ecbbe628c3a9..2b788a76cafe198abb9aed8ba842e37cc6ff73a6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -45,7 +45,19 @@ class GreaterThanChecker { public: explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} void operator()(T& value) const { - PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail"); + PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails."); + } + + private: + T lower_bound_; +}; + +template +class EqualGreaterThanChecker { + public: + explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(T& value) const { + PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails."); } private: @@ -115,6 +127,11 @@ class TypedAttrChecker { return *this; } + TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { + value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); + return *this; + } + // we can add more common limits, like LessThan(), Between()... TypedAttrChecker& SetDefault(const T& default_value) { diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 85b7de79743bb0390d66b8999f2e8342a51d14a9..fc3d508553c0e966978b28d58127bdbff10d45f1 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } + +DDim flatten_to_2d(const DDim& src, int num_col_dims) { + int rank = src.size(); + return make_ddim({product(slice_ddim(src, 0, num_col_dims)), + product(slice_ddim(src, num_col_dims, rank))}); +} + +DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index db30c523948b1d437615aa0e9bfecb5e25569296..ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -115,6 +115,12 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +// Reshape a tensor to a matrix. The matrix's first dimension(column length) +// will be the product of tensor's first `num_col_dims` dimensions. +DDim flatten_to_2d(const DDim& src, int num_col_dims); + +DDim flatten_to_1d(const DDim& src); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 2d8d9ae10c56e0632414a5bbc754d35bfa9ce6a5..54bbeafcabdeeb1e2c1017c156b3512c83dada3a 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -63,20 +63,35 @@ struct EigenTensor { template -struct EigenMatrix : public EigenTensor {}; +struct EigenMatrix : public EigenTensor { + static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_col_dims)); + } + + static typename EigenMatrix::ConstType Reshape(const Tensor& tensor, + int num_col_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank, + "`num_col_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_col_dims)); + } +}; template struct EigenVector : public EigenTensor { // Flatten reshapes a Tensor into an EigenVector. static typename EigenVector::Type Flatten(Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {product(tensor.dims_)}); } static typename EigenVector::ConstType Flatten(const Tensor& tensor) { - return EigenVector::From( - tensor, make_ddim({static_cast(product(tensor.dims_))})); + return EigenVector::From(tensor, {product(tensor.dims_)}); } }; diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index dc1957691b1a202826e10e84c21ac8874df9e378..bc4a2db32cfba66bef2c444e1f822e0d2a57b91e 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -108,5 +108,24 @@ TEST(Eigen, Matrix) { } } +TEST(Eigen, MatrixReshape) { + Tensor t; + float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace()); + for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::Reshape(t, 2); + + ASSERT_EQ(2 * 3, em.dimension(0)); + ASSERT_EQ(6 * 4, em.dimension(1)); + + for (int i = 0; i < 2 * 3; i++) { + for (int j = 0; j < 6 * 4; j++) { + ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 643f875491724bf443bd7727391734377ee6180c..ce938b21437195fed8c1adad4329fd139f3f96ab 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -43,6 +43,9 @@ class Tensor { template friend struct EigenTensor; + template + friend struct EigenMatrix; + template friend struct EigenVector; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 94f436294f350e2a39785a09959efb3b17bd00a5..637f04ae0037bd402d855b8bcde8087bfe8328d1 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) { inline const DDim& Tensor::dims() const { return dims_; } +template +inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { + Tensor res; + res.ShareDataWith(src); + res.Resize(flatten_to_2d(src.dims(), num_col_dims)); + return res; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 7db38d5caeebccf710334e854faf785ef0f64063..55302ea47120f420e952b26830c8ea4cbcce6435 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { } #endif } + +TEST(Tensor, ReshapeToMatrix) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace()); + for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { + src_ptr[i] = i; + } + Tensor res = ReshapeToMatrix(src, 2); + ASSERT_EQ(res.dims()[0], 2 * 3); + ASSERT_EQ(res.dims()[1], 4 * 9); +} \ No newline at end of file diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 28a47cdff2e9b7a965ff9f99e787bb8315010823..710a56a0e8e2d17162d7d000df226f1537104eb9 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("Y")->dims(); - PADDLE_ENFORCE_EQ(dim0.size(), 2, - "input X(%s) should be a tensor with 2 dims, a matrix", - ctx.op().Input("X")); - PADDLE_ENFORCE_EQ(dim1.size(), 2, - "input Y(%s) should be a tensor with 2 dims, a matrix", - ctx.op().Input("Y")); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + int x_num_col_dims = Attr("x_num_col_dims"); + int y_num_col_dims = Attr("y_num_col_dims"); + + PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, + "The rank of input tensor X(%s) should be larger than " + "`mul_op`'s `x_num_col_dims`.", + ctx.op().Input("X")); + PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, + "The rank of input tensor Y(%s) should be larger than " + "`mul_op`'s `y_num_col_dims`.", + ctx.op().Input("Y")); + + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); + PADDLE_ENFORCE_EQ( - dim0[1], dim1[0], + x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); } }; @@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of mul op"); AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); + AddAttr( + "x_num_col_dims", + R"DOC(mul_op can take tensors with more than two dimensions as input `X`, + in that case, tensors will be reshaped to a matrix. The matrix's first + dimension(column length) will be the product of tensor's last + `num_col_dims` dimensions, and the matrix's second dimension(row length) + will be the product of tensor's first `rank - num_col_dims` dimensions. + )DOC") + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr( + "y_num_col_dims", + R"DOC(mul_op can take tensors with more than two dimensions as input `Y`, + in that case, tensors will be reshaped to a matrix. Just like input `X`. + )DOC") + .SetDefault(1) + .EqualGreaterThan(1); AddComment(R"DOC( Two Element Mul Operator. @@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel { auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - PADDLE_ENFORCE(x_dims[0] == out_dims[0], - "Out@GRAD M X N must equal to X dims 0, M "); - PADDLE_ENFORCE(y_dims[1] == out_dims[1], - "Out@GRAD M X N must equal to Y dims 1, N "); + + auto x_mat_dims = + framework::flatten_to_2d(x_dims, Attr("x_num_col_dims")); + auto y_mat_dims = + framework::flatten_to_2d(y_dims, Attr("y_num_col_dims")); + + PADDLE_ENFORCE_EQ( + x_mat_dims[0], out_dims[0], + "The first dimension of Out@GRAD must equal to the first dimension of " + "the first operand."); + PADDLE_ENFORCE_EQ( + y_mat_dims[1], out_dims[1], + "The second dimension of Out@GRAD must equal to the second " + "dimension of the second operand."); if (x_grad) x_grad->Resize(x_dims); if (y_grad) y_grad->Resize(y_dims); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 05a79e13b3470e39a5ebd0394ba05629553a5075..3c01f868bda8cba488b3403df456d63d6b082fa6 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -1,7 +1,7 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. + You may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -31,13 +31,25 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* z = context.Output("Out"); + const Tensor* x = context.Input("X"); + const Tensor* y = context.Input("Y"); + Tensor* z = context.Output("Out"); + const Tensor x_matrix = + x->dims().size() > 2 + ? framework::ReshapeToMatrix( + *x, context.template Attr("x_num_col_dims")) + : *x; + const Tensor y_matrix = + y->dims().size() > 2 + ? framework::ReshapeToMatrix( + *y, context.template Attr("y_num_col_dims")) + : *y; + z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(*x, false, *y, false, 1, z, 0, device_context); + math::matmul(x_matrix, false, y_matrix, false, 1, z, 0, + device_context); } }; @@ -45,23 +57,39 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor x_matrix = + x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims) + : *x; + const Tensor y_matrix = + y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims) + : *y; + const Tensor* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); + Tensor* dx = ctx.Output(framework::GradVarName("X")); + Tensor* dy = ctx.Output(framework::GradVarName("Y")); auto* device_context = const_cast(ctx.device_context_); if (dx) { dx->mutable_data(ctx.GetPlace()); + Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix( + *dx, x_num_col_dims) + : *dx; // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - math::matmul(*dout, false, *y, true, 1, dx, 0, device_context); + math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0, + device_context); } if (dy) { dy->mutable_data(ctx.GetPlace()); + Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix( + *dy, y_num_col_dims) + : *dy; // dy = x' * dout. dy K x N, dout : M x N, x : M x K - math::matmul(*x, true, *dout, false, 1, dy, 0, device_context); + math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0, + device_context); } } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 30b4b404315a9f041e21d79b75fd06307e33f7f9..fa8f0ff1a858143af427b51025279c726f1628e0 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("b")->dims(); - - PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); - PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); - PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); - PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1"); - ctx.Output("Out")->Resize(ctx.Input("X")->dims()); + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_col_dims = x_dims.size() - b_dims.size(); + + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); + PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1"); + ctx.Output("Out")->Resize(x_dims); } }; @@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto dims0 = ctx.Input("X")->dims(); - auto dims1 = ctx.Input("b")->dims(); - PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + auto x_dims = ctx.Input("X")->dims(); + auto b_dims = ctx.Input("b")->dims(); + PADDLE_ENFORCE_GT( + x_dims.size(), b_dims.size(), + "The rank of input `X` must be larger than the one of input `b`."); + + int num_col_dims = x_dims.size() - b_dims.size(); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims, + "The width of two operands must be same"); auto *dx = ctx.Output(framework::GradVarName("X")); auto *db = ctx.Output(framework::GradVarName("b")); - if (dx) dx->Resize(dims0); - if (db) db->Resize(dims1); + if (dx) dx->Resize(x_dims); + if (db) db->Resize(b_dims); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 4e926d9f2947f37b71e81c0fa592b0c66b19c640..35774b940926f77167b8f19597027e74d3477e5b 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); out->mutable_data(context.GetPlace()); - - auto input = EigenMatrix::From(*context.Input("X")); - auto bias = EigenVector::From(*context.Input("b")); - auto output = EigenMatrix::From(*out); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); + auto input = + EigenMatrix::Reshape(*context.Input("X"), num_col_dims); + auto bias = EigenVector::Flatten(*context.Input("b")); + auto output = EigenMatrix::Reshape(*out, num_col_dims); const int bias_size = bias.dimension(0); const int rest_size = input.size() / bias_size; @@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel { auto* dout = context.Input(framework::GradVarName("Out")); auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); + int num_col_dims = context.Input("X")->dims().size() - + context.Input("b")->dims().size(); - auto out_grad = EigenMatrix::From(*dout); + auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims); auto place = context.GetEigenDevice(); + if (dx) { dx->mutable_data(context.GetPlace()); - EigenMatrix::From(*dx).device(place) = out_grad; + EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad; } if (db) { diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index b58e4266d1588a4b6151f5f896537ded6ddd3896..8c827e242e866b267e0fc4b73c31bafa0ccc7c48 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -2,6 +2,7 @@ import unittest import numpy as np from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta +from paddle.v2.framework.op import Operator class TestMulOp(unittest.TestCase): @@ -16,6 +17,22 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class TestMulOp2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "mul" + self.inputs = { + 'X': np.random.random((15, 4, 12, 10)).astype("float32"), + 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") + } + self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} + self.outputs = { + 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), + self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) + } + + class TestMulGradOp(GradientChecker): def setUp(self): self.op = create_op("mul") @@ -49,7 +66,38 @@ class TestMulGradOp(GradientChecker): no_grad_set={"Y"}) -# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library +class TestMulGradTest2(GradientChecker): + def setUp(self): + self.op = Operator( + "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2) + self.inputs = { + "X": np.random.random((15, 4, 12, 10)).astype("float32"), + "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32") + } + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + def test_normal(self): + self.check_grad( + self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) + + def test_ignore_x(self): + self.check_grad( + self.op, + self.inputs, ["Y"], + "Out", + max_relative_error=0.5, + no_grad_set={"X"}) + + def test_ignore_y(self): + self.check_grad( + self.op, + self.inputs, ["X"], + "Out", + max_relative_error=0.5, + no_grad_set={"Y"}) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 2ddb85e2e7a98a08bd1d6e24e6f812f6021142e8..8378c1cd21c21fd31da9b82d2cdaaff332f291d7 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -16,6 +16,18 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class TestRowwiseAddOp2(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "rowwise_add" + self.inputs = { + 'X': np.random.random((13, 6, 7, 8)).astype("float32"), + 'b': np.random.random((7, 8)).astype("float32") + } + self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} + + class TestRowwiseAddGradOp(GradientChecker): def setUp(self): self.op = create_op("rowwise_add") @@ -34,5 +46,23 @@ class TestRowwiseAddGradOp(GradientChecker): self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) +class TestRowwiseAddGradOp2(GradientChecker): + def setUp(self): + self.op = create_op("rowwise_add") + self.inputs = { + "X": np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"), + "b": np.random.uniform(0.1, 1, [2, 5]).astype("float32") + } + + def test_normal(self): + self.check_grad(self.op, self.inputs, ["X", "b"], "Out") + + def test_ignore_b(self): + self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"}) + + def test_ignore_x(self): + self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) + + if __name__ == '__main__': unittest.main()