From 2d6b4f23f0341bbfb0def185e6b5ed4b1c7020e8 Mon Sep 17 00:00:00 2001 From: zhoukunsheng Date: Fri, 12 Apr 2019 16:20:16 +0800 Subject: [PATCH] test=develop bug fix: reduce_all, reduce_any register GRAD_OP, but have not defined GradKernel --- paddle/fluid/operators/reduce_ops/reduce_all_op.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_any_op.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 9 +++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc index b087fbbb94..a3ca9ae067 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_all_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_all_op.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/operators/reduce_ops/reduce_all_op.h" -REGISTER_REDUCE_OP(reduce_all); +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_all); REGISTER_OP_CPU_KERNEL(reduce_all, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc index d865dcb3c9..34f0fffc9a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_any_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_any_op.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/operators/reduce_ops/reduce_any_op.h" -REGISTER_REDUCE_OP(reduce_any); +REGISTER_REDUCE_OP_WITHOUT_GRAD(reduce_any); REGISTER_OP_CPU_KERNEL(reduce_any, ops::ReduceKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 540742c4cd..81e9933276 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -270,3 +270,12 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(op_name##_grad, ops::ReduceGradOp) + +#define REGISTER_REDUCE_OP_WITHOUT_GRAD(op_name) \ + class __##op_name##Maker__ : public ops::ReduceOpMaker { \ + protected: \ + virtual std::string GetName() const { return #op_name; } \ + virtual std::string GetOpType() const { return "Reduce " #op_name; } \ + }; \ + REGISTER_OPERATOR(op_name, ops::ReduceOp, __##op_name##Maker__, \ + paddle::framework::DefaultGradOpDescMaker); -- GitLab