提交 34d68f24 编写于 作者: W wwhu

fix doc and code style

上级 65451b5c
......@@ -39,15 +39,14 @@ template <typename AttrType>
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor)The input of clip_by_norm op."
"(Tensor) The input of clip_by_norm op."
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor)The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>(
"max_norm", "(float)The maximum norm value.");
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>("max_norm", "(float)The maximum norm value.");
AddComment(R"DOC(
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
......@@ -62,29 +61,11 @@ where norm('X') represents the L2 norm of 'X'.
}
};
class ClipByNormOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm,
ops::ClipByNormOp,
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker<float>);
REGISTER_OP_CPU_KERNEL(clip_by_norm,
ops::ClipByNormKernel
<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
......@@ -15,6 +15,5 @@
#include "paddle/operators/clip_by_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(clip_by_norm,
ops::ClipByNormKernel
<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::GPUPlace, float>);
......@@ -25,9 +25,6 @@ using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T>
class ClipByNormKernel : public framework::OpKernel<T> {
......
......@@ -18,21 +18,19 @@ class TestClipByNormOp(OpTest):
output = self.max_norm * input / norm
else:
output = input
self.outputs = {
'Out': output
}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.shape = (100,)
self.shape = (100, )
self.max_norm = 1.0
class TestCase1(TestClipByNormOp):
def initTestCase(self):
self.shape = (100,)
self.shape = (100, )
self.max_norm = 1e20
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册