提交 1b797468 编写于 作者: Z zchen0211

prelu

上级 b6347fb6
......@@ -29,6 +29,11 @@ class PReluOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
auto *in = ctx.Input<framework::Tensor>("X");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"),
"Input(Alpha) should not be null");
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) should not be null");
auto *out = ctx.Output<framework::LoDTensor>("Out");
......@@ -36,15 +41,13 @@ class PReluOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator.");
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
.SetDefault(0.0);
AddComment(R"DOC(PRelu operator
The equation is:
......@@ -66,11 +69,15 @@ class PReluGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto *X_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *X = ctx.Input<framework::Tensor>("X");
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *x = ctx.Input<framework::Tensor>("X");
auto *dalpha =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Alpha"));
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
X_grad->Resize(X->dims());
dx->Resize(x->dims());
dalpha->Resize(alpha->dims());
}
};
......@@ -79,7 +86,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker<float>, prelu_grad,
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad,
ops::PReluGradOp);
REGISTER_OP_CPU_KERNEL(prelu,
ops::PReluKernel<paddle::platform::CPUPlace, float>);
......
......@@ -28,33 +28,35 @@ class PReluFunctor {
public:
explicit PReluFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& X) const {
if (X > 0)
return X;
HOSTDEVICE T operator()(const T& x) const {
if (x > 0)
return x;
else
return X * alpha_;
return x * alpha_;
}
private:
T alpha_;
};
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class PReluKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
auto* x = context.Input<Tensor>("X");
auto* alpha = context.Input<Tensor>("Alpha");
auto* out = context.Output<Tensor>("Out");
const T* X_ptr = X->data<T>();
T* O_ptr = Out->mutable_data<T>(context.GetPlace());
const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace());
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
auto alpha_val = alpha->data<T>()[0];
// auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
int numel = X->numel();
int numel = x->numel();
auto place = context.GetPlace();
Transform(place, X_ptr, X_ptr + numel, O_ptr, PReluFunctor<T>(alpha));
Transform(place, x_ptr, x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_val));
}
};
......@@ -63,36 +65,36 @@ class PReluGradFunctor {
public:
explicit PReluGradFunctor(const T& alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& Out, const T& dOut) const {
if (Out > 0)
return dOut;
HOSTDEVICE T operator()(const T& out, const T& dout) const {
if (out > 0)
return dout;
else
return dOut * alpha_;
return dout * alpha_;
}
private:
T alpha_;
};
template <typename Place, typename T, typename AttrType = T>
template <typename Place, typename T>
class PReluGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* dO = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* Out = context.Input<Tensor>("Out");
auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha");
auto alpha_val = alpha->data<T>()[0];
auto alpha = static_cast<T>(context.Attr<AttrType>("alpha"));
T* dX_ptr = dX->mutable_data<T>(context.GetPlace());
const T* dO_ptr = dO->data<T>();
const T* O_ptr = Out->data<T>();
int numel = dX->numel();
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
const T* dout_ptr = dout->data<T>();
const T* out_ptr = out->data<T>();
int numel = dx->numel();
auto place = context.GetPlace();
Transform(place, O_ptr, O_ptr + numel, dO_ptr, dX_ptr,
PReluGradFunctor<T>(alpha));
Transform(place, out_ptr, out_ptr + numel, dout_ptr, dx_ptr,
PReluGradFunctor<T>(alpha_val));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册