未验证 提交 a829357e 编写于 作者: Z Zhong Hui 提交者: GitHub

register the op version for some ops

register the op version for some ops
上级 5579edfb
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -122,3 +124,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -122,3 +124,15 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>, clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>); ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(clip)
.AddCheckpoint(
R"ROC(
Upgrade clip add a new input [Min])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("Min",
"Pass the mix, min value as input, not attribute. Min is "
"dispensable.")
.NewInput("Max",
"Pass the mix, min value as input, not attribute. Max is "
"dispensable."));
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -128,18 +129,28 @@ class CompareOp : public framework::OperatorWithKernel { ...@@ -128,18 +129,28 @@ class CompareOp : public framework::OperatorWithKernel {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_COMPARE_OP(op_type, _equation) \ #define REGISTER_COMPARE_OP_VERSION(op_type) \
struct _##op_type##Comment { \ REGISTER_OP_VERSION(op_type) \
static char type[]; \ .AddCheckpoint( \
static char equation[]; \ R"ROC(Upgrade compare ops, add a new attribute [force_cpu])ROC", \
}; \ paddle::framework::compatible::OpVersionDesc().NewAttr( \
char _##op_type##Comment::type[]{#op_type}; \ "force_cpu", \
char _##op_type##Comment::equation[]{_equation}; \ "In order to force fill output variable to cpu memory.", \
REGISTER_OPERATOR( \ false));
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ #define REGISTER_COMPARE_OP(op_type, _equation) \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \ struct _##op_type##Comment { \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); \
REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y"); REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor); REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册