未验证 提交 6815c8ab 编写于 作者: Z zhangyikun02 提交者: GitHub

add mish and mish_grad for XPU, test=kunlun (#45098)

上级 3649099f
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220810") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220812")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220810") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220812")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -404,6 +404,49 @@ struct XPULogGradFunctor : public BaseActivationFunctor<T> { ...@@ -404,6 +404,49 @@ struct XPULogGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct XPUMishFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
float threshold = ctx.Attr<float>("threshold");
auto xpu_context =
ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
int r = xpu::mish(xpu_context, x_data, y_data, x->numel(), threshold);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish");
}
};
template <typename T>
struct XPUMishGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const auto *x = ctx.Input<Tensor>("X");
auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const T *x_data = x->data<T>();
const T *y_grad = dOut->data<T>();
T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
float threshold = ctx.Attr<float>("threshold");
auto xpu_context =
ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
int r = xpu::mish_grad(xpu_context,
reinterpret_cast<const float *>(x_data),
reinterpret_cast<const float *>(
x_data), // mish_grad do not need y_data
reinterpret_cast<const float *>(y_grad),
reinterpret_cast<float *>(x_grad),
dX->numel(),
threshold);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish_grad");
}
};
template <typename T> template <typename T>
struct XPUPowFunctor : public BaseActivationFunctor<T> { struct XPUPowFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -589,6 +632,7 @@ REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, ...@@ -589,6 +632,7 @@ REGISTER_ACTIVATION_XPU_KERNEL(hard_swish,
REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu,
XPULeakyReluFunctor, XPULeakyReluFunctor,
XPULeakyReluGradFunctor) XPULeakyReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(mish, XPUMishFunctor, XPUMishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, REGISTER_ACTIVATION_XPU_KERNEL(reciprocal,
XPUReciprocalFunctor, XPUReciprocalFunctor,
XPUReciprocalGradFunctor) XPUReciprocalGradFunctor)
......
...@@ -111,6 +111,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -111,6 +111,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose", {"conv2d_transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad", {"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d", {"depthwise_conv2d",
...@@ -342,6 +346,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -342,6 +346,8 @@ XPUOpMap& get_kl2_ops() {
{"merged_momentum", {"merged_momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"mish_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", {"mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
...@@ -559,6 +565,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -559,6 +565,8 @@ XPUOpMap& get_kl2_ops() {
{"update_loss_scaling", {"update_loss_scaling",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"uniform_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
......
...@@ -1100,5 +1100,57 @@ def ref_thresholded_relu(x, threshold=1.0): ...@@ -1100,5 +1100,57 @@ def ref_thresholded_relu(x, threshold=1.0):
return out return out
class XPUTestMishOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'mish'
self.use_dynamic_create_class = False
class XPUTestMishBase(TestActivationOPBase):
def set_case(self):
self.op_type = "mish"
self.dtype = self.in_type
self.init_config()
threshold = np.random.uniform(0, 1)
out = ref_mish(self.x, threshold)
self.inputs = {'X': self.x}
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True, 'threshold': threshold}
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestMish2(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
class XPUTestMish3(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2,
[4, 512, 15, 15]).astype(self.dtype)
class XPUTestMish4(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2,
[4, 256, 22, 22]).astype(self.dtype)
support_types = get_xpu_op_support_types('mish')
for stype in support_types:
create_test_class(globals(), XPUTestMishOP, stype)
def ref_mish(x, threshold=20):
sp = np.select([x <= threshold, x > threshold], [np.log(1 + np.exp(x)), x])
out = x * np.tanh(sp)
return out
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from op_test_xpu import OpTest, XPUOpTest from op_test_xpu import OpTest, XPUOpTest
import paddle import paddle
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def dmc_bilinear(data_im, height, width, h, w): def dmc_bilinear(data_im, height, width, h, w):
...@@ -111,181 +112,189 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): ...@@ -111,181 +112,189 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
return out return out
class TestModulatedDeformableConvOp(XPUOpTest): class XPUTestModulatedDeformableConvOp(XPUOpTestWrapper):
def setUp(self): def __init__(self):
self.op_type = "deformable_conv" self.op_name = 'deformable_conv'
self.dtype = np.float32 self.use_dynamic_create_class = False
self.init_group()
self.init_dilation() class TestModulatedDeformableConvOp(XPUOpTest):
self.init_test_case()
def setUp(self):
conv_param = { self.op_type = "deformable_conv"
'stride': self.stride, self.dtype = self.in_type
'pad': self.pad, self.place = paddle.XPUPlace(0)
'dilation': self.dilations self.init_group()
} self.init_dilation()
self.init_test_case()
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype) conv_param = {
mask = 10 * np.random.random(self.mask_size).astype(self.dtype) 'stride': self.stride,
filter = np.random.random(self.filter_size).astype(self.dtype) 'pad': self.pad,
output = dconv_im2col_gemm(input, offset, mask, filter, self.groups, 'dilation': self.dilations
conv_param) }
output = output.astype(self.dtype)
input = np.random.random(self.input_size).astype(self.dtype)
self.inputs = { offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
'Input': OpTest.np_dtype_to_fluid_dtype(input), mask = 10 * np.random.random(self.mask_size).astype(self.dtype)
'Offset': OpTest.np_dtype_to_fluid_dtype(offset), filter = np.random.random(self.filter_size).astype(self.dtype)
'Mask': OpTest.np_dtype_to_fluid_dtype(mask), output = dconv_im2col_gemm(input, offset, mask, filter, self.groups,
'Filter': OpTest.np_dtype_to_fluid_dtype(filter) conv_param)
} output = output.astype(self.dtype)
self.attrs = {
'strides': self.stride, self.inputs = {
'paddings': self.pad, 'Input': OpTest.np_dtype_to_fluid_dtype(input),
'groups': self.groups, 'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'deformable_groups': self.deformable_groups, 'Mask': OpTest.np_dtype_to_fluid_dtype(mask),
'im2col_step': self.im2col_step, 'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
'dilations': self.dilations, }
} self.attrs = {
self.outputs = {'Output': output} 'strides': self.stride,
'paddings': self.pad,
def has_cuda(self): 'groups': self.groups,
return core.is_compiled_with_cuda() and (self.use_cudnn 'deformable_groups': self.deformable_groups,
or self.use_cuda) 'im2col_step': self.im2col_step,
'dilations': self.dilations,
def test_check_output(self): }
if core.is_compiled_with_xpu(): self.outputs = {'Output': output}
paddle.enable_static()
place = paddle.XPUPlace(0) def test_check_output(self):
self.check_output_with_place(place) if core.is_compiled_with_xpu():
paddle.enable_static()
def test_check_grad(self): self.check_output_with_place(self.place)
if core.is_compiled_with_xpu():
paddle.enable_static() def test_check_grad(self):
place = paddle.XPUPlace(0) if core.is_compiled_with_xpu():
self.check_grad_with_place(place, paddle.enable_static()
{'Input', 'Offset', 'Mask', 'Filter'}, self.check_grad_with_place(
'Output', self.place, {'Input', 'Offset', 'Mask', 'Filter'},
max_relative_error=0.06) 'Output',
max_relative_error=0.06)
def init_test_case(self):
self.pad = [1, 1] def init_test_case(self):
self.stride = [1, 1] self.pad = [1, 1]
self.dilations = [1, 1] self.stride = [1, 1]
self.input_size = [2, 8, 4, 4] # NCHW self.dilations = [1, 1]
assert np.mod(self.input_size[1], self.groups) == 0 self.input_size = [2, 8, 4, 4] # NCHW
f_c = self.input_size[1] // self.groups assert np.mod(self.input_size[1], self.groups) == 0
self.filter_size = [8, f_c, 3, 3] f_c = self.input_size[1] // self.groups
self.im2col_step = 1 self.filter_size = [8, f_c, 3, 3]
self.deformable_groups = 1 self.im2col_step = 1
offset_c = 2 * self.deformable_groups * self.filter_size[ self.deformable_groups = 1
2] * self.filter_size[3] offset_c = 2 * self.deformable_groups * self.filter_size[
mask_c = self.deformable_groups * self.filter_size[ 2] * self.filter_size[3]
2] * self.filter_size[3] mask_c = self.deformable_groups * self.filter_size[
self.offset_size = [ 2] * self.filter_size[3]
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] self.offset_size = [
] self.input_size[0], offset_c, self.input_size[2],
self.mask_size = [ self.input_size[3]
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] ]
] self.mask_size = [
self.input_size[0], mask_c, self.input_size[2],
def init_dilation(self): self.input_size[3]
self.dilations = [1, 1] ]
def init_group(self): def init_dilation(self):
self.groups = 1 self.dilations = [1, 1]
def init_group(self):
class TestWithDilation(TestModulatedDeformableConvOp): self.groups = 1
def init_test_case(self): class TestWithDilation(TestModulatedDeformableConvOp):
self.pad = [2, 2]
self.stride = [1, 1] def init_test_case(self):
self.input_size = [4, 3, 4, 4] # NCHW self.pad = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0 self.stride = [1, 1]
f_c = self.input_size[1] // self.groups self.input_size = [4, 3, 4, 4] # NCHW
self.filter_size = [6, f_c, 3, 3] assert np.mod(self.input_size[1], self.groups) == 0
self.im2col_step = 1 f_c = self.input_size[1] // self.groups
self.deformable_groups = 1 self.filter_size = [6, f_c, 3, 3]
offset_c = 2 * self.deformable_groups * self.filter_size[ self.im2col_step = 1
2] * self.filter_size[3] self.deformable_groups = 1
mask_c = self.deformable_groups * self.filter_size[ offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3] 2] * self.filter_size[3]
self.offset_size = [ mask_c = self.deformable_groups * self.filter_size[
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] 2] * self.filter_size[3]
] self.offset_size = [
self.mask_size = [ self.input_size[0], offset_c, self.input_size[2],
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] self.input_size[3]
] ]
self.mask_size = [
def init_dilation(self): self.input_size[0], mask_c, self.input_size[2],
self.dilations = [2, 2] self.input_size[3]
]
class TestWith3x3(TestModulatedDeformableConvOp): def init_dilation(self):
self.dilations = [2, 2]
def init_test_case(self):
self.pad = [1, 1] class TestWith3x3(TestModulatedDeformableConvOp):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW def init_test_case(self):
assert np.mod(self.input_size[1], self.groups) == 0 self.pad = [1, 1]
f_c = self.input_size[1] // self.groups self.stride = [1, 1]
self.filter_size = [6, f_c, 3, 3] self.input_size = [2, 3, 5, 5] # NCHW
self.im2col_step = 1 assert np.mod(self.input_size[1], self.groups) == 0
self.deformable_groups = 1 f_c = self.input_size[1] // self.groups
offset_c = 2 * self.deformable_groups * self.filter_size[ self.filter_size = [6, f_c, 3, 3]
2] * self.filter_size[3] self.im2col_step = 1
mask_c = self.deformable_groups * self.filter_size[ self.deformable_groups = 1
2] * self.filter_size[3] offset_c = 2 * self.deformable_groups * self.filter_size[
self.offset_size = [ 2] * self.filter_size[3]
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] mask_c = self.deformable_groups * self.filter_size[
] 2] * self.filter_size[3]
self.mask_size = [ self.offset_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] self.input_size[0], offset_c, self.input_size[2],
] self.input_size[3]
]
self.mask_size = [
class TestModulatedDeformableConvInvalidInput(unittest.TestCase): self.input_size[0], mask_c, self.input_size[2],
self.input_size[3]
def test_error(self): ]
def test_invalid_input(): class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
paddle.enable_static()
input = [1, 3, 32, 32] def test_error(self):
offset = fluid.data(name='offset',
shape=[None, 3, 32, 32], def test_invalid_input():
dtype='float32') paddle.enable_static()
mask = fluid.data(name='mask', input = [1, 3, 32, 32]
shape=[None, 3, 32, 32], offset = fluid.data(name='offset',
dtype='float32') shape=[None, 3, 32, 32],
loss = fluid.layers.deformable_conv(input, dtype='float32')
offset, mask = fluid.data(name='mask',
mask, shape=[None, 3, 32, 32],
num_filters=4, dtype='float32')
filter_size=1) loss = fluid.layers.deformable_conv(input,
offset,
self.assertRaises(TypeError, test_invalid_input) mask,
num_filters=4,
def test_invalid_offset(): filter_size=1)
paddle.enable_static()
input = fluid.data(name='input', self.assertRaises(TypeError, test_invalid_input)
shape=[None, 3, 32, 32],
dtype='int32') def test_invalid_offset():
offset = fluid.data(name='offset', paddle.enable_static()
shape=[None, 3, 32, 32], input = fluid.data(name='input',
dtype='float32') shape=[None, 3, 32, 32],
mask = fluid.data(name='mask', dtype='int32')
shape=[None, 3, 32, 32], offset = fluid.data(name='offset',
dtype='float32') shape=[None, 3, 32, 32],
loss = fluid.layers.deformable_conv(input, dtype='float32')
offset, mask = fluid.data(name='mask',
mask, shape=[None, 3, 32, 32],
num_filters=4, dtype='float32')
filter_size=1) loss = fluid.layers.deformable_conv(input,
offset,
self.assertRaises(TypeError, test_invalid_offset) mask,
num_filters=4,
filter_size=1)
self.assertRaises(TypeError, test_invalid_offset)
support_types = get_xpu_op_support_types('deformable_conv')
for stype in support_types:
create_test_class(globals(), XPUTestModulatedDeformableConvOp, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册