提交 b6347fb6 编写于 作者: Z zchen0211

prelu fix

上级 4a237884
...@@ -29,6 +29,8 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -29,6 +29,8 @@ class PReluOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
auto *in = ctx.Input<framework::Tensor>("X"); auto *in = ctx.Input<framework::Tensor>("X");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(in->dims()); out->Resize(in->dims());
} }
...@@ -41,6 +43,8 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,6 +43,8 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator."); AddInput("X", "The input tensor of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator.");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
AddComment(R"DOC(PRelu operator AddComment(R"DOC(PRelu operator
The equation is: The equation is:
...@@ -49,8 +53,6 @@ The equation is: ...@@ -49,8 +53,6 @@ The equation is:
f(x) = x , for x >= 0 f(x) = x , for x >= 0
)DOC"); )DOC");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
} }
}; };
......
...@@ -24,9 +24,9 @@ using Tensor = framework::Tensor; ...@@ -24,9 +24,9 @@ using Tensor = framework::Tensor;
using platform::Transform; using platform::Transform;
template <typename T> template <typename T>
class Prelu_functor { class PReluFunctor {
public: public:
explicit Prelu_functor(const T& alpha) : alpha_(alpha) {} explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& X) const { HOSTDEVICE T operator()(const T& X) const {
if (X > 0) if (X > 0)
...@@ -54,14 +54,14 @@ class PReluKernel : public framework::OpKernel { ...@@ -54,14 +54,14 @@ class PReluKernel : public framework::OpKernel {
int numel = X->numel(); int numel = X->numel();
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, X_ptr, X_ptr + numel, O_ptr, Prelu_functor<T>(alpha)); Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha));
} }
}; };
template <typename T> template <typename T>
class Prelu_Grad_functor { class PReluGradFunctor {
public: public:
explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {} explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& Out, const T& dOut) const { HOSTDEVICE T operator()(const T& Out, const T& dOut) const {
if (Out > 0) if (Out > 0)
...@@ -92,7 +92,7 @@ class PReluGradKernel : public framework::OpKernel { ...@@ -92,7 +92,7 @@ class PReluGradKernel : public framework::OpKernel {
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr, Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr,
Prelu_Grad_functor<T>(alpha)); PReluGradFunctor<T>(alpha));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册