提交 c7dfec11 编写于 作者: Z zchen0211

fix

上级 384368f4
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class PreluOp : public framework::OperatorWithKernel { class PReluOp : public framework::OperatorWithKernel {
public: public:
PreluOp(const std::string &type, const framework::VariableNameMap &inputs, PReluOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
...@@ -34,13 +34,13 @@ class PreluOp : public framework::OperatorWithKernel { ...@@ -34,13 +34,13 @@ class PreluOp : public framework::OperatorWithKernel {
}; };
// template <typename AttrType> // template <typename AttrType>
class PreluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PreluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator."); AddInput("X", "The input tensor of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator.");
AddComment(R"DOC(Prelu operator AddComment(R"DOC(PRelu operator
The equation is: The equation is:
f(x) = alpha * x , for x < 0 f(x) = alpha * x , for x < 0
...@@ -52,7 +52,7 @@ f(x) = x , for x >= 0 ...@@ -52,7 +52,7 @@ f(x) = x , for x >= 0
}; };
// The operator to calculate gradients of a prelu operator. // The operator to calculate gradients of a prelu operator.
class PreluGradOp : public framework::OperatorWithKernel { class PReluGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -72,9 +72,9 @@ class PreluGradOp : public framework::OperatorWithKernel { ...@@ -72,9 +72,9 @@ class PreluGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(prelu, ops::PreluOp, ops::PreluOpMaker, prelu_grad, REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad,
ops::PreluGradOp); ops::PReluGradOp);
REGISTER_OP_CPU_KERNEL(prelu, REGISTER_OP_CPU_KERNEL(prelu,
ops::PreluKernel<paddle::platform::CPUPlace, float>); ops::PReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(prelu_grad, REGISTER_OP_CPU_KERNEL(prelu_grad,
ops::PreluGradKernel<paddle::platform::CPUPlace, float>); ops::PReluGradKernel<paddle::platform::CPUPlace, float>);
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class PreluKernel : public framework::OpKernel { class PReluKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
...@@ -45,7 +45,7 @@ class PreluKernel : public framework::OpKernel { ...@@ -45,7 +45,7 @@ class PreluKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class PreluGradKernel : public framework::OpKernel { class PReluGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册