You need to sign in or sign up before continuing.
未验证 提交 8d512b8f 编写于 作者: W wangshengxiang 提交者: GitHub

add prelu & prelu_grad op for xpu (#49672)

上级 ac9debee
...@@ -418,6 +418,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -418,6 +418,9 @@ XPUOpMap& get_kl2_ops() {
{"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})}, {"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})},
{"prior_box", XPUKernelSet({phi::DataType::FLOAT32})}, {"prior_box", XPUKernelSet({phi::DataType::FLOAT32})},
{"prelu", XPUKernelSet({phi::DataType::FLOAT32})},
{"prelu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})}, {"range", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT64})},
{"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})}, {"reciprocal", XPUKernelSet({phi::DataType::FLOAT32})},
{"reciprocal_grad", {"reciprocal_grad",
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const DenseTensor& out_grad,
const std::string& data_format,
const std::string& mode,
DenseTensor* x_grad,
DenseTensor* alpha_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
const T* x_ptr = x.data<T>();
const T* alpha_ptr = alpha.data<T>();
const T* out_grad_ptr = out_grad.data<T>();
T* x_grad_ptr = dev_ctx.template Alloc<T>(x_grad);
T* alpha_grad_ptr = dev_ctx.template Alloc<T>(alpha_grad);
auto x_dim = x.dims();
auto x_rank = x_dim.size();
std::vector<int> x_shape(x_rank);
for (int i = 0; i < x_rank; i++) {
x_shape[i] = x_dim[i];
}
auto alpha_dim = alpha.dims();
auto alpha_rank = alpha_dim.size();
std::vector<int> alpha_shape(alpha_rank);
for (int i = 0; i < x_rank; i++) {
alpha_shape[i] = alpha_dim[i];
}
// mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xhsape = {n,
// c, h, w}
// mode = 1, channel_nhwc, slope_shape = {c}, meanwhile, xhsape = {n, h, w, c}
// mode = 2, elementwise, slope_shape = {c*h*w}
// mode = 3, single slope, slope_shape = {1}
int xpu_mode = 0;
if (mode == "channel") {
if (data_format == "NCHW") {
xpu_mode = 0;
} else {
// NHWC
xpu_mode = 1;
}
} else if (mode == "element") {
xpu_mode = 2;
} else {
xpu_mode = 3;
}
int r = xpu::prelu_grad(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_ptr),
reinterpret_cast<const XPUType*>(
out_grad_ptr), /* const T* y, not used in xpu kernel */
reinterpret_cast<const XPUType*>(alpha_ptr),
reinterpret_cast<const XPUType*>(out_grad_ptr),
reinterpret_cast<XPUType*>(x_grad_ptr),
reinterpret_cast<XPUType*>(alpha_grad_ptr),
x_shape,
xpu_mode);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(prelu_grad,
XPU,
ALL_LAYOUT,
phi::PReluGradKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/prelu_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void PReluKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& alpha,
const std::string& data_format,
const std::string& mode,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const T* x_ptr = x.data<T>();
const T* alpha_ptr = alpha.data<T>();
T* y_ptr = dev_ctx.template Alloc<T>(out);
auto x_dim = x.dims();
auto x_rank = x_dim.size();
std::vector<int> x_shape(x_rank);
for (int i = 0; i < x_rank; i++) {
x_shape[i] = x_dim[i];
}
auto alpha_dim = alpha.dims();
auto alpha_rank = alpha_dim.size();
std::vector<int> alpha_shape(x_rank, 1); // same size with x_shape
for (int i = 0; i < alpha_rank; i++) {
alpha_shape[i] = alpha_dim[i];
}
int r = xpu::prelu(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_ptr),
reinterpret_cast<const XPUType*>(alpha_ptr),
reinterpret_cast<XPUType*>(y_ptr),
x_shape,
alpha_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu");
}
} // namespace phi
PD_REGISTER_KERNEL(prelu, XPU, ALL_LAYOUT, phi::PReluKernel, float) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program
paddle.enable_static()
class XPUTestPReluOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "prelu"
self.use_dynamic_create_class = False
class TestPReluOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "prelu"
self.init_dtype()
self.eager_mode = True
# override
self.init_input_shape()
self.init_attr()
self.x = np.random.uniform(-10.0, 10.0, self.x_shape).astype(
self.dtype
)
# Since zero point in prelu is not differentiable, avoid randomize zero.
self.x[np.abs(self.x) < 0.005] = 0.02
if self.attrs == {
'mode': "all",
"data_format": "NCHW",
} or self.attrs == {'mode': "all", "data_format": "NHWC"}:
self.alpha = np.random.uniform(-1, -0.5, (1))
elif self.attrs == {'mode': "channel", "data_format": "NCHW"}:
self.alpha = np.random.uniform(
-1, -0.5, [1, self.x_shape[1], 1, 1]
)
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
self.alpha = np.random.uniform(
-1, -0.5, [1, 1, 1, self.x_shape[-1]]
)
else:
self.alpha = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
# eager check don't support mode = 'all'
self.eager_mode = False
self.alpha = self.alpha.astype(self.dtype)
self.inputs = {'X': self.x, 'Alpha': self.alpha}
reshaped_alpha = self.inputs['Alpha']
if self.attrs == {'mode': "channel", "data_format": "NCHW"}:
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1, self.x_shape[1]] + [1] * len(self.x_shape[2:]),
)
elif self.attrs == {'mode': "channel", "data_format": "NHWC"}:
reshaped_alpha = np.reshape(
self.inputs['Alpha'],
[1] + [1] * len(self.x_shape[1:-1]) + [self.x_shape[-1]],
)
self.alpha = np.random.uniform(
-10.0, 10.0, [1, self.x_shape[1], 1, 1]
).astype(self.dtype)
out_np = np.maximum(self.inputs['X'], 0.0)
out_np = out_np + np.minimum(self.inputs['X'], 0.0) * reshaped_alpha
assert out_np is not self.inputs['X']
self.outputs = {'Out': out_np}
def init_input_shape(self):
self.x_shape = [2, 3, 5, 6]
def init_attr(self):
self.attrs = {'mode': "channel", 'data_format': "NCHW"}
def set_xpu(self):
self.__class__.no_need_check_grad = False
self.place = paddle.XPUPlace(0)
def init_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Alpha'], 'Out', check_eager=self.eager_mode
)
class TestModeChannelNHWC(TestPReluOp):
def init_input_shape(self):
self.x_shape = [2, 3, 4, 5]
def init_attr(self):
self.attrs = {'mode': "channel", "data_format": "NHWC"}
class TestModeAll(TestPReluOp):
def init_input_shape(self):
self.x_shape = [2, 3, 4, 5]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NCHW"}
class TestModeAllNHWC(TestPReluOp):
def init_input_shape(self):
self.x_shape = [2, 3, 4, 50]
def init_attr(self):
self.attrs = {'mode': "all", "data_format": "NHWC"}
class TestModeElt(TestPReluOp):
def init_input_shape(self):
self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NCHW"}
class TestModeEltNHWC(TestPReluOp):
def init_input_shape(self):
self.x_shape = [3, 2, 5, 10]
def init_attr(self):
self.attrs = {'mode': "element", "data_format": "NHWC"}
def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
alpha_shape = [1, x.shape[1], 1, 1]
dtype = helper.input_dtype(input_param_name='x')
alpha = helper.create_parameter(
attr=helper.param_attr,
shape=alpha_shape,
dtype='float32',
is_bias=False,
default_initializer=fluid.initializer.ConstantInitializer(0.25),
)
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="prelu",
inputs={"X": x, 'Alpha': alpha},
attrs={"mode": mode, 'data_format': data_format},
outputs={"Out": out},
)
return out
# error message test if mode is not one of 'all', 'channel', 'element'
class TestModeError(unittest.TestCase):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.x_np = np.ones([1, 2, 3, 4]).astype('float32')
def test_mode_error(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = prelu_t(x, 'any')
except Exception as e:
assert e.args[0].find('InvalidArgument') != -1
def test_data_format_error1(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = prelu_t(x, 'channel', data_format='N')
except Exception as e:
assert e.args[0].find('InvalidArgument') != -1
def test_data_format_error2(self):
main_program = Program()
with fluid.program_guard(main_program, Program()):
x = fluid.data(name='x', shape=[2, 3, 4, 5])
try:
y = paddle.static.nn.prelu(x, 'channel', data_format='N')
except ValueError as e:
pass
support_types = get_xpu_op_support_types("prelu")
for stype in support_types:
create_test_class(globals(), XPUTestPReluOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册