提交 b3a86b6d 编写于 作者: W wwhu

fix CI

上级 c8c4b6e4
......@@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = Attr<float>("max_norm");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
......@@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(framework::OpProto* proto,
......@@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>("max_norm", "(float) The maximum norm value.");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
......@@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker<float>);
ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册