未验证 提交 aad0ae2a 编写于 作者: C Chen Weihang 提交者: GitHub

add conv2d for infrt (#40776)

上级 0331cfda
......@@ -23,8 +23,6 @@ limitations under the License. */
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace phi {
namespace detail {
......@@ -469,6 +467,31 @@ void ConvInferMeta(const MetaTensor& input,
out->set_dtype(input.dtype());
}
void ConvInferInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
MetaTensor* out,
MetaConfig config) {
ConvInferMeta(input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/512, // useless in infermeta
/*exhaustive_search=*/false,
out,
config);
}
void ConvTransposeInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const std::vector<int>& strides,
......@@ -1670,3 +1693,4 @@ void ValueCompareInferMeta(const MetaTensor& x,
PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
PD_REGISTER_INFER_META_FN(conv2d, phi::ConvInferMeta);
PD_REGISTER_INFER_META_FN(conv2d_infer, phi::ConvInferInferMeta);
......@@ -83,6 +83,17 @@ void ConvInferMeta(const MetaTensor& input,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConvInferInferMeta(const MetaTensor& input,
const MetaTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
MetaTensor* out,
MetaConfig config = MetaConfig());
void ConvTransposeInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const std::vector<int>& strides,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/conv_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
namespace phi {
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
ConvKernel<T, Context>(dev_ctx,
input,
filter,
strides,
paddings,
paddding_algorithm,
groups,
dilations,
data_format,
/*use_addto=*/false,
/*workspace_size_MB=*/paddle::platform::
GetDefaultConvWorkspaceSizeLimitMB(),
/*exhaustive_search=*/false,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(
conv2d_infer, CPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
conv2d_infer, GPU, ALL_LAYOUT, phi::ConvInferKernel, float, double) {}
#endif
......@@ -64,4 +64,16 @@ void DepthwiseConvKernel(const Context& dev_ctx,
bool fuse_relu,
DenseTensor* out);
template <typename T, typename Context>
void ConvInferKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& paddding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out);
} // namespace phi
......@@ -17,18 +17,31 @@
namespace phi {
KernelSignature Conv2dOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{"Output"});
if (!ctx.HasAttr("use_addto") || !ctx.HasAttr("workspace_size_MB") ||
!ctx.HasAttr("exhaustive_search")) {
return KernelSignature("conv2d_infer",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format"},
{"Output"});
} else {
return KernelSignature("conv2d",
{"Input", "Filter"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"use_addto",
"workspace_size_MB",
"exhaustive_search"},
{"Output"});
}
}
KernelSignature Conv2dGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册