From d3e06a51270baea1332eb996b146a51773c372c9 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Wed, 15 Sep 2021 14:29:57 +0800 Subject: [PATCH] [NPU] fix depthwise_conv2d_grad, test=develop (#35626) * [NPU] fix depthwise_conv2d_grad, test=develop * remove debug files, test=develop --- paddle/fluid/operators/conv_op_npu.cc | 186 +++++------ python/paddle/fluid/dygraph/nn.py | 8 + python/paddle/fluid/layers/nn.py | 7 + .../npu/test_conv2d_op_depthwise_conv_npu.py | 289 +++++++----------- .../tests/unittests/npu/test_conv2d_op_npu.py | 30 +- python/paddle/nn/functional/conv.py | 14 + 6 files changed, 257 insertions(+), 277 deletions(-) diff --git a/paddle/fluid/operators/conv_op_npu.cc b/paddle/fluid/operators/conv_op_npu.cc index 5fc39b5fb4d..86724e06975 100644 --- a/paddle/fluid/operators/conv_op_npu.cc +++ b/paddle/fluid/operators/conv_op_npu.cc @@ -19,33 +19,26 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using NPUDeviceContext = platform::NPUDeviceContext; -template +template class DepthwiseConvNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - // input - const Tensor* input = context.Input("Input"); - const Tensor* filter = context.Input("Filter"); - // output - Tensor* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); - // attr - const std::vector stride = context.Attr>("strides"); - std::vector padding = context.Attr>("paddings"); - std::vector dilation = context.Attr>("dilations"); - const std::string data_format = context.Attr("data_format"); - const std::string padding_algorithm = - context.Attr("padding_algorithm"); + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); - // npu stream - auto stream = - context.template device_context().stream(); + const std::vector stride = ctx.Attr>("strides"); + std::vector padding = ctx.Attr>("paddings"); + std::vector dilation = ctx.Attr>("dilations"); + const std::string data_format = ctx.Attr("data_format"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); - // check dimension const bool channel_last = data_format == "NHWC"; if (channel_last) { - // NHWC PADDLE_ENFORCE_EQ( output->dims()[output->dims().size() - 1], input->dims()[input->dims().size() - 1], @@ -56,7 +49,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { output->dims()[output->dims().size() - 1], input->dims()[input->dims().size() - 1])); } else { - // NCHW PADDLE_ENFORCE_EQ( output->dims()[1], input->dims()[1], platform::errors::InvalidArgument( @@ -66,7 +58,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { output->dims()[1], input->dims()[1])); } - // update padding and dilation auto in_dims = input->dims(); auto filter_dims = filter->dims(); framework::DDim in_data_dims; @@ -83,17 +74,6 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, in_data_dims, stride, ksize); - // Transform filter (n, 1, h, w) --> (1, n, h, w) - Tensor transformed_filter(filter->type()); - transformed_filter.mutable_data({filter->dims()[1], filter->dims()[0], - filter->dims()[2], filter->dims()[3]}, - context.device_context().GetPlace()); - std::vector perm = {1, 0, 2, 3}; - const auto& runner_trans = NpuOpRunner( - "TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}}); - runner_trans.Run(stream); - - // construct NPU attr std::vector strides(4, 1); std::vector dilations(4, 1); @@ -115,7 +95,18 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { dilations[3] = dilation[1]; } - // CANN OP + auto stream = ctx.template device_context().stream(); + + // Transform filter (n, 1, h, w) --> (1, n, h, w) + Tensor transformed_filter(filter->type()); + transformed_filter.mutable_data({filter->dims()[1], filter->dims()[0], + filter->dims()[2], filter->dims()[3]}, + ctx.device_context().GetPlace()); + std::vector perm = {1, 0, 2, 3}; + const auto& runner_trans = NpuOpRunner( + "TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}}); + runner_trans.Run(stream); + const auto& runner = NpuOpRunner("DepthwiseConv2D", {input_tensor, transformed_filter}, {output_tensor}, {{"strides", strides}, @@ -129,27 +120,20 @@ class DepthwiseConvNPUKernel : public framework::OpKernel { template class DepthwiseConvGradNPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - // input - const Tensor* input = context.Input("Input"); - const Tensor* filter = context.Input("Filter"); - // output - auto output_grad = context.Input(framework::GradVarName("Output")); - auto input_grad = context.Output(framework::GradVarName("Input")); - auto filter_grad = context.Output(framework::GradVarName("Filter")); - // attr - const std::vector stride = context.Attr>("strides"); - std::vector padding = context.Attr>("paddings"); - std::vector dilation = context.Attr>("dilations"); - const std::string data_format = context.Attr("data_format"); - const std::string padding_algorithm = - context.Attr("padding_algorithm"); + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor* filter = ctx.Input("Filter"); + auto output_grad = ctx.Input(framework::GradVarName("Output")); + auto input_grad = ctx.Output(framework::GradVarName("Input")); + auto filter_grad = ctx.Output(framework::GradVarName("Filter")); - // npu stream - auto stream = - context.template device_context().stream(); + const std::vector stride = ctx.Attr>("strides"); + std::vector padding = ctx.Attr>("paddings"); + std::vector dilation = ctx.Attr>("dilations"); + const std::string data_format = ctx.Attr("data_format"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); - // check dimension const bool channel_last = data_format == "NHWC"; // update padding and dilation @@ -169,11 +153,13 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel { UpdatePaddingAndDilation(&padding, &dilation, padding_algorithm, in_data_dims, stride, ksize); + auto stream = ctx.template device_context().stream(); + // Transform filter (n, 1, h, w) --> (1, n, h, w) Tensor transformed_filter(filter->type()); transformed_filter.mutable_data({filter->dims()[1], filter->dims()[0], filter->dims()[2], filter->dims()[3]}, - context.device_context().GetPlace()); + ctx.device_context().GetPlace()); std::vector perm = {1, 0, 2, 3}; const auto& runner_trans = NpuOpRunner( "TransposeD", {*filter}, {transformed_filter}, {{"perm", perm}}); @@ -200,39 +186,52 @@ class DepthwiseConvGradNPUKernel : public framework::OpKernel { dilations[3] = dilation[1]; } + // LOG(INFO) << "strides = " << framework::make_ddim(strides).to_str(); + // LOG(INFO) << "dilations = " << framework::make_ddim(dilations).to_str(); + // LOG(INFO) << "padding = " << framework::make_ddim(padding).to_str(); + // LOG(INFO) << "data_format = " << data_format; + if (filter_grad) { - filter_grad->mutable_data(context.GetPlace()); - std::vector filter_shape_vec = - framework::vectorize(transformed_filter.dims()); + filter_grad->mutable_data(ctx.GetPlace()); - const auto& runner = NpuOpRunner( - "DepthwiseConv2DBackpropFilterD", {input_tensor, output_grad_tensor}, - {*filter_grad}, {{"filter_size", filter_shape_vec}, - {"strides", strides}, - {"pads", padding}, - {"dilations", dilations}, - {"data_format", data_format}}); - runner.Run(stream); + PADDLE_ENFORCE_EQ( + (dilations[2] == 1 && dilations[3] == 1), true, + platform::errors::InvalidArgument( + "dilation_h and dilation_w in DepthwiseConv2DBackpropFilterD " + "must be equal to 1, but got dilation_h %d, dilation_w %d", + dilation[2], dilation[3])); + + NpuOpRunner runner; + runner.SetType("DepthwiseConv2DBackpropFilterD") + .AddInput(input_tensor) + .AddInput(output_grad_tensor) + .AddOutput(*filter_grad) + .AddAttr("filter_size", + framework::vectorize(transformed_filter.dims())) + .AddAttr("strides", strides) + .AddAttr("dilations", dilations) + .AddAttr("pads", padding) + .AddAttr("data_format", data_format) + .Run(stream); } if (input_grad) { - input_grad->mutable_data(context.GetPlace()); - std::vector input_shape_vec = - framework::vectorize(input->dims()); - + input_grad->mutable_data(ctx.GetPlace()); Tensor input_grad_tensor; input_grad_tensor.ShareDataWith(*input_grad); if (channel_last) { input_grad_tensor.set_layout(DataLayout::kNHWC); } - const auto& runner = - NpuOpRunner("DepthwiseConv2DBackpropInputD", - {transformed_filter, output_grad_tensor}, - {input_grad_tensor}, {{"input_size", input_shape_vec}, - {"strides", strides}, - {"pads", padding}, - {"dilations", dilations}, - {"data_format", data_format}}); - runner.Run(stream); + NpuOpRunner runner; + runner.SetType("DepthwiseConv2DBackpropInputD") + .AddInput(transformed_filter) + .AddInput(output_grad_tensor) + .AddOutput(input_grad_tensor) + .AddAttr("input_size", framework::vectorize(input->dims())) + .AddAttr("strides", strides) + .AddAttr("dilations", dilations) + .AddAttr("pads", padding) + .AddAttr("data_format", data_format) + .Run(stream); } } }; @@ -241,7 +240,6 @@ template class NPUConvOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); const Tensor* input = ctx.Input("Input"); auto* filter = ctx.Input("Filter"); auto* output = ctx.Output("Output"); @@ -293,6 +291,7 @@ class NPUConvOpKernel : public framework::OpKernel { dilations_vec[3] = dilations[1]; } + auto stream = ctx.template device_context().stream(); const auto& runner = NpuOpRunner("Conv2D", {input_tensor, *filter}, {output_tensor}, {{"strides", strides_vec}, @@ -300,7 +299,7 @@ class NPUConvOpKernel : public framework::OpKernel { {"dilations", dilations_vec}, {"groups", groups}, {"data_format", data_format}}); - runner.Run(dev_ctx.stream()); + runner.Run(stream); } }; @@ -308,8 +307,6 @@ template class NPUConvGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - auto input = ctx.Input("Input"); auto filter = ctx.Input("Filter"); auto output_grad = ctx.Input(framework::GradVarName("Output")); @@ -363,6 +360,7 @@ class NPUConvGradOpKernel : public framework::OpKernel { dilations_vec[3] = dilations[1]; } + auto stream = ctx.template device_context().stream(); if (filter_grad) { filter_grad->mutable_data(ctx.GetPlace()); std::vector filter_shape_vec = @@ -376,7 +374,7 @@ class NPUConvGradOpKernel : public framework::OpKernel { {"dilations", dilations_vec}, {"groups", groups}, {"data_format", data_format}}); - runner.Run(dev_ctx.stream()); + runner.Run(stream); } if (input_grad) { input_grad->mutable_data(ctx.GetPlace()); @@ -396,7 +394,7 @@ class NPUConvGradOpKernel : public framework::OpKernel { {"dilations", dilations_vec}, {"groups", groups}, {"data_format", data_format}}); - runner.Run(dev_ctx.stream()); + runner.Run(stream); } } }; @@ -404,15 +402,17 @@ class NPUConvGradOpKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(depthwise_conv2d, ops::DepthwiseConvNPUKernel, + ops::DepthwiseConvNPUKernel); + +REGISTER_OP_NPU_KERNEL(depthwise_conv2d_grad, + ops::DepthwiseConvGradNPUKernel, + ops::DepthwiseConvGradNPUKernel); -REGISTER_OP_NPU_KERNEL( - depthwise_conv2d, - ops::DepthwiseConvNPUKernel); -REGISTER_OP_NPU_KERNEL( - depthwise_conv2d_grad, - ops::DepthwiseConvGradNPUKernel); REGISTER_OP_NPU_KERNEL(conv2d, ops::NPUConvOpKernel, - ops::NPUConvOpKernel); + ops::NPUConvOpKernel); + REGISTER_OP_NPU_KERNEL(conv2d_grad, ops::NPUConvGradOpKernel, - ops::NPUConvGradOpKernel); + ops::NPUConvGradOpKernel); diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index d9a431990c1..3703da08dea 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -203,6 +203,14 @@ class Conv2D(layers.Layer): else: self._l_type = 'conv2d' + # NPU only supports depthwise_conv2d when "input_channel = output_channel = groups" + if core.is_compiled_with_npu(): + if (self._num_channels == self._groups and + self._num_channels == self._num_filters): + l_type = 'depthwise_conv2d' + else: + l_type = 'conv2d' + self._num_channels = num_channels if self._groups is None: num_filter_channels = self._num_channels diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0cb348c5dd3..534ebf231a1 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1547,6 +1547,13 @@ def conv2d(input, core.is_compiled_with_rocm()): l_type = 'depthwise_conv2d' + # NPU only supports depthwise_conv2d when "input_channel = output_channel = groups" + if core.is_compiled_with_npu(): + if (num_channels == groups and num_channels == num_filters): + l_type = 'depthwise_conv2d' + else: + l_type = 'conv2d' + helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype() diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py index d1c1e80c218..012a6e59e77 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_depthwise_conv_npu.py @@ -20,7 +20,7 @@ import paddle import paddle.fluid as fluid import sys sys.path.append("..") -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest from test_conv2d_op import conv2d_forward_naive from paddle import ParamAttr from paddle.regularizer import L2Decay @@ -66,14 +66,22 @@ def create_test_padding_VALID_class(parent): globals()[cls_name] = TestPaddingVALIDCase -@skip_check_grad_ci( - reason='''Inference only, it doesn't need to call check_grad.''') +def create_test_fp16_class(parent): + class TestFp16Case(parent): + def init_data_type(self): + self.dtype = np.float16 + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestFp16Case.__name__ = cls_name + globals()[cls_name] = TestFp16Case + + class TestDepthwiseConvNPU(OpTest): def setUp(self): - self.op_type = "depthwise_conv2d" - self.dtype = np.float16 self.set_npu() + self.op_type = "depthwise_conv2d" self.init_data_format() + self.init_data_type() self.init_test_case() self.init_test_case_2() @@ -114,18 +122,52 @@ class TestDepthwiseConvNPU(OpTest): self.pad = [1, 1] self.dilations = [1, 1] self.stride = [2, 2] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, atol=1e-2) + + def test_check_grad(self): + if self.dtype == np.float16: + return + if self.dilations[0] == 1 and self.dilations[1] == 1: + self.check_grad_with_place( + self.place, {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + no_grad_set=set(['Filter']), + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + if self.dilations[0] == 1 and self.dilations[1] == 1: + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + no_grad_set=set(['Input']), + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) def init_data_format(self): self.data_format = "NCHW" + def init_data_type(self): + self.dtype = np.float32 + def init_test_case_2(self): pass @@ -135,45 +177,44 @@ class TestDepthwiseConvNPU2(TestDepthwiseConvNPU): self.pad = [1, 1] self.dilations = [1, 1] self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] class TestDepthwiseConvNPU3(TestDepthwiseConvNPU): def init_test_case(self): self.pad = [1, 1] - self.dilations = [2, 2] + self.dilations = [1, 1] self.stride = [2, 2] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] class TestDepthwiseConvNPU4(TestDepthwiseConvNPU): def init_test_case(self): self.pad = [1, 1] - self.dilations = [2, 2] + self.dilations = [1, 1] self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] -@skip_check_grad_ci( - reason='''Inference only, it doesn't need to call check_grad.''') class TestDepthwiseConvNPU_Padding(OpTest): def setUp(self): self.op_type = "depthwise_conv2d" - self.dtype = np.float16 + self.dtype = np.float32 self.set_npu() self.init_data_format() + self.init_data_type() self.init_paddings() self.init_test_case() self.init_test_case_2() @@ -215,18 +256,50 @@ class TestDepthwiseConvNPU_Padding(OpTest): self.pad = [1, 1, 0, 1] self.dilations = [1, 1] self.stride = [2, 2] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, atol=1e-2) + + def test_check_grad(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, {'Input', 'Filter'}, + 'Output', + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter']), + numeric_place=paddle.CPUPlace()) + + def test_check_grad_no_input(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Filter'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input']), + numeric_place=paddle.CPUPlace()) def init_data_format(self): self.data_format = "NCHW" + def init_data_type(self): + self.dtype = np.float32 + def init_paddings(self): self.pad = [1, 1, 0, 1] self.padding_algorithm = "EXPLICIT" @@ -240,11 +313,11 @@ class TestDepthwiseConvNPU2_Padding(TestDepthwiseConvNPU_Padding): self.pad = [1, 1, 0, 1] self.dilations = [1, 1] self.stride = [1, 1] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] def init_paddings(self): self.pad = [0, 1, 0, 2] @@ -256,11 +329,11 @@ class TestDepthwiseConvNPU3_Padding(TestDepthwiseConvNPU_Padding): self.pad = [1, 1, 0, 1] self.dilations = [1, 1] self.stride = [2, 2] - self.input_size = [2, 3, 5, 5] # NCHW - self.groups = 3 + self.input_size = [2, 12, 5, 5] # NCHW + self.groups = 12 assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups - self.filter_size = [3, f_c, 3, 3] + self.filter_size = [12, f_c, 3, 3] def init_paddings(self): self.pad = [2, 1, 2, 3] @@ -283,151 +356,11 @@ create_test_padding_VALID_class(TestDepthwiseConvNPU_Padding) create_test_padding_VALID_class(TestDepthwiseConvNPU2_Padding) create_test_padding_VALID_class(TestDepthwiseConvNPU3_Padding) - -class TestDepthwiseConvNet(unittest.TestCase): - def __init__(self, methodName='runTest'): - super().__init__(methodName=methodName) - - def _test(self, run_npu=True): - main_prog = paddle.static.Program() - startup_prog = paddle.static.Program() - main_prog.random_seed = SEED - startup_prog.random_seed = SEED - np.random.seed(SEED) - - a_np = np.random.random(size=(2, 4, 16, 16)).astype('float16') - b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16') - if not run_npu: - a_np = a_np.astype('float32') - b_np = b_np.astype('float32') - label_np = np.random.randint(10, size=(2, 10)).astype('float32') - with paddle.static.program_guard(main_prog, startup_prog): - if run_npu: - a = paddle.static.data( - name="a", shape=[2, 4, 16, 16], dtype='float16') - b = paddle.static.data( - name="b", shape=[4, 1, 3, 3], dtype='float16') - else: - a = paddle.static.data( - name="a", shape=[2, 4, 16, 16], dtype='float32') - b = paddle.static.data( - name="b", shape=[4, 1, 3, 3], dtype='float32') - label = paddle.static.data( - name="label", shape=[2, 10], dtype='float32') - - a *= 2.0 - b += 0.01 - fc_1 = paddle.nn.functional.conv2d(a, b, bias=None, groups=4) - if run_npu: - fc_1 = paddle.cast(fc_1, dtype='float32') - fc_1 = paddle.nn.functional.relu(fc_1) - prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax') - - cost = paddle.nn.functional.smooth_l1_loss( - input=prediction, label=label) - loss = paddle.sum(cost) - sgd = fluid.optimizer.SGD(learning_rate=0.00001) - sgd.minimize(loss) - - if run_npu: - place = paddle.NPUPlace(0) - else: - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_prog) - - print("Start run on {}".format(place)) - for epoch in range(100): - - pred_res, loss_res = exe.run( - main_prog, - feed={"a": a_np, - "b": b_np, - "label": label_np}, - fetch_list=[prediction, loss]) - - return pred_res, loss_res - - def test_npu(self): - cpu_pred, cpu_loss = self._test(False) - npu_pred, npu_loss = self._test(True) - - self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03)) - self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03)) - - -class TestDepthwiseConvNet_NHWC(unittest.TestCase): - def __init__(self, methodName='runTest'): - super().__init__(methodName=methodName) - - def _test(self, run_npu=True): - main_prog = paddle.static.Program() - startup_prog = paddle.static.Program() - main_prog.random_seed = SEED - startup_prog.random_seed = SEED - np.random.seed(SEED) - - a_np = np.random.random(size=(2, 16, 16, 4)).astype('float16') - b_np = np.random.random(size=(4, 1, 3, 3)).astype('float16') - if not run_npu: - a_np = a_np.astype('float32') - b_np = b_np.astype('float32') - label_np = np.random.randint(10, size=(2, 10)).astype('float32') - with paddle.static.program_guard(main_prog, startup_prog): - if run_npu: - a = paddle.static.data( - name="a", shape=[2, 16, 16, 4], dtype='float16') - b = paddle.static.data( - name="b", shape=[4, 1, 3, 3], dtype='float16') - else: - a = paddle.static.data( - name="a", shape=[2, 16, 16, 4], dtype='float32') - b = paddle.static.data( - name="b", shape=[4, 1, 3, 3], dtype='float32') - label = paddle.static.data( - name="label", shape=[2, 10], dtype='float32') - - a *= 2.0 - b += 0.01 - fc_1 = paddle.nn.functional.conv2d( - a, b, bias=None, groups=4, data_format='NHWC') - if run_npu: - fc_1 = paddle.cast(fc_1, dtype='float32') - fc_1 = paddle.nn.functional.relu(fc_1) - prediction = fluid.layers.fc(input=fc_1, size=10, act='softmax') - - cost = paddle.nn.functional.smooth_l1_loss( - input=prediction, label=label) - loss = paddle.sum(cost) - sgd = fluid.optimizer.SGD(learning_rate=0.00001) - sgd.minimize(loss) - - if run_npu: - place = paddle.NPUPlace(0) - else: - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - exe.run(startup_prog) - - print("Start run on {}".format(place)) - for epoch in range(100): - - pred_res, loss_res = exe.run( - main_prog, - feed={"a": a_np, - "b": b_np, - "label": label_np}, - fetch_list=[prediction, loss]) - - return pred_res, loss_res - - def test_npu(self): - cpu_pred, cpu_loss = self._test(False) - npu_pred, npu_loss = self._test(True) - - self.assertTrue(np.allclose(npu_pred, cpu_pred, rtol=1e-04, atol=1e-03)) - self.assertTrue(np.allclose(npu_loss, cpu_loss, rtol=1e-04, atol=1e-03)) - +create_test_fp16_class(TestDepthwiseConvNPU) +create_test_fp16_class(TestDepthwiseConvNPU2) +create_test_fp16_class(TestDepthwiseConvNPU_Padding) +create_test_fp16_class(TestDepthwiseConvNPU2_Padding) +create_test_fp16_class(TestDepthwiseConvNPU3_Padding) if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py index dff7438702d..d0dc86055a1 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_conv2d_op_npu.py @@ -127,24 +127,33 @@ class TestConv2DOp(OpTest): self.check_output_with_place(fluid.NPUPlace(0), atol=1e-2) def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad_with_place( fluid.NPUPlace(0), {'Input', 'Filter'}, 'Output', - max_relative_error=0.03) + max_relative_error=0.03, + numeric_place=paddle.CPUPlace()) def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return self.check_grad_with_place( fluid.NPUPlace(0), ['Input'], 'Output', max_relative_error=0.03, - no_grad_set=set(['Filter'])) + no_grad_set=set(['Filter']), + numeric_place=paddle.CPUPlace()) def test_check_grad_no_input(self): + if self.dtype == np.float16: + return self.check_grad_with_place( fluid.NPUPlace(0), ['Filter'], 'Output', max_relative_error=0.03, - no_grad_set=set(['Input'])) + no_grad_set=set(['Input']), + numeric_place=paddle.CPUPlace()) def init_test_case(self): self.pad = [0, 0] @@ -310,23 +319,32 @@ class TestConv2DOp_v2(OpTest): self.check_output_with_place(paddle.NPUPlace(0), atol=1e-2) def test_check_grad(self): + if self.dtype == np.float16: + return self.check_grad_with_place( paddle.NPUPlace(0), {'Input', 'Filter'}, 'Output', - max_relative_error=0.02) + max_relative_error=0.02, + numeric_place=paddle.CPUPlace()) def test_check_grad_no_filter(self): + if self.dtype == np.float16: + return self.check_grad_with_place( paddle.NPUPlace(0), ['Input'], 'Output', max_relative_error=0.02, - no_grad_set=set(['Filter'])) + no_grad_set=set(['Filter']), + numeric_place=paddle.CPUPlace()) def test_check_grad_no_input(self): + if self.dtype == np.float16: + return self.check_grad_with_place( paddle.NPUPlace(0), ['Filter'], 'Output', - no_grad_set=set(['Input'])) + no_grad_set=set(['Input']), + numeric_place=paddle.CPUPlace()) def init_test_case(self): self.pad = [0, 0] diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index fcf6f1cdac4..c124ed003d7 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -344,6 +344,13 @@ def conv1d(x, l_type = 'depthwise_conv2d' use_cudnn = False + # NPU only supports depthwise_conv2d when "input_channel = output_channel = groups" + if core.is_compiled_with_npu(): + if (num_channels == groups and num_channels == num_filters): + l_type = 'depthwise_conv2d' + else: + l_type = 'conv2d' + squeeze_aixs = -2 if channel_last else -1 x = unsqueeze(x, axis=[squeeze_aixs]) weight = unsqueeze(weight, axis=[-1]) @@ -562,6 +569,13 @@ def conv2d(x, else: use_cudnn = False + # NPU only supports depthwise_conv2d when "input_channel = output_channel = groups" + if core.is_compiled_with_npu(): + if (num_channels == groups and num_channels == num_filters): + l_type = 'depthwise_conv2d' + else: + l_type = 'conv2d' + if (core.is_compiled_with_cuda() and get_flags("FLAGS_conv2d_disable_cudnn") ["FLAGS_conv2d_disable_cudnn"]): use_cudnn = False -- GitLab