未验证 提交 a829357e 编写于 作者: Z Zhong Hui 提交者: GitHub

register the op version for some ops

register the op version for some ops
上级 5579edfb
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/clip_op.h"
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -122,3 +124,15 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(clip)
.AddCheckpoint(
R"ROC(
Upgrade clip add a new input [Min])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("Min",
"Pass the mix, min value as input, not attribute. Min is "
"dispensable.")
.NewInput("Max",
"Pass the mix, min value as input, not attribute. Max is "
"dispensable."));
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle {
......@@ -128,18 +129,28 @@ class CompareOp : public framework::OperatorWithKernel {
} // namespace operators
} // namespace paddle
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
#define REGISTER_COMPARE_OP_VERSION(op_type) \
REGISTER_OP_VERSION(op_type) \
.AddCheckpoint( \
R"ROC(Upgrade compare ops, add a new attribute [force_cpu])ROC", \
paddle::framework::compatible::OpVersionDesc().NewAttr( \
"force_cpu", \
"In order to force fill output variable to cpu memory.", \
false));
#define REGISTER_COMPARE_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); \
REGISTER_COMPARE_OP_VERSION(op_type);
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册