diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..831958e3a46f058d6fff681263962c2830d4e7fd --- /dev/null +++ b/paddle/operators/prelu_op.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class PreluOp : public framework::OperatorWithKernel { + public: + PreluOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *in = ctx.Input("X"); + auto *out = ctx.Output("Out"); + out->Resize(in->dims()); + } +}; + +template +class PreluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PreluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of prelu operator.").NotInGradient(); + AddOutput("Out", "The output tensor of prelu operator.").NotInGradient(); + AddComment(R"DOC(Prelu operator + +The equation is: +f(x) = alpha * x , for x < 0 +f(x) = x , for x >= 0 +)DOC"); + AddAttr("alpha", "The scaling factor alpha of prelu.") + .SetDefault(0.0); + } +}; + +// The operator to calculate gradients of a prelu operator. +class PreluGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto X_grad = ctx.Output(framework::GradVarName("X")); + auto X = ctx.Input("X"); + + X_grad->Resize(X->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(prelu, ops::PreluOp, ops::PreluOpMaker, prelu_grad, + ops::PreluGradOp); +REGISTER_OP_CPU_KERNEL(prelu, + ops::PreluKernel); +REGISTER_OP_CPU_KERNEL(prelu_grad, + ops::PreluGradKernel); diff --git a/paddle/operators/prelu_op.cu b/paddle/operators/prelu_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..54a9089bdb66c47cbe21d87d2f6813b7aec4a299 --- /dev/null +++ b/paddle/operators/prelu_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/prelu_op.h" + +REGISTER_OP_GPU_KERNEL( + prelu, paddle::operators::PreluKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0bb6f61e3061d2e173d42c53029cf102b08afa89 --- /dev/null +++ b/paddle/operators/prelu_op.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class PreluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + + Out->mutable_data(context.GetPlace()); + + auto alpha = static_cast(context.Attr("alpha")); + + auto X_vec = EigenVector::Flatten(*X); + auto Out_vec = EigenVector::Flatten(*Out); + + auto place = context.GetEigenDevice(); + + Out_vec.device(place) = X_vec.cwiseMax(0.f) + X_vec.cwiseMin(0.f) * alpha; + } +}; + +template +class PreluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dX = context.Output(framework::GradVarName("X")); + auto* dO = context.Input(framework::GradVarName("Out")); + + auto* Out = context.Output("Out"); + + auto alpha = static_cast(context.Attr("alpha")); + + dX->mutable_data(context.GetPlace()); + + for (int i = 0; i < dX->numel(); ++i) { + if (Out->data()[i] > 0) { + dX->data()[i] = dO->data()[i]; + } else { + dX->data()[i] = dO->data()[i] * alpha; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3916696a69c958d5609ecdf961a25c28352184 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -0,0 +1,23 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class ScaleTest(OpTest): + def setUp(self): + self.op_type = "prelu" + self.inputs = {'X': np.random.random((10, 10)).astype("float32")} + self.attrs = {'alpha': 0.1} + out_np = np.maximum(self.inputs['X'], 0.) + out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha'] + self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == "__main__": + unittest.main()