提交 1b797468 编写于 作者: Z zchen0211

prelu

上级 b6347fb6
...@@ -29,6 +29,11 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -29,6 +29,11 @@ class PReluOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
auto *in = ctx.Input<framework::Tensor>("X"); auto *in = ctx.Input<framework::Tensor>("X");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"),
"Input(Alpha) should not be null");
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null"); "Output(Out) should not be null");
auto *out = ctx.Output<framework::LoDTensor>("Out"); auto *out = ctx.Output<framework::LoDTensor>("Out");
...@@ -36,15 +41,13 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -36,15 +41,13 @@ class PReluOp : public framework::OperatorWithKernel {
} }
}; };
template <typename AttrType>
class PReluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator."); AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator.");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
AddComment(R"DOC(PRelu operator AddComment(R"DOC(PRelu operator
The equation is: The equation is:
...@@ -66,11 +69,15 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -66,11 +69,15 @@ class PReluGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto *X_grad = auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); auto *x = ctx.Input<framework::Tensor>("X");
auto *X = ctx.Input<framework::Tensor>("X");
auto *dalpha =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Alpha"));
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
X_grad->Resize(X->dims()); dx->Resize(x->dims());
dalpha->Resize(alpha->dims());
} }
}; };
...@@ -79,7 +86,7 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -79,7 +86,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker<float>, prelu_grad, REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad,
ops::PReluGradOp); ops::PReluGradOp);
REGISTER_OP_CPU_KERNEL(prelu, REGISTER_OP_CPU_KERNEL(prelu,
ops::PReluKernel<paddle::platform::CPUPlace, float>); ops::PReluKernel<paddle::platform::CPUPlace, float>);
......
...@@ -28,33 +28,35 @@ class PReluFunctor { ...@@ -28,33 +28,35 @@ class PReluFunctor {
public: public:
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {} explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& X) const { HOSTDEVICE T operator()(const T& x) const {
if (X > 0) if (x > 0)
return X; return x;
else else
return X * alpha_; return x * alpha_;
} }
private: private:
T alpha_; T alpha_;
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T>
class PReluKernel : public framework::OpKernel { class PReluKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out"); auto* alpha = context.Input<Tensor>("Alpha");
auto* out = context.Output<Tensor>("Out");
const T* X_ptr = X->data<T>(); const T* x_ptr = x->data<T>();
T* O_ptr = Out->mutable_data<T>(context.GetPlace()); T* o_ptr = out->mutable_data<T>(context.GetPlace());
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha")); auto alpha_val = alpha->data<T>()[0];
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
int numel = X->numel(); int numel = x->numel();
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha)); Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val));
} }
}; };
...@@ -63,36 +65,36 @@ class PReluGradFunctor { ...@@ -63,36 +65,36 @@ class PReluGradFunctor {
public: public:
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {} explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& Out, const T& dOut) const { HOSTDEVICE T operator()(const T& out, const T& dout) const {
if (Out > 0) if (out > 0)
return dOut; return dout;
else else
return dOut * alpha_; return dout * alpha_;
} }
private: private:
T alpha_; T alpha_;
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T>
class PReluGradKernel : public framework::OpKernel { class PReluGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dO = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* Out = context.Input<Tensor>("Out"); auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha");
auto alpha_val = alpha->data<T>()[0];
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha")); T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
const T* dout_ptr = dout->data<T>();
T* dX_ptr = dX->mutable_data<T>(context.GetPlace()); const T* out_ptr = out->data<T>();
const T* dO_ptr = dO->data<T>(); int numel = dx->numel();
const T* O_ptr = Out->data<T>();
int numel = dX->numel();
auto place = context.GetPlace(); auto place = context.GetPlace();
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr, Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha)); PReluGradFunctor<T>(alpha_val));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册