提交 c1083765 编写于 作者: J jerrywgz 提交者: qingqing01

Add three modes for prelu_op (#12630)

* Add three modes for prelu_op.
上级 d0684930
...@@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul ...@@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -26,14 +23,40 @@ class PReluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
std::string mode = ctx->Attrs().Get<std::string>("mode");
auto x_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"Size of weight Alpha must be one.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); if (mode == "all") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
"For mode 'all', size of weight Alpha must be one.");
} else if (mode == "channel") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1],
"For channel-wise mode, size of weight Alpha must be "
"equal to the number of channels, should be %d",
x_dim[1]);
} else if (mode == "element") {
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim),
"For element-wise mode, size of weight Alpha must be "
"equal to the number of input, should be %d",
product(x_dim));
} else {
PADDLE_THROW("Unkown mode %s", mode);
}
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
}
}; };
class PReluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -44,9 +67,7 @@ class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of prelu operator."); AddOutput("Out", "The output tensor of prelu operator.");
AddComment(R"DOC( AddComment(R"DOC(
PRelu Operator. PRelu Operator.
The equation is: The equation is:
$$ $$
f(x) = f(x) =
\begin{cases} \begin{cases}
...@@ -54,11 +75,15 @@ f(x) = ...@@ -54,11 +75,15 @@ f(x) =
x, \qquad \text{if} \ x >= 0 x, \qquad \text{if} \ x >= 0
\end{cases} \end{cases}
$$ $$
The input `X` can carry the LoD (Level of Details) information, The input `X` can carry the LoD (Level of Details) information,
or not. And the output shares the LoD information with input `X`. or not. And the output shares the LoD information with input `X`.
There are modes:
all: all elements share same weight
channel: elements in a channel share same weight
element: each element has a weight
)DOC"); )DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
} }
}; };
...@@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -71,9 +96,23 @@ class PReluGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(framework::GradVarName("Alpha"), auto alpha_grad_name = framework::GradVarName("Alpha");
ctx->GetInputDim("Alpha"));
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
}
if (ctx->HasOutput(alpha_grad_name)) {
ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
platform::CPUPlace());
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/prelu_op.h"
REGISTER_OP_CUDA_KERNEL(
prelu,
paddle::operators::PReluKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(prelu_grad,
paddle::operators::PReluGradKernel<
paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,32 +10,16 @@ See the License for the specific language governing permissions and ...@@ -13,32 +10,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using platform::Transform; using platform::Transform;
template <typename T>
class PReluFunctor {
public:
explicit PReluFunctor(const T* alpha) : alpha_(alpha) {}
HOSTDEVICE T operator()(const T& x) const {
if (x > 0)
return x;
else
return x * (*alpha_);
}
private:
const T* alpha_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PReluKernel : public framework::OpKernel<T> { class PReluKernel : public framework::OpKernel<T> {
public: public:
...@@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel<T> { ...@@ -50,53 +31,93 @@ class PReluKernel : public framework::OpKernel<T> {
const T* x_ptr = x->data<T>(); const T* x_ptr = x->data<T>();
T* o_ptr = out->mutable_data<T>(context.GetPlace()); T* o_ptr = out->mutable_data<T>(context.GetPlace());
auto* alpha_ptr = alpha->data<T>(); const T* alpha_ptr = alpha->data<T>();
std::string mode = context.Attr<std::string>("mode");
int numel = x->numel(); int numel = x->numel();
auto dim = x->dims();
Transform<DeviceContext> trans; int index = 0;
trans(context.template device_context<DeviceContext>(), x_ptr, int i = 0;
x_ptr + numel, o_ptr, PReluFunctor<T>(alpha_ptr)); int temp = 0;
} if (mode == "channel") {
}; for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
template <typename T> index = (i / temp) % dim[1];
class PReluGradFunctor { o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i];
public: }
explicit PReluGradFunctor(const T* alpha) : alpha_(alpha) {} } else if (mode == "element") {
for (i = 0; i < numel; i++) {
HOSTDEVICE T operator()(const T& out, const T& dout) const { o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i];
if (out > 0) }
return dout; } else {
else for (i = 0; i < numel; i++) {
return dout * (*alpha_); o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[0] * x_ptr[i];
}
}
} }
private:
const T* alpha_;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PReluGradKernel : public framework::OpKernel<T> { class PReluGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* dx = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out")); auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));
auto* out = context.Input<Tensor>("Out"); auto* out = context.Input<Tensor>("Out");
auto* alpha = context.Input<Tensor>("Alpha"); auto* alpha = context.Input<Tensor>("Alpha");
auto* alpha_ptr = alpha->data<T>(); const T* alpha_ptr = alpha->data<T>();
const T* x_ptr = x->data<T>();
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
const T* dout_ptr = dout->data<T>(); const T* dout_ptr = dout->data<T>();
const T* out_ptr = out->data<T>(); const T* out_ptr = out->data<T>();
int numel = dx->numel(); std::string mode = context.Attr<std::string>("mode");
int numel = x->numel();
Transform<DeviceContext> trans; auto dim = x->dims();
trans(context.template device_context<DeviceContext>(), out_ptr, int index = 0;
out_ptr + numel, dout_ptr, dx_ptr, PReluGradFunctor<T>(alpha_ptr)); int i = 0;
int temp = 0;
// TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready if (dx) {
T* dx_ptr = dx->mutable_data<T>(context.GetPlace());
if (mode == "channel") {
for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1];
dx_ptr[i] =
out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i];
}
} else if (mode == "element") {
for (i = 0; i < numel; i++) {
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i];
}
}
}
index = 0;
if (dalpha) {
T* dalpha_ptr = dalpha->mutable_data<T>(context.GetPlace());
if (mode == "channel") {
for (i = 0; i < numel; i++) {
temp = numel / (dim[0] * dim[1]);
index = (i / temp) % dim[1];
dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else if (mode == "element") {
for (i = 0; i < numel; i++) {
dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
} else {
for (i = 0; i < numel; i++) {
dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i];
}
}
}
// TODO(Guanzhong): add GPU kernels
} }
}; };
......
...@@ -112,6 +112,7 @@ __all__ = [ ...@@ -112,6 +112,7 @@ __all__ = [
'log', 'log',
'crop', 'crop',
'rank_loss', 'rank_loss',
'prelu',
'flatten', 'flatten',
] ]
...@@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None): ...@@ -5364,6 +5365,59 @@ def rank_loss(label, left, right, name=None):
return out return out
def prelu(x, mode, param_attr=None, name=None):
"""
Equation:
y = \max(0, x) + alpha \min(0, x)
Args:
x (Variable): The input tensor.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha).
mode (string): The mode for weight sharing
all: all elements share same weight
channel:elements in a channel share same weight
element:each element has a weight
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output tensor with the same shape as input.
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[10,10], dtype="float32")
mode = 'channel'
output = fluid.layers.prelu(x,mode)
"""
helper = LayerHelper('prelu', **locals())
if mode not in ['all', 'channel', 'element']:
raise ValueError('mode should be one of all, channel, element.')
alpha_shape = [1]
if mode == 'channel':
alpha_shape = [1, x.shape[1], 1, 1]
elif mode == 'element':
alpha_shape = x.shape
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=param_attr,
shape=alpha_shape,
dtype='float32',
is_bias=False,
default_initializer=Constant(1.0))
out = helper.create_tmp_variable(dtype)
helper.append_op(
type="prelu",
inputs={"X": x,
'Alpha': alpha},
attrs={"mode": mode},
outputs={"Out": out})
return out
def flatten(x, axis=1, name=None): def flatten(x, axis=1, name=None):
""" """
**Flatten layer** **Flatten layer**
......
...@@ -21,6 +21,7 @@ import paddle.fluid.nets as nets ...@@ -21,6 +21,7 @@ import paddle.fluid.nets as nets
from paddle.fluid.framework import Program, program_guard, default_main_program from paddle.fluid.framework import Program, program_guard, default_main_program
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
import decorators import decorators
from paddle.fluid.initializer import Constant
class TestBook(unittest.TestCase): class TestBook(unittest.TestCase):
...@@ -485,6 +486,20 @@ class TestBook(unittest.TestCase): ...@@ -485,6 +486,20 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
def test_prelu(self):
program = Program()
with program_guard(program):
input = layers.data(
name="input", shape=[5, 200, 100, 100], dtype="float32")
mode = 'channel'
out = layers.prelu(
input,
mode,
param_attr=ParamAttr(initializer=Constant(1.0)),
name='prelu')
self.assertIsNotNone(out)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,30 +20,58 @@ from op_test import OpTest ...@@ -20,30 +20,58 @@ from op_test import OpTest
class PReluTest(OpTest): class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.op_type = "prelu" self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32") self.initTestCase()
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
for pos, val in np.ndenumerate(x_np):
# Since zero point in prelu is not differentiable, avoid randomize # Since zero point in prelu is not differentiable, avoid randomize
# zero. # zero.
while abs(val) < 1e-3: x_np[np.abs(x_np) < 0.005] = 0.02
x_np[pos] = np.random.normal()
val = x_np[pos] if self.attrs == {'mode': "all"}:
alpha_np = np.random.rand(1).astype("float32")
x_np_sign = np.sign(x_np) self.inputs = {'X': x_np, 'Alpha': alpha_np}
x_np = x_np_sign * np.maximum(x_np, .005) elif self.attrs == {'mode': "channel"}:
alpha_np = np.array([.1], dtype="float32") alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
else:
alpha_np = np.random.rand(*x_np.shape).astype("float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], out_np = out_np + np.minimum(self.inputs['X'],
0.) * self.inputs['Alpha'] 0.) * self.inputs['Alpha']
assert out_np is not self.inputs['X'] assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np} self.outputs = {'Out': out_np}
def initTestCase(self):
self.attrs = {'mode': "channel"}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out') self.check_grad(['X', 'Alpha'], 'Out')
def test_check_grad_ignore_x(self):
self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
def test_check_grad_ignore_alpha(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Alpha'))
class TestCase1(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "all"}
class TestCase2(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "channel"}
class TestCase3(PReluTest):
def initTestCase(self):
self.attrs = {'mode': "element"}
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册