未验证 提交 6815c8ab 编写于 作者: Z zhangyikun02 提交者: GitHub

add mish and mish_grad for XPU, test=kunlun (#45098)

上级 3649099f
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220810") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220812")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220810") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220812")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -404,6 +404,49 @@ struct XPULogGradFunctor : public BaseActivationFunctor<T> { ...@@ -404,6 +404,49 @@ struct XPULogGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct XPUMishFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
float threshold = ctx.Attr<float>("threshold");
auto xpu_context =
ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
int r = xpu::mish(xpu_context, x_data, y_data, x->numel(), threshold);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish");
}
};
template <typename T>
struct XPUMishGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
const auto *x = ctx.Input<Tensor>("X");
auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const T *x_data = x->data<T>();
const T *y_grad = dOut->data<T>();
T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
float threshold = ctx.Attr<float>("threshold");
auto xpu_context =
ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
int r = xpu::mish_grad(xpu_context,
reinterpret_cast<const float *>(x_data),
reinterpret_cast<const float *>(
x_data), // mish_grad do not need y_data
reinterpret_cast<const float *>(y_grad),
reinterpret_cast<float *>(x_grad),
dX->numel(),
threshold);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish_grad");
}
};
template <typename T> template <typename T>
struct XPUPowFunctor : public BaseActivationFunctor<T> { struct XPUPowFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const { void operator()(const framework::ExecutionContext &ctx) const {
...@@ -589,6 +632,7 @@ REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, ...@@ -589,6 +632,7 @@ REGISTER_ACTIVATION_XPU_KERNEL(hard_swish,
REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu,
XPULeakyReluFunctor, XPULeakyReluFunctor,
XPULeakyReluGradFunctor) XPULeakyReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(mish, XPUMishFunctor, XPUMishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, REGISTER_ACTIVATION_XPU_KERNEL(reciprocal,
XPUReciprocalFunctor, XPUReciprocalFunctor,
XPUReciprocalGradFunctor) XPUReciprocalGradFunctor)
......
...@@ -111,6 +111,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -111,6 +111,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose", {"conv2d_transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"deformable_conv",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad", {"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d", {"depthwise_conv2d",
...@@ -342,6 +346,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -342,6 +346,8 @@ XPUOpMap& get_kl2_ops() {
{"merged_momentum", {"merged_momentum",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"mish_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mish", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"mul", {"mul",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
...@@ -559,6 +565,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -559,6 +565,8 @@ XPUOpMap& get_kl2_ops() {
{"update_loss_scaling", {"update_loss_scaling",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"uniform_random",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
......
...@@ -1100,5 +1100,57 @@ def ref_thresholded_relu(x, threshold=1.0): ...@@ -1100,5 +1100,57 @@ def ref_thresholded_relu(x, threshold=1.0):
return out return out
class XPUTestMishOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'mish'
self.use_dynamic_create_class = False
class XPUTestMishBase(TestActivationOPBase):
def set_case(self):
self.op_type = "mish"
self.dtype = self.in_type
self.init_config()
threshold = np.random.uniform(0, 1)
out = ref_mish(self.x, threshold)
self.inputs = {'X': self.x}
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True, 'threshold': threshold}
def init_config(self):
self.x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
class XPUTestMish2(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2, [1024, 8]).astype(self.dtype)
class XPUTestMish3(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2,
[4, 512, 15, 15]).astype(self.dtype)
class XPUTestMish4(XPUTestMishBase):
def init_config(self):
self.x = np.random.uniform(-2, 2,
[4, 256, 22, 22]).astype(self.dtype)
support_types = get_xpu_op_support_types('mish')
for stype in support_types:
create_test_class(globals(), XPUTestMishOP, stype)
def ref_mish(x, threshold=20):
sp = np.select([x <= threshold, x > threshold], [np.log(1 + np.exp(x)), x])
out = x * np.tanh(sp)
return out
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
from op_test_xpu import OpTest, XPUOpTest from op_test_xpu import OpTest, XPUOpTest
import paddle import paddle
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def dmc_bilinear(data_im, height, width, h, w): def dmc_bilinear(data_im, height, width, h, w):
...@@ -111,11 +112,18 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param): ...@@ -111,11 +112,18 @@ def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
return out return out
class TestModulatedDeformableConvOp(XPUOpTest): class XPUTestModulatedDeformableConvOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'deformable_conv'
self.use_dynamic_create_class = False
class TestModulatedDeformableConvOp(XPUOpTest):
def setUp(self): def setUp(self):
self.op_type = "deformable_conv" self.op_type = "deformable_conv"
self.dtype = np.float32 self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
self.init_test_case() self.init_test_case()
...@@ -150,22 +158,16 @@ class TestModulatedDeformableConvOp(XPUOpTest): ...@@ -150,22 +158,16 @@ class TestModulatedDeformableConvOp(XPUOpTest):
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda() and (self.use_cudnn
or self.use_cuda)
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_output_with_place(self.place)
self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_grad_with_place(
self.check_grad_with_place(place, self.place, {'Input', 'Offset', 'Mask', 'Filter'},
{'Input', 'Offset', 'Mask', 'Filter'},
'Output', 'Output',
max_relative_error=0.06) max_relative_error=0.06)
...@@ -184,10 +186,12 @@ class TestModulatedDeformableConvOp(XPUOpTest): ...@@ -184,10 +186,12 @@ class TestModulatedDeformableConvOp(XPUOpTest):
mask_c = self.deformable_groups * self.filter_size[ mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3] 2] * self.filter_size[3]
self.offset_size = [ self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] self.input_size[0], offset_c, self.input_size[2],
self.input_size[3]
] ]
self.mask_size = [ self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] self.input_size[0], mask_c, self.input_size[2],
self.input_size[3]
] ]
def init_dilation(self): def init_dilation(self):
...@@ -196,8 +200,7 @@ class TestModulatedDeformableConvOp(XPUOpTest): ...@@ -196,8 +200,7 @@ class TestModulatedDeformableConvOp(XPUOpTest):
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
class TestWithDilation(TestModulatedDeformableConvOp):
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self): def init_test_case(self):
self.pad = [2, 2] self.pad = [2, 2]
...@@ -213,17 +216,18 @@ class TestWithDilation(TestModulatedDeformableConvOp): ...@@ -213,17 +216,18 @@ class TestWithDilation(TestModulatedDeformableConvOp):
mask_c = self.deformable_groups * self.filter_size[ mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3] 2] * self.filter_size[3]
self.offset_size = [ self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] self.input_size[0], offset_c, self.input_size[2],
self.input_size[3]
] ]
self.mask_size = [ self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] self.input_size[0], mask_c, self.input_size[2],
self.input_size[3]
] ]
def init_dilation(self): def init_dilation(self):
self.dilations = [2, 2] self.dilations = [2, 2]
class TestWith3x3(TestModulatedDeformableConvOp):
class TestWith3x3(TestModulatedDeformableConvOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
...@@ -239,14 +243,15 @@ class TestWith3x3(TestModulatedDeformableConvOp): ...@@ -239,14 +243,15 @@ class TestWith3x3(TestModulatedDeformableConvOp):
mask_c = self.deformable_groups * self.filter_size[ mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3] 2] * self.filter_size[3]
self.offset_size = [ self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3] self.input_size[0], offset_c, self.input_size[2],
self.input_size[3]
] ]
self.mask_size = [ self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3] self.input_size[0], mask_c, self.input_size[2],
self.input_size[3]
] ]
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self): def test_error(self):
...@@ -287,5 +292,9 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase): ...@@ -287,5 +292,9 @@ class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
self.assertRaises(TypeError, test_invalid_offset) self.assertRaises(TypeError, test_invalid_offset)
support_types = get_xpu_op_support_types('deformable_conv')
for stype in support_types:
create_test_class(globals(), XPUTestModulatedDeformableConvOp, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册