未验证 提交 276017bb 编写于 作者: Z zhangyikun02 提交者: GitHub

conv2d support FP16 on xpu and update unittest for conv2d, test=kunlun (#40395)

上级 1eb96eec
...@@ -19,14 +19,16 @@ namespace operators { ...@@ -19,14 +19,16 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvXPUKernel : public framework::OpKernel<T> { class GemmConvXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
// The filter will be reshaped in the calculations, // The filter will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output"); Tensor *output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups"); int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -53,11 +55,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<float, float, float, int16_t>( const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
dev_ctx.x_context(), input->data<float>(), filter.data<float>(), const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize, XPUT *output_data = reinterpret_cast<XPUT *>(output->data<T>());
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input_data, filter_data, output_data, batch_size,
img_c, img_h, img_w, f, ksize, strides, paddings, dilations, groups,
nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> { ...@@ -67,14 +74,16 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> { class GemmConvGradXPUKernel : public framework::OpKernel<T> {
using XPUT = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor *input = context.Input<Tensor>("Input");
const Tensor* output_grad = const Tensor *output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = Tensor *input_grad =
context.Output<Tensor>(framework::GradVarName("Input")); context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = Tensor *filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
// The filter and filter_grad will be reshaped in the calculations, // The filter and filter_grad will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -107,19 +116,27 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
const int img_h = static_cast<int>(input->dims()[2]); const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]); const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]); const int f = static_cast<int>(filter.dims()[0]);
const XPUT *input_data = reinterpret_cast<const XPUT *>(input->data<T>());
const XPUT *filter_data = reinterpret_cast<const XPUT *>(filter.data<T>());
const XPUT *output_grad_data =
reinterpret_cast<const XPUT *>(output_grad->data<T>());
XPUT *input_grad_data = nullptr;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
input_grad_data = reinterpret_cast<XPUT *>(input_grad->data<T>());
} }
XPUT *filter_grad_data = nullptr;
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
filter_grad_data = reinterpret_cast<XPUT *>(filter_grad->data<T>());
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::conv2d_grad<float, float, float, int16_t>( int r = xpu::conv2d_grad<XPUT, XPUT, XPUT, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(), dev_ctx.x_context(), input_data, filter_data, output_grad_data,
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr, input_grad_data, filter_grad_data, batch_size, img_c, img_h, img_w, f,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]", platform::errors::External("XPU conv kernel return wrong value[%d %s]",
...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> { ...@@ -130,14 +147,22 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d, conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
REGISTER_OP_XPU_KERNEL( paddle::platform::float16>);
conv2d, ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
conv2d_grad, conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad, depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() { ...@@ -51,16 +51,20 @@ XPUOpMap& get_kl2_ops() {
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_transpose_grad", {"conv2d_transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose", {"conv2d_transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"depthwise_conv2d_grad", {"depthwise_conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"depthwise_conv2d", {"depthwise_conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"dropout_grad", {"dropout_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"dropout", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
...@@ -23,6 +23,7 @@ import paddle.fluid as fluid ...@@ -23,6 +23,7 @@ import paddle.fluid as fluid
from op_test_xpu import XPUOpTest from op_test_xpu import XPUOpTest
import paddle import paddle
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def conv2d_forward_naive(input, def conv2d_forward_naive(input,
...@@ -159,8 +160,15 @@ def create_test_padding_VALID_class(parent): ...@@ -159,8 +160,15 @@ def create_test_padding_VALID_class(parent):
globals()[cls_name] = TestPaddingVALIDCase globals()[cls_name] = TestPaddingVALIDCase
class TestConv2DOp(XPUOpTest): class XPUTestConv2DOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp(XPUOpTest):
def setUp(self): def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.op_type = "conv2d" self.op_type = "conv2d"
self.use_cudnn = False self.use_cudnn = False
self.exhaustive_search = False self.exhaustive_search = False
...@@ -168,7 +176,6 @@ class TestConv2DOp(XPUOpTest): ...@@ -168,7 +176,6 @@ class TestConv2DOp(XPUOpTest):
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False self.fuse_relu_before_depthwise_conv = False
self.data_format = "AnyLayout" self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type() self.init_kernel_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
...@@ -180,6 +187,7 @@ class TestConv2DOp(XPUOpTest): ...@@ -180,6 +187,7 @@ class TestConv2DOp(XPUOpTest):
'dilation': self.dilations 'dilation': self.dilations
} }
np.random.seed(100)
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
if not self.has_cuda(): if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False self.fuse_relu_before_depthwise_conv = False
...@@ -190,10 +198,12 @@ class TestConv2DOp(XPUOpTest): ...@@ -190,10 +198,12 @@ class TestConv2DOp(XPUOpTest):
input2 = np.maximum(input, 0.0) input2 = np.maximum(input, 0.0)
else: else:
input2 = input input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) np.random.seed(1)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups, output, _, _, _, _ = conv2d_forward_naive(input2, filter,
conv2d_param) self.groups, conv2d_param)
output = output.astype(self.dtype) output = output.astype(self.dtype)
self.inputs = { self.inputs = {
...@@ -221,37 +231,38 @@ class TestConv2DOp(XPUOpTest): ...@@ -221,37 +231,38 @@ class TestConv2DOp(XPUOpTest):
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_output_with_place(self.place)
self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True): self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_grad_with_place(self.place, {'Input', 'Filter'},
self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output') 'Output')
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True): self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter'])) self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16 or (hasattr(self, "no_need_check_grad") and if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True): self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input'])) self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
...@@ -273,8 +284,7 @@ class TestConv2DOp(XPUOpTest): ...@@ -273,8 +284,7 @@ class TestConv2DOp(XPUOpTest):
def init_kernel_type(self): def init_kernel_type(self):
pass pass
class TestWithPad(TestConv2DOp):
class TestWithPad(TestConv2DOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [1, 1] self.stride = [1, 1]
...@@ -283,8 +293,7 @@ class TestWithPad(TestConv2DOp): ...@@ -283,8 +293,7 @@ class TestWithPad(TestConv2DOp):
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
class TestWithStride(TestConv2DOp):
class TestWithStride(TestConv2DOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [2, 2] self.stride = [2, 2]
...@@ -293,8 +302,7 @@ class TestWithStride(TestConv2DOp): ...@@ -293,8 +302,7 @@ class TestWithStride(TestConv2DOp):
f_c = self.input_size[1] // self.groups f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
class TestWith1x1(TestConv2DOp):
class TestWith1x1(TestConv2DOp):
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
...@@ -307,24 +315,22 @@ class TestWith1x1(TestConv2DOp): ...@@ -307,24 +315,22 @@ class TestWith1x1(TestConv2DOp):
self.groups = 1 self.groups = 1
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
# ---- test asymmetric padding ---- # ---- test asymmetric padding ----
class XPUTestConv2DOp_v2(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'conv2d'
self.use_dynamic_create_class = False
class TestConv2DOp_v2(XPUOpTest):
class TestConv2DOp_v2(XPUOpTest):
def setUp(self): def setUp(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.op_type = "conv2d" self.op_type = "conv2d"
self.use_cudnn = False self.use_cudnn = False
self.exhaustive_search = False self.exhaustive_search = False
self.use_cuda = False self.use_cuda = False
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False self.fuse_relu_before_depthwise_conv = False
self.dtype = np.float32
self.init_kernel_type() self.init_kernel_type()
self.init_group() self.init_group()
self.init_dilation() self.init_dilation()
...@@ -339,6 +345,7 @@ class TestConv2DOp_v2(XPUOpTest): ...@@ -339,6 +345,7 @@ class TestConv2DOp_v2(XPUOpTest):
'dilation': self.dilations 'dilation': self.dilations
} }
np.random.seed(100)
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
if not self.has_cuda(): if not self.has_cuda():
self.fuse_relu_before_depthwise_conv = False self.fuse_relu_before_depthwise_conv = False
...@@ -349,10 +356,12 @@ class TestConv2DOp_v2(XPUOpTest): ...@@ -349,10 +356,12 @@ class TestConv2DOp_v2(XPUOpTest):
input2 = np.maximum(input, 0.0) input2 = np.maximum(input, 0.0)
else: else:
input2 = input input2 = input
filter = np.random.uniform(-1, 1, self.filter_size).astype(self.dtype) np.random.seed(8)
filter = np.random.uniform(-1, 1,
self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive( output, _, _, _, _ = conv2d_forward_naive(
input2, filter, self.groups, conv2d_param, self.padding_algorithm, input2, filter, self.groups, conv2d_param,
self.data_format) self.padding_algorithm, self.data_format)
output = output.astype(self.dtype) output = output.astype(self.dtype)
self.inputs = { self.inputs = {
...@@ -382,37 +391,41 @@ class TestConv2DOp_v2(XPUOpTest): ...@@ -382,37 +391,41 @@ class TestConv2DOp_v2(XPUOpTest):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_output_with_place(place=self.place)
self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16: if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0) self.check_grad_with_place(self.place, {'Input', 'Filter'},
self.check_grad_with_place(place, {'Input', 'Filter'}, 'Output') 'Output')
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16: if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['Input'], 'Output', no_grad_set=set(['Filter'])) self.place, ['Input'],
'Output',
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode # TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.dtype == np.float16: if (hasattr(self, "no_need_check_grad") and
self.no_need_check_grad == True):
return return
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
paddle.enable_static() paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
place, ['Filter'], 'Output', no_grad_set=set(['Input'])) self.place, ['Filter'],
'Output',
no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
...@@ -441,14 +454,12 @@ class TestConv2DOp_v2(XPUOpTest): ...@@ -441,14 +454,12 @@ class TestConv2DOp_v2(XPUOpTest):
def init_test_case_2(self): def init_test_case_2(self):
pass pass
class TestConv2DOp_AsyPadding(TestConv2DOp_v2):
class TestConv2DOp_AsyPadding(TestConv2DOp_v2):
def init_paddings(self): def init_paddings(self):
self.pad = [0, 0, 0, 0] self.pad = [0, 0, 0, 0]
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
class TestWithPad_AsyPadding(TestConv2DOp_v2):
class TestWithPad_AsyPadding(TestConv2DOp_v2):
def init_test_case(self): def init_test_case(self):
self.stride = [1, 1] self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
...@@ -460,8 +471,7 @@ class TestWithPad_AsyPadding(TestConv2DOp_v2): ...@@ -460,8 +471,7 @@ class TestWithPad_AsyPadding(TestConv2DOp_v2):
self.pad = [1, 1, 1, 1] self.pad = [1, 1, 1, 1]
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
class TestWithStride_AsyPadding(TestConv2DOp_v2):
class TestWithStride_AsyPadding(TestConv2DOp_v2):
def init_test_case(self): def init_test_case(self):
self.stride = [2, 2] self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW self.input_size = [2, 3, 6, 6] # NCHW
...@@ -474,6 +484,11 @@ class TestWithStride_AsyPadding(TestConv2DOp_v2): ...@@ -474,6 +484,11 @@ class TestWithStride_AsyPadding(TestConv2DOp_v2):
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
support_types = get_xpu_op_support_types('conv2d')
for stype in support_types:
create_test_class(globals(), XPUTestConv2DOp, stype)
create_test_class(globals(), XPUTestConv2DOp_v2, stype)
#---------- test SAME VALID ----------- #---------- test SAME VALID -----------
#create_test_padding_SAME_class(TestConv2DOp_AsyPadding) #create_test_padding_SAME_class(TestConv2DOp_AsyPadding)
#create_test_padding_SAME_class(TestWithPad_AsyPadding) #create_test_padding_SAME_class(TestWithPad_AsyPadding)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册