From 7ed7c6c73aa2352d5fb97d88ff2f85e9818b1ad7 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 6 Apr 2022 08:58:59 +0800 Subject: [PATCH] Add conv yaml (#41354) * update * add conv yaml * add backward * remove useless code * fix bug * fix bug * revert fluid dygraph conv2d * remove useless infermeta function * fix meta fn deluplicat error * conv using custom impl * remove amp include * fix bug * use cudnn = true * fix test mkldnn caching bug --- paddle/fluid/operators/conv_op.cc | 6 - paddle/phi/api/lib/api_custom_impl.cc | 208 ++++++++++++++++++ paddle/phi/api/lib/api_custom_impl.h | 26 +++ paddle/phi/kernels/conv_grad_kernel.h | 6 +- paddle/phi/kernels/cpu/conv_grad_kernel.cc | 8 +- paddle/phi/kernels/gpu/conv_grad_kernel.cu | 4 +- .../kernels/gpu/depthwise_conv_grad_kernel.cu | 2 +- paddle/phi/kernels/gpudnn/conv_grad_kernel.cu | 6 +- .../phi/kernels/impl/conv_grad_kernel_impl.h | 2 +- paddle/phi/ops/compat/conv2d_sig.cc | 2 +- paddle/phi/ops/compat/conv3d_sig.cc | 2 +- paddle/phi/ops/compat/depthwise_conv2d_sig.cc | 2 +- python/paddle/fluid/dygraph/nn.py | 12 + .../tests/unittests/test_conv2d_layer.py | 14 +- .../tests/unittests/test_imperative_mnist.py | 1 + .../tests/unittests/test_imperative_resnet.py | 3 +- python/paddle/nn/functional/conv.py | 25 ++- python/paddle/utils/code_gen/api.yaml | 6 + python/paddle/utils/code_gen/backward.yaml | 6 + 19 files changed, 312 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 9be63a85fc..4057947838 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -844,8 +844,6 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( } // namespace paddle namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(conv2d, Conv2dInferShapeFunctor, - PD_INFER_META(phi::ConvInferMeta)); REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, ops::ConvOpInferVarType, ops::Conv2DGradMaker, @@ -856,8 +854,6 @@ REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad, REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad); // depthwise convolution op -DECLARE_INFER_SHAPE_FUNCTOR(depthwise_conv2d, DepthwiseConv2dInferShapeFunctor, - PD_INFER_META(phi::ConvInferMeta)); REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, ops::ConvOpInferVarType, ops::Conv2DGradMaker, @@ -867,8 +863,6 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad, ops::Conv2DDoubleGradMaker); REGISTER_OPERATOR(depthwise_conv2d_grad_grad, ops::ConvOpDoubleGrad); -DECLARE_INFER_SHAPE_FUNCTOR(conv3d, Conv3dInferShapeFunctor, - PD_INFER_META(phi::ConvInferMeta)); REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, ops::ConvOpInferVarType, ops::Conv3DGradMaker, diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 46d09c29bc..8ea9204fa9 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -34,6 +34,213 @@ namespace experimental { ////////////////// Forward api impls ////////////////////// +Tensor conv2d_impl(const Tensor& input, + const Tensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool use_addto, + int workspace_size_MB, + bool exhaustive_search) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + kernel_data_type = ParseDataType(input); + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(input, filter); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + VLOG(6) << "conv2d API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "conv2d", {kernel_backend, kernel_layout, kernel_data_type}, true); + VLOG(6) << "conv2d API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + phi::TensorArgDef args0 = kernel.InputAt(0); + phi::TensorArgDef args1 = kernel.InputAt(1); + if (kernel_backend == Backend::GPU) { + args0.backend = Backend::GPU; + args1.backend = Backend::GPU; + } + + auto input_input = PrepareData(input, args0, {}); + auto input_filter = PrepareData(filter, args1, {}); + + Tensor api_output; + auto kernel_out = SetKernelOutput(kernel_backend, &api_output); + phi::MetaTensor meta_out(kernel_out); + + phi::ConvInferMeta(MakeMetaTensor(*input_input), + MakeMetaTensor(*input_filter), + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + use_addto, + workspace_size_MB, + exhaustive_search, + &meta_out); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + const std::string&, + int, + const std::vector&, + const std::string&, + bool, + int, + bool, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + { + (*kernel_fn)(*dev_ctx, + *input_input, + *input_filter, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + use_addto, + workspace_size_MB, + exhaustive_search, + kernel_out); + } + + return api_output; +} + +std::vector> conv2d_grad_impl( + const Tensor& input, + const Tensor& filter, + const Tensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool use_addto, + int workspace_size_MB, + bool exhaustive_search) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(input, filter, out_grad); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + VLOG(6) << "conv2d_grad API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "conv2d_grad", {kernel_backend, kernel_layout, kernel_data_type}, true); + VLOG(6) << "conv2d_grad API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + phi::TensorArgDef args0 = kernel.InputAt(0); + phi::TensorArgDef args1 = kernel.InputAt(1); + phi::TensorArgDef args2 = kernel.InputAt(2); + if (kernel_backend == Backend::GPU) { + args0.backend = Backend::GPU; + args1.backend = Backend::GPU; + args2.backend = Backend::GPU; + } + + auto input_input = PrepareData(input, args0, {}); + auto input_filter = PrepareData(filter, args1, {}); + auto input_out_grad = PrepareData(out_grad, args2, {}); + + std::vector> api_output(2); + api_output[0].emplace_back(); + auto kernel_out_0 = SetKernelOutput(kernel_backend, &api_output[0][0]); + api_output[1].emplace_back(); + auto kernel_out_1 = SetKernelOutput(kernel_backend, &api_output[1][0]); + phi::MetaTensor meta_out_0(kernel_out_0); + phi::MetaTensor meta_out_1(kernel_out_1); + + phi::GeneralBinaryGradInferMeta(MakeMetaTensor(*input_input), + MakeMetaTensor(*input_filter), + &meta_out_0, + &meta_out_1); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + const std::string&, + int, + const std::vector&, + const std::string&, + bool, + int, + bool, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + { + (*kernel_fn)(*dev_ctx, + *input_input, + *input_filter, + *input_out_grad, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + use_addto, + workspace_size_MB, + exhaustive_search, + kernel_out_0, + kernel_out_1); + } + + return api_output; +} + Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { auto kernel_key_set = ParseKernelKeyByInputArgs(x); kernel_key_set.backend_set = @@ -61,6 +268,7 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { phi::DenseTensor*); auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, *dense_x, place, blocking, kernel_out); return out; diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 15b593238c..91b94fd74c 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -28,6 +28,32 @@ namespace experimental { ////////////////// Forward api impls ////////////////////// +Tensor conv2d_impl(const Tensor& input, + const Tensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool use_addto, + int workspace_size_MB, + bool exhaustive_search); + +std::vector> conv2d_grad_impl( + const Tensor& input, + const Tensor& filter, + const Tensor& out_grad, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + bool use_addto, + int workspace_size_MB, + bool exhaustive_search); + Tensor copy_to_impl(const Tensor& x, Place place, bool blocking); std::vector split_impl(const Tensor& x, diff --git a/paddle/phi/kernels/conv_grad_kernel.h b/paddle/phi/kernels/conv_grad_kernel.h index bad30989ac..a6b970e099 100644 --- a/paddle/phi/kernels/conv_grad_kernel.h +++ b/paddle/phi/kernels/conv_grad_kernel.h @@ -20,9 +20,9 @@ namespace phi { template void ConvGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -37,9 +37,9 @@ void ConvGradKernel(const Context& dev_ctx, template void Conv3DGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -54,9 +54,9 @@ void Conv3DGradKernel(const Context& dev_ctx, template void DepthwiseConvGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, diff --git a/paddle/phi/kernels/cpu/conv_grad_kernel.cc b/paddle/phi/kernels/cpu/conv_grad_kernel.cc index 994ad861bd..2d8a9bf1de 100644 --- a/paddle/phi/kernels/cpu/conv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/conv_grad_kernel.cc @@ -22,9 +22,9 @@ namespace phi { template void DepthwiseConvGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -38,9 +38,9 @@ void DepthwiseConvGradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvGradKernel(dev_ctx, - out_grad, input, filter, + out_grad, strides, paddings, paddding_algorithm, @@ -56,9 +56,9 @@ void DepthwiseConvGradKernel(const Context& dev_ctx, template void Conv3DGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -71,9 +71,9 @@ void Conv3DGradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvGradKernel(dev_ctx, - out_grad, input, filter, + out_grad, strides, paddings, paddding_algorithm, diff --git a/paddle/phi/kernels/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/gpu/conv_grad_kernel.cu index 4df7bb26ad..677ec4a062 100644 --- a/paddle/phi/kernels/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/conv_grad_kernel.cu @@ -22,9 +22,9 @@ namespace phi { template void Conv3DGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -37,9 +37,9 @@ void Conv3DGradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvGradKernel(dev_ctx, - out_grad, input, filter, + out_grad, strides, paddings, paddding_algorithm, diff --git a/paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu b/paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu index 4f27b6fde9..5fc5482a08 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu @@ -24,9 +24,9 @@ namespace phi { template void DepthwiseConvGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides_t, const std::vector& paddings_t, const std::string& padding_algorithm, diff --git a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu index a99a1e5f94..e09c33380b 100644 --- a/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/conv_grad_kernel.cu @@ -43,9 +43,9 @@ namespace phi { template void ConvCudnnGradKernel(const Context& ctx, - const DenseTensor& output_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& output_grad, const std::vector& strides_t, const std::vector& paddings_t, const std::string& padding_algorithm, @@ -595,9 +595,9 @@ void ConvCudnnGradKernel(const Context& ctx, template void Conv3DCudnnGradKernel(const Context& dev_ctx, - const DenseTensor& out_grad, const DenseTensor& input, const DenseTensor& filter, + const DenseTensor& out_grad, const std::vector& strides, const std::vector& paddings, const std::string& paddding_algorithm, @@ -610,9 +610,9 @@ void Conv3DCudnnGradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* filter_grad) { ConvCudnnGradKernel(dev_ctx, - out_grad, input, filter, + out_grad, strides, paddings, paddding_algorithm, diff --git a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h index 2deebb996a..6674500c3c 100644 --- a/paddle/phi/kernels/impl/conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/conv_grad_kernel_impl.h @@ -26,9 +26,9 @@ namespace phi { template void ConvGradKernel(const Context& dev_ctx, - const DenseTensor& output_grad, const DenseTensor& input, const DenseTensor& filter_t, + const DenseTensor& output_grad, const std::vector& strides, const std::vector& paddings_t, const std::string& padding_algorithm, diff --git a/paddle/phi/ops/compat/conv2d_sig.cc b/paddle/phi/ops/compat/conv2d_sig.cc index 67b99f1dd6..19e20fddcb 100644 --- a/paddle/phi/ops/compat/conv2d_sig.cc +++ b/paddle/phi/ops/compat/conv2d_sig.cc @@ -46,7 +46,7 @@ KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("conv2d_grad", - {GradVarName("Output"), "Input", "Filter"}, + {"Input", "Filter", GradVarName("Output")}, {"strides", "paddings", "padding_algorithm", diff --git a/paddle/phi/ops/compat/conv3d_sig.cc b/paddle/phi/ops/compat/conv3d_sig.cc index a036afac82..b24c08b60c 100644 --- a/paddle/phi/ops/compat/conv3d_sig.cc +++ b/paddle/phi/ops/compat/conv3d_sig.cc @@ -33,7 +33,7 @@ KernelSignature Conv3dOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature Conv3dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("conv2d_grad", - {GradVarName("Output"), "Input", "Filter"}, + {"Input", "Filter", GradVarName("Output")}, {"strides", "paddings", "padding_algorithm", diff --git a/paddle/phi/ops/compat/depthwise_conv2d_sig.cc b/paddle/phi/ops/compat/depthwise_conv2d_sig.cc index e2b6801f73..d2d7451eca 100644 --- a/paddle/phi/ops/compat/depthwise_conv2d_sig.cc +++ b/paddle/phi/ops/compat/depthwise_conv2d_sig.cc @@ -36,7 +36,7 @@ KernelSignature DepthwiseConv2dOpArgumentMapping( KernelSignature DepthwiseConv2dGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("depthwise_conv2d_grad", - {GradVarName("Output"), "Input", "Filter"}, + {"Input", "Filter", GradVarName("Output")}, {"strides", "paddings", "padding_algorithm", diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 0ae3cf6ba2..df6af698ab 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -240,6 +240,18 @@ class Conv2D(layers.Layer): is_bias=True) def forward(self, input): + if in_dygraph_mode() and self._l_type == "conv2d": + pre_bias = _C_ops.final_state_conv2d( + input, self.weight, self._stride, self._padding, "EXPLICIT", + self._groups if self._groups else 1, self._dilation, "NCHW", + False, -1, False) + if self.bias is not None: + pre_act = F.elementwise_add(pre_bias, self.bias, axis=1) + else: + pre_act = pre_bias + return dygraph_utils._append_activation_in_dygraph( + pre_act, self._act, use_mkldnn=self._use_mkldnn) + if _non_static_mode() and (self._l_type == 'conv2d' or self._l_type == 'depthwise_conv2d'): attrs = ('strides', self._stride, 'paddings', self._padding, diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_layer.py b/python/paddle/fluid/tests/unittests/test_conv2d_layer.py index 892fa649a6..508bd7b1e6 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_layer.py @@ -19,6 +19,7 @@ import paddle.nn.functional as F import paddle.fluid.initializer as I import unittest import paddle +from paddle.fluid.framework import _test_eager_guard def _reverse_repeat_list(t, n): @@ -166,7 +167,8 @@ class Conv2DTestCase(unittest.TestCase): return y_np def paddle_nn_layer(self): - x_var = dg.to_variable(self.input) + x_var = paddle.to_tensor(self.input) + x_var.stop_gradient = False conv = nn.Conv2D( self.num_channels, self.num_filters, @@ -181,17 +183,23 @@ class Conv2DTestCase(unittest.TestCase): if not self.no_bias: conv.bias.set_value(self.bias) y_var = conv(x_var) + y_var.backward() y_np = y_var.numpy() - return y_np + t1 = x_var.gradient() + return y_np, t1 def _test_equivalence(self, place): place = fluid.CPUPlace() result1 = self.fluid_layer(place) result2 = self.functional(place) with dg.guard(place): - result3 = self.paddle_nn_layer() + result3, g1 = self.paddle_nn_layer() + with _test_eager_guard(): + res_eager, g2 = self.paddle_nn_layer() np.testing.assert_array_almost_equal(result1, result2) np.testing.assert_array_almost_equal(result2, result3) + self.assertTrue(np.allclose(result3, res_eager)) + self.assertTrue(np.allclose(g1, g2)) def runTest(self): place = fluid.CPUPlace() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py index 06836ed85a..f9bd5e4597 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py @@ -265,4 +265,5 @@ class TestImperativeMnist(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py index 3a643c5316..e48e75c661 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_resnet.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_resnet.py @@ -164,7 +164,7 @@ class BottleneckBlock(fluid.Layer): class ResNet(fluid.Layer): - def __init__(self, layers=50, class_dim=102, use_cudnn=False): + def __init__(self, layers=50, class_dim=102, use_cudnn=True): super(ResNet, self).__init__() self.layers = layers @@ -438,4 +438,5 @@ class TestDygraphResnet(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 414f5cefff..086ae78919 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -29,6 +29,8 @@ from paddle import get_flags from paddle import in_dynamic_mode from paddle.device import is_compiled_with_cuda from paddle.device import is_compiled_with_npu +from paddle import in_dynamic_mode +from paddle import get_flags from paddle.device import is_compiled_with_rocm from paddle.fluid.framework import _global_flags from paddle.fluid.framework import _in_legacy_dygraph @@ -120,6 +122,15 @@ def _conv_nd(x, name=None): # Due to the poor performance of NHWC, we transpose the input to NCHW. + if in_dygraph_mode() and op_type == "conv2d": + pre_bias = _C_ops.final_state_conv2d( + x, weight, stride, padding, padding_algorithm, groups, dilation, + data_format, False, -1, False) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + return out + else: + return pre_bias if in_dynamic_mode(): attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', @@ -562,8 +573,6 @@ def conv2d(x, use_cudnn = True if (is_compiled_with_cuda() and cudnn_version is not None) else False - use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] - # update attrs padding, padding_algorithm = _update_padding_nd(padding, channel_last, 2) stride = convert_to_list(stride, 2, 'stride') @@ -577,6 +586,18 @@ def conv2d(x, use_cudnn = True else: use_cudnn = False + else: + if in_dygraph_mode(): + pre_bias = _C_ops.final_state_conv2d( + x, weight, stride, padding, padding_algorithm, groups, dilation, + data_format, False, -1, False) + if bias is not None: + out = nn.elementwise_add(pre_bias, bias, axis=channel_dim) + return out + else: + return pre_bias + + use_mkldnn = _global_flags()["FLAGS_use_mkldnn"] # NPU only supports depthwise_conv2d when "input_channel = output_channel = groups" if is_compiled_with_npu(): diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index e5cb8756da..a3e5c3fad7 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -346,6 +346,12 @@ kernel : func : conj +- api : conv2d + args : (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) + output : Tensor + invoke : conv2d_impl(input, filter, strides, paddings, paddding_algorithm, groups, dilations, data_format, use_addto, workspace_size_MB, exhaustive_search) + backward : conv2d_grad + - api : conv2d_transpose args : (Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 875f06cecf..f49b804937 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -218,6 +218,12 @@ output : Tensor[](x_grad) invoke : concat_grad_impl(x, out_grad, axis) +- backward_api : conv2d_grad + forward : conv2d (Tensor input, Tensor filter, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) -> Tensor(out) + args : (Tensor input, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, str paddding_algorithm, int groups, int[] dilations, str data_format, bool use_addto, int workspace_size_MB, bool exhaustive_search) + output : Tensor(input_grad), Tensor(filter_grad) + invoke : conv2d_grad_impl(input, filter, out_grad, strides, paddings, paddding_algorithm, groups, dilations, data_format, use_addto, workspace_size_MB, exhaustive_search) + - backward_api : conv2d_transpose_grad forward : conv2d_transpose(Tensor x, Tensor filter, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -> Tensor(out) args : (Tensor x, Tensor filter, Tensor out_grad, int[] strides, int[] paddings, int[] output_padding, int[] output_size, str padding_algorithm, int groups, int[] dilations, str data_format) -- GitLab