提交 b6347fb6 编写于 作者: Z zchen0211

prelu fix

上级 4a237884
......@@ -29,6 +29,8 @@ class PReluOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
auto *in = ctx.Input<framework::Tensor>("X");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto *out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(in->dims());
}
......@@ -41,6 +43,8 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator.");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
AddComment(R"DOC(PRelu operator
The equation is:
......@@ -49,8 +53,6 @@ The equation is:
f(x) = x , for x >= 0
)DOC");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
}
};
......
......@@ -24,9 +24,9 @@ using Tensor = framework::Tensor;
using platform::Transform;
template <typename T>
class Prelu_functor {
class PReluFunctor {
public:
explicit Prelu_functor(const T& alpha) : alpha_(alpha) {}
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& X) const {
if (X > 0)
......@@ -54,14 +54,14 @@ class PReluKernel : public framework::OpKernel {
int numel = X->numel();
auto place = context.GetPlace();
Transform(place, X_ptr, X_ptr + numel, O_ptr, Prelu_functor<T>(alpha));
Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha));
}
};
template <typename T>
class Prelu_Grad_functor {
class PReluGradFunctor {
public:
explicit Prelu_Grad_functor(const T& alpha) : alpha_(alpha) {}
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& Out, const T& dOut) const {
if (Out > 0)
......@@ -92,7 +92,7 @@ class PReluGradKernel : public framework::OpKernel {
auto place = context.GetPlace();
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr,
Prelu_Grad_functor<T>(alpha));
PReluGradFunctor<T>(alpha));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册