diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index c020ff45ad3f3a72bf8a88622df333c1765a3d21..ea9105d79c18310b16bdc81a7ab0e643e72c3965 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index db040509bc08c3f6ad031c5b97c93574e31337e0..23d9ea88f6701f9f9e5e02948e996878a849ddd6 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { + std::string mode = ctx->Attrs().Get("mode"); + + auto x_dim = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); - PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, - "Size of weight Alpha must be one."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (mode == "all") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, + "For mode 'all', size of weight Alpha must be one."); + } else if (mode == "channel") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1], + "For channel-wise mode, size of weight Alpha must be " + "equal to the number of channels, should be %d", + x_dim[1]); + } else if (mode == "element") { + PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim), + "For element-wise mode, size of weight Alpha must be " + "equal to the number of input, should be %d", + product(x_dim)); + } else { + PADDLE_THROW("Unkown mode %s", mode); + } + ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } }; class PReluOpMaker : public framework::OpProtoAndCheckerMaker { @@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of prelu operator."); AddComment(R"DOC( PRelu Operator. - The equation is: - $$ f(x) = \begin{cases} @@ -54,11 +75,15 @@ f(x) = x, \qquad \text{if} \ x >= 0 \end{cases} $$ - The input `X` can carry the LoD (Level of Details) information, or not. And the output shares the LoD information with input `X`. - +There are modes: + all: all elements share same weight + channel: elements in a channel share same weight + element: each element has a weight )DOC"); + AddAttr("mode", "The mode for inputs to share weights.") + .SetDefault("all"); } }; @@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->SetOutputDim(framework::GradVarName("Alpha"), - ctx->GetInputDim("Alpha")); + auto x_grad_name = framework::GradVarName("X"); + auto alpha_grad_name = framework::GradVarName("Alpha"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + } + if (ctx->HasOutput(alpha_grad_name)) { + ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu deleted file mode 100644 index 37d934a29046be04a1721b7330c813f663f61aed..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/prelu_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/prelu_op.h" - -REGISTER_OP_CUDA_KERNEL( - prelu, - paddle::operators::PReluKernel); -REGISTER_OP_CUDA_KERNEL(prelu_grad, - paddle::operators::PReluGradKernel< - paddle::platform::CUDADeviceContext, float>); diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index a6197d354833a2f4173003ad2a970c487ad9a65b..f9076cbc678534fd5490fa0d7adeac0e50909a39 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -13,32 +10,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/transform.h" - namespace paddle { namespace operators { using Tensor = framework::Tensor; using platform::Transform; -template -class PReluFunctor { - public: - explicit PReluFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& x) const { - if (x > 0) - return x; - else - return x * (*alpha_); - } - - private: - const T* alpha_; -}; - template class PReluKernel : public framework::OpKernel { public: @@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel { const T* x_ptr = x->data(); T* o_ptr = out->mutable_data(context.GetPlace()); - auto* alpha_ptr = alpha->data(); + const T* alpha_ptr = alpha->data(); + std::string mode = context.Attr("mode"); int numel = x->numel(); - - Transform trans; - trans(context.template device_context(), x_ptr, - x_ptr + numel, o_ptr, PReluFunctor(alpha_ptr)); - } -}; - -template -class PReluGradFunctor { - public: - explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} - - HOSTDEVICE T operator()(const T& out, const T& dout) const { - if (out > 0) - return dout; - else - return dout * (*alpha_); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i]; + } + } } - - private: - const T* alpha_; }; template class PReluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* dx = context.Output(framework::GradVarName("X")); auto* dout = context.Input(framework::GradVarName("Out")); - + auto* dalpha = context.Output(framework::GradVarName("Alpha")); auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); - auto* alpha_ptr = alpha->data(); - - T* dx_ptr = dx->mutable_data(context.GetPlace()); + const T* alpha_ptr = alpha->data(); + const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); const T* out_ptr = out->data(); - int numel = dx->numel(); - - Transform trans; - trans(context.template device_context(), out_ptr, - out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor(alpha_ptr)); - - // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready + std::string mode = context.Attr("mode"); + int numel = x->numel(); + auto dim = x->dims(); + int index = 0; + int i = 0; + int temp = 0; + if (dx) { + T* dx_ptr = dx->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dx_ptr[i] = + out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; + } + } + } + + index = 0; + if (dalpha) { + T* dalpha_ptr = dalpha->mutable_data(context.GetPlace()); + if (mode == "channel") { + for (i = 0; i < numel; i++) { + temp = numel / (dim[0] * dim[1]); + index = (i / temp) % dim[1]; + dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else if (mode == "element") { + for (i = 0; i < numel; i++) { + dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } + } + + // TODO(Guanzhong): add GPU kernels } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c75e7eeb4384f28ca0dd95e3b79b7de5a3031351..3e50fc91d92d0b338f2c0282e3e61bfda6e5edbc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -112,6 +112,7 @@ __all__ = [ 'log', 'crop', 'rank_loss', + 'prelu', 'flatten', ] @@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None): return out +def prelu(x, mode, param_attr=None, name=None): + """ + Equation: + + y = \max(0, x) + alpha \min(0, x) + + Args: + x (Variable): The input tensor. + param_attr(ParamAttr|None): The parameter attribute for the learnable + weight (alpha). + mode (string): The mode for weight sharing + all: all elements share same weight + channel:elements in a channel share same weight + element:each element has a weight + name(str|None): A name for this layer(optional). If set None, the layer + will be named automatically. + + Returns: + Variable: The output tensor with the same shape as input. + + Examples: + + .. code-block:: python + + x = fluid.layers.data(name="x", shape=[10,10], dtype="float32") + mode = 'channel' + output = fluid.layers.prelu(x,mode) + """ + helper = LayerHelper('prelu', **locals()) + if mode not in ['all', 'channel', 'element']: + raise ValueError('mode should be one of all, channel, element.') + alpha_shape = [1] + if mode == 'channel': + alpha_shape = [1, x.shape[1], 1, 1] + elif mode == 'element': + alpha_shape = x.shape + dtype = helper.input_dtype(input_param_name='x') + alpha = helper.create_parameter( + attr=param_attr, + shape=alpha_shape, + dtype='float32', + is_bias=False, + default_initializer=Constant(1.0)) + out = helper.create_tmp_variable(dtype) + helper.append_op( + type="prelu", + inputs={"X": x, + 'Alpha': alpha}, + attrs={"mode": mode}, + outputs={"Out": out}) + return out + + def flatten(x, axis=1, name=None): """ **Flatten layer** diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38a138a8faee7746d9e7630d39206066634956f8..07fd0575d333dacf309620a883e4052c6126739f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -21,6 +21,7 @@ import paddle.fluid.nets as nets from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.param_attr import ParamAttr import decorators +from paddle.fluid.initializer import Constant class TestBook(unittest.TestCase): @@ -485,6 +486,20 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_prelu(self): + program = Program() + with program_guard(program): + input = layers.data( + name="input", shape=[5, 200, 100, 100], dtype="float32") + mode = 'channel' + out = layers.prelu( + input, + mode, + param_attr=ParamAttr(initializer=Constant(1.0)), + name='prelu') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index ae19a553bb826002c562c15ee07759391d51b4d8..cb7de3fc93c0379ea50c88044876d6a8ee617a69 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -20,30 +20,58 @@ from op_test import OpTest class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" - x_np = np.random.normal(size=(10, 10)).astype("float32") - - for pos, val in np.ndenumerate(x_np): - # Since zero point in prelu is not differentiable, avoid randomize - # zero. - while abs(val) < 1e-3: - x_np[pos] = np.random.normal() - val = x_np[pos] - - x_np_sign = np.sign(x_np) - x_np = x_np_sign * np.maximum(x_np, .005) - alpha_np = np.array([.1], dtype="float32") - self.inputs = {'X': x_np, 'Alpha': alpha_np} + self.initTestCase() + x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32") + + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + x_np[np.abs(x_np) < 0.005] = 0.02 + + if self.attrs == {'mode': "all"}: + alpha_np = np.random.rand(1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + elif self.attrs == {'mode': "channel"}: + alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + else: + alpha_np = np.random.rand(*x_np.shape).astype("float32") + self.inputs = {'X': x_np, 'Alpha': alpha_np} + out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.inputs['Alpha'] assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} + def initTestCase(self): + self.attrs = {'mode': "channel"} + def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X', 'Alpha'], 'Out') + + def test_check_grad_ignore_x(self): + self.check_grad(['Alpha'], 'Out', no_grad_set=set('X')) + + def test_check_grad_ignore_alpha(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Alpha')) + + +class TestCase1(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "all"} + + +class TestCase2(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "channel"} + + +class TestCase3(PReluTest): + def initTestCase(self): + self.attrs = {'mode': "element"} if __name__ == "__main__":