未验证 提交 276017bb 编写于 作者: Z zhangyikun02 提交者: GitHub

conv2d support FP16 on xpu and update unittest for conv2d, test=kunlun (#40395)

上级 1eb96eec
......@@ -19,14 +19,16 @@ namespace operators {
template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
Tensor *output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize,
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *input = context.Input<Tensor>("Input");
const Tensor *output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
Tensor *input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
Tensor *filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation,
......@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
const XPUT *output_grad_data =
reinterpret_cast<const XPUT *>(output_grad->data<T>());
XPUT *input_grad_data = nullptr;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
}
XPUT *filter_grad_data = nullptr;
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(),
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_grad_data,
input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
......@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() {
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
......@@ -23,6 +23,7 @@ import paddle.fluid as fluid
from op_test_xpu import XPUOpTest
import paddle
from paddle.fluid import Program, program_guard
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def conv2d_forward_naive(input,
......@@ -159,8 +160,15 @@ def create_test_padding_VALID_class(parent):
globals()[cls_name] = TestPaddingVALIDCase
class TestConv2DOp(XPUOpTest):
class XPUTestConv2DOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp(XPUOpTest):
def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
......@@ -168,7 +176,6 @@ class TestConv2DOp(XPUOpTest):
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
self.init_dilation()
......@@ -180,6 +187,7 @@ class TestConv2DOp(XPUOpTest):
'dilation': self.dilations
}
np.random.seed(100)
input = np.random.random(self.input_size).astype(self.dtype)
if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False
......@@ -190,10 +198,12 @@ class TestConv2DOp(XPUOpTest):
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype)
np.random.seed(1)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups,
conv2d_param)
output, _, _, _, _ = conv2d_forward_naive(input2, filter,
self.groups, conv2d_param)
output = output.astype(self.dtype)
self.inputs = {
......@@ -221,37 +231,38 @@ class TestConv2DOp(XPUOpTest):
def test_check_output(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output')
self.check_grad_with_place(self.place, {'Input', 'Filter'},
'Output')
def test_check_grad_no_filter(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter']))
self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input']))
self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
......@@ -273,8 +284,7 @@ class TestConv2DOp(XPUOpTest):
def init_kernel_type(self):
pass
class TestWithPad(TestConv2DOp):
class TestWithPad(TestConv2DOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
......@@ -283,8 +293,7 @@ class TestWithPad(TestConv2DOp):
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
class TestWithStride(TestConv2DOp):
class TestWithStride(TestConv2DOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
......@@ -293,8 +302,7 @@ class TestWithStride(TestConv2DOp):
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
class TestWith1x1(TestConv2DOp):
class TestWith1x1(TestConv2DOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
......@@ -307,24 +315,22 @@ class TestWith1x1(TestConv2DOp):
self.groups = 1
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
# ---- test asymmetric padding ----
class XPUTestConv2DOp_v2(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp_v2(XPUOpTest):
class TestConv2DOp_v2(XPUOpTest):
def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
self.init_dilation()
......@@ -339,6 +345,7 @@ class TestConv2DOp_v2(XPUOpTest):
'dilation': self.dilations
}
np.random.seed(100)
input = np.random.random(self.input_size).astype(self.dtype)
if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False
......@@ -349,10 +356,12 @@ class TestConv2DOp_v2(XPUOpTest):
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype)
np.random.seed(8)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(
input2, filter, self.groups, conv2d_param, self.padding_algorithm,
self.data_format)
input2, filter, self.groups, conv2d_param,
self.padding_algorithm, self.data_format)
output = output.astype(self.dtype)
self.inputs = {
......@@ -382,37 +391,41 @@ class TestConv2DOp_v2(XPUOpTest):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place=self.place)
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output')
self.check_grad_with_place(self.place, {'Input', 'Filter'},
'Output')
def test_check_grad_no_filter(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter']))
self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16:
if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input']))
self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
......@@ -441,14 +454,12 @@ class TestConv2DOp_v2(XPUOpTest):
def init_test_case_2(self):
pass
class TestConv2DOp_AsyPadding(TestConv2DOp_v2):
class TestConv2DOp_AsyPadding(TestConv2DOp_v2):
def init_paddings(self):
self.pad = [0, 0, 0, 0]
self.padding_algorithm = "EXPLICIT"
class TestWithPad_AsyPadding(TestConv2DOp_v2):
class TestWithPad_AsyPadding(TestConv2DOp_v2):
def init_test_case(self):
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
......@@ -460,8 +471,7 @@ class TestWithPad_AsyPadding(TestConv2DOp_v2):
self.pad = [1, 1, 1, 1]
self.padding_algorithm = "EXPLICIT"
class TestWithStride_AsyPadding(TestConv2DOp_v2):
class TestWithStride_AsyPadding(TestConv2DOp_v2):
def init_test_case(self):
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW
......@@ -474,6 +484,11 @@ class TestWithStride_AsyPadding(TestConv2DOp_v2):
self.padding_algorithm = "EXPLICIT"
support_types = get_xpu_op_support_types('conv2d')
for stype in support_types:
create_test_class(globals(), XPUTestConv2DOp, stype)
create_test_class(globals(), XPUTestConv2DOp_v2, stype)
#---------- test SAME VALID -----------
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
#create_test_padding_SAME_class(TestWithPad_AsyPadding)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册