From aad0ae2acab6d968be6f888b23d572d0670e9bb5 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 22 Mar 2022 16:41:27 +0800 Subject: [PATCH] add conv2d for infrt (#40776) --- paddle/phi/infermeta/binary.cc | 28 +++++++++++++- paddle/phi/infermeta/binary.h | 11 ++++++ paddle/phi/kernels/conv_kernel.cc | 57 +++++++++++++++++++++++++++++ paddle/phi/kernels/conv_kernel.h | 12 ++++++ paddle/phi/ops/compat/conv2d_sig.cc | 37 +++++++++++++------ 5 files changed, 131 insertions(+), 14 deletions(-) create mode 100644 paddle/phi/kernels/conv_kernel.cc diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 36a049eca0f..5221076f10d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -23,8 +23,6 @@ limitations under the License. */ #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/common_shape.h" -#include "paddle/phi/kernels/cpu/conv_util.h" - namespace phi { namespace detail { @@ -469,6 +467,31 @@ void ConvInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void ConvInferInferMeta(const MetaTensor& input, + const MetaTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + MetaTensor* out, + MetaConfig config) { + ConvInferMeta(input, + filter, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + /*use_addto=*/false, + /*workspace_size_MB=*/512, // useless in infermeta + /*exhaustive_search=*/false, + out, + config); +} + void ConvTransposeInferMeta(const MetaTensor& x, const MetaTensor& filter, const std::vector& strides, @@ -1670,3 +1693,4 @@ void ValueCompareInferMeta(const MetaTensor& x, PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta); PD_REGISTER_INFER_META_FN(conv2d, phi::ConvInferMeta); +PD_REGISTER_INFER_META_FN(conv2d_infer, phi::ConvInferInferMeta); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9a54c4c5fa6..f9a93984377 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -83,6 +83,17 @@ void ConvInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void ConvInferInferMeta(const MetaTensor& input, + const MetaTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void ConvTransposeInferMeta(const MetaTensor& x, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/kernels/conv_kernel.cc b/paddle/phi/kernels/conv_kernel.cc new file mode 100644 index 00000000000..7268384f401 --- /dev/null +++ b/paddle/phi/kernels/conv_kernel.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/conv_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/fluid/platform/cudnn_workspace_helper.h" + +namespace phi { + +template +void ConvInferKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* out) { + ConvKernel(dev_ctx, + input, + filter, + strides, + paddings, + paddding_algorithm, + groups, + dilations, + data_format, + /*use_addto=*/false, + /*workspace_size_MB=*/paddle::platform:: + GetDefaultConvWorkspaceSizeLimitMB(), + /*exhaustive_search=*/false, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + conv2d_infer, CPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL( + conv2d_infer, GPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {} +#endif diff --git a/paddle/phi/kernels/conv_kernel.h b/paddle/phi/kernels/conv_kernel.h index eb0bfdd0275..508b3a42a21 100644 --- a/paddle/phi/kernels/conv_kernel.h +++ b/paddle/phi/kernels/conv_kernel.h @@ -64,4 +64,16 @@ void DepthwiseConvKernel(const Context& dev_ctx, bool fuse_relu, DenseTensor* out); +template +void ConvInferKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const std::vector& strides, + const std::vector& paddings, + const std::string& paddding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/ops/compat/conv2d_sig.cc b/paddle/phi/ops/compat/conv2d_sig.cc index a755fdb19ec..67b99f1dd61 100644 --- a/paddle/phi/ops/compat/conv2d_sig.cc +++ b/paddle/phi/ops/compat/conv2d_sig.cc @@ -17,18 +17,31 @@ namespace phi { KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("conv2d", - {"Input", "Filter"}, - {"strides", - "paddings", - "padding_algorithm", - "groups", - "dilations", - "data_format", - "use_addto", - "workspace_size_MB", - "exhaustive_search"}, - {"Output"}); + if (!ctx.HasAttr("use_addto") || !ctx.HasAttr("workspace_size_MB") || + !ctx.HasAttr("exhaustive_search")) { + return KernelSignature("conv2d_infer", + {"Input", "Filter"}, + {"strides", + "paddings", + "padding_algorithm", + "groups", + "dilations", + "data_format"}, + {"Output"}); + } else { + return KernelSignature("conv2d", + {"Input", "Filter"}, + {"strides", + "paddings", + "padding_algorithm", + "groups", + "dilations", + "data_format", + "use_addto", + "workspace_size_MB", + "exhaustive_search"}, + {"Output"}); + } } KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) { -- GitLab