提交 68d96385 编写于 作者: Y Yang Yang

remove REGISTER_OP and REGISTER_OP_EX

上级 4b1a32db
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -69,8 +70,7 @@ class GradOpDescMakerBase { ...@@ -69,8 +70,7 @@ class GradOpDescMakerBase {
" for input argument with a list of variables, " " for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes" " drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient" " the correspondence bewteen a variable and its gradient"
" ambiguous. Use REGISTER_OP_EX to register the op" " ambiguous."
" or call InputGrad(?,false) in GradOpDescMaker."
" Op type %s", " Op type %s",
fwd_op_.Type()); fwd_op_.Type());
......
...@@ -143,32 +143,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -143,32 +143,6 @@ class OpKernelRegistrar : public Registrar {
return 0; \ return 0; \
} }
/**
* Macro to register Operator. When the input is duplicable, you should
* use REGISTER_OP_EX with drop_empty_grad=false instead.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, true)
// When an argument is duplicable, we need to use this version.
// Perhaps we can omit DropEmptyIG template parameter and
// only have one version of REGISTER_OP.
#define REGISTER_OP_EX(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class, drop_empty_grad) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker<drop_empty_grad> { \
using ::paddle::framework::DefaultGradOpDescMaker< \
drop_empty_grad>::DefaultGradOpDescMaker; \
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
}; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class);
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class) REGISTER_OPERATOR(op_type, op_class, op_maker_class)
......
...@@ -103,8 +103,10 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -103,8 +103,10 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
ops::ConcatOpGrad, false) paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */)
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad)
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>) concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>)
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -124,9 +124,11 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel { ...@@ -124,9 +124,11 @@ class SequenceConcatGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_EX(sequence_concat, ops::SequenceConcatOp, REGISTER_OPERATOR(sequence_concat, ops::SequenceConcatOp,
ops::SequenceConcatOpMaker, sequence_concat_grad, ops::SequenceConcatOpMaker,
ops::SequenceConcatGradOp, false); paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */)
REGISTER_OPERATOR(sequence_concat_grad, ops::SequenceConcatGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_concat, sequence_concat,
ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceConcatOpKernel<paddle::platform::CPUDeviceContext, float>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册