提交 b3a86b6d 编写于 作者: W wwhu

fix CI

上级 c8c4b6e4
...@@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel { ...@@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
"Input(X) of ClipByNormOp should not be null."); "Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null."); "Output(Out) of ClipByNormOp should not be null.");
auto max_norm = Attr<float>("max_norm"); auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0."); PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
...@@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel {
} }
}; };
template <typename AttrType>
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ClipByNormOpMaker(framework::OpProto* proto, ClipByNormOpMaker(framework::OpProto* proto,
...@@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of dimensions must be between [1, 9]."); "The number of dimensions must be between [1, 9].");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)"); "(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>("max_norm", "(float) The maximum norm value."); AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC( AddComment(R"DOC(
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'. ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
...@@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'. ...@@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp, REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker<float>); ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>); clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册