From faba4b116a0961ae854d8e777f34301cb1aef027 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 15 Jan 2020 15:03:31 +0800 Subject: [PATCH] Remove disable flag in test_fsp_op.py (#22171) * fix fsp_op, test=develop * fix fsp grad op maker, test=develop * update op_use_default_grad_op_maker.spec, test=develop --- .../fluid/op_use_default_grad_op_maker.spec | 1 - paddle/fluid/operators/fsp_op.cc | 32 ++++++++++++++++--- paddle/fluid/operators/fsp_op.h | 4 +++ .../fluid/tests/unittests/test_fsp_op.py | 1 - 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index c6ffd5483ad..076206ca0aa 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -1,5 +1,4 @@ cos_sim -fsp gru match_matrix_tensor maxout diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index 020b03e3eea..59fc3b7ab7b 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fsp_op.h" +#include namespace paddle { namespace operators { @@ -114,14 +115,37 @@ class FSPOpGrad : public framework::OperatorWithKernel { } }; +template +class FSPGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + + op->SetType("fsp_grad"); + + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - fsp, ops::FSPOp, ops::FSPOpMaker, - paddle::framework::DefaultGradOpMaker, - paddle::framework::DefaultGradOpMaker); +REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker, + ops::FSPGradOpMaker, + ops::FSPGradOpMaker); REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad); REGISTER_OP_CPU_KERNEL( fsp, ops::FSPOpKernel, diff --git a/paddle/fluid/operators/fsp_op.h b/paddle/fluid/operators/fsp_op.h index 544af2b7d9b..55bd23784d4 100644 --- a/paddle/fluid/operators/fsp_op.h +++ b/paddle/fluid/operators/fsp_op.h @@ -46,6 +46,7 @@ class FSPOpKernel : public framework::OpKernel { x_mat_desc.width_ = height * width; x_mat_desc.batch_size_ = batch_size; x_mat_desc.stride_ = x_channel * height * width; + x_mat_desc.trans_ = false; math::MatDescriptor y_mat_desc; y_mat_desc.height_ = height * width; @@ -93,12 +94,14 @@ class FSPGradOpKernel : public framework::OpKernel { d_out_mat_desc.width_ = y_channel; d_out_mat_desc.batch_size_ = batch_size; d_out_mat_desc.stride_ = x_channel * y_channel; + d_out_mat_desc.trans_ = false; math::MatDescriptor y_mat_desc; y_mat_desc.height_ = y_channel; y_mat_desc.width_ = h * w; y_mat_desc.batch_size_ = batch_size; y_mat_desc.stride_ = y_channel * h * w; + y_mat_desc.trans_ = false; blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc, static_cast(1.0 / (h * w)), d_x, static_cast(0.0)); @@ -125,6 +128,7 @@ class FSPGradOpKernel : public framework::OpKernel { x_mat_desc.width_ = h * w; x_mat_desc.batch_size_ = batch_size; x_mat_desc.stride_ = x_channel * h * w; + x_mat_desc.trans_ = false; blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc, static_cast(1.0 / (h * w)), d_y, static_cast(0.0)); diff --git a/python/paddle/fluid/tests/unittests/test_fsp_op.py b/python/paddle/fluid/tests/unittests/test_fsp_op.py index b4be4af69a5..3503c4ade4a 100644 --- a/python/paddle/fluid/tests/unittests/test_fsp_op.py +++ b/python/paddle/fluid/tests/unittests/test_fsp_op.py @@ -34,7 +34,6 @@ def fsp_matrix(a, b): return np.mean(a_r * b_r, axis=1) -@unittest.skip("Disable temporarily.") class TestFSPOp(OpTest): def setUp(self): self.op_type = "fsp" -- GitLab