From 4e5d6743436dbed747002c3e040aca8e9b23244d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sat, 19 Feb 2022 11:54:31 +0800 Subject: [PATCH] [Pten] Adjust the params of creation kernel for inference (#39573) * remove manual_api * change sig map of full and empty * fix fill_any_like_xpu_op * fix fill_any_like_xpu_op * fix problem of fill_any_like_xpu_op * fix conflict * polish code --- paddle/fluid/eager/api/utils/tensor_utils.cc | 2 +- paddle/fluid/operators/fill_any_like_op.h | 3 +- .../fluid/operators/fill_any_like_op_xpu.cc | 3 +- paddle/pten/infermeta/nullary.cc | 3 +- paddle/pten/infermeta/nullary.h | 5 +- paddle/pten/infermeta/unary.cc | 7 +-- paddle/pten/infermeta/unary.h | 5 +- paddle/pten/kernels/cpu/full_kernel.cc | 57 ++++++++++++++++++- paddle/pten/kernels/empty_kernel.cc | 6 +- paddle/pten/kernels/empty_kernel.h | 27 +++++---- paddle/pten/kernels/full_kernel.h | 26 ++++----- paddle/pten/kernels/gpu/full_kernel.cu | 11 ++-- paddle/pten/kernels/xpu/full_kernel.cc | 3 + paddle/pten/ops/compat/empty_sig.cc | 6 +- paddle/pten/ops/compat/fill_any_like_sig.cc | 2 +- paddle/pten/ops/compat/fill_constant_sig.cc | 22 ++++--- python/paddle/utils/code_gen/api.yaml | 36 ++++++------ python/paddle/utils/code_gen/api_base.py | 4 +- .../utils/code_gen/wrapped_infermeta_gen.py | 26 +++++---- 19 files changed, 157 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/eager/api/utils/tensor_utils.cc b/paddle/fluid/eager/api/utils/tensor_utils.cc index bd4e9f0af9..801b608b7c 100644 --- a/paddle/fluid/eager/api/utils/tensor_utils.cc +++ b/paddle/fluid/eager/api/utils/tensor_utils.cc @@ -43,7 +43,7 @@ paddle::experimental::Tensor CreateTensorWithValue( bool is_leaf) { paddle::experimental::Tensor out = paddle::experimental::full( paddle::framework::vectorize(ddim), paddle::experimental::Scalar(value), - dtype, pten::TransToPtenBackend(place), layout); + dtype, pten::TransToPtenBackend(place)); auto meta = EagerUtils::autograd_meta(&out); if (is_leaf) { diff --git a/paddle/fluid/operators/fill_any_like_op.h b/paddle/fluid/operators/fill_any_like_op.h index 3ebda1a074..1e5b43c81d 100644 --- a/paddle/fluid/operators/fill_any_like_op.h +++ b/paddle/fluid/operators/fill_any_like_op.h @@ -33,6 +33,7 @@ class FillAnyLikeKernel : public framework::OpKernel { float, T>::type>::type; void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); @@ -65,7 +66,7 @@ class FillAnyLikeKernel : public framework::OpKernel { pten::FullLikeKernel( static_cast::TYPE&>(dev_ctx), - value, out); + *x, value, pten::DataType::UNDEFINED, out); } }; diff --git a/paddle/fluid/operators/fill_any_like_op_xpu.cc b/paddle/fluid/operators/fill_any_like_op_xpu.cc index b4788d0445..fba5d8ece3 100644 --- a/paddle/fluid/operators/fill_any_like_op_xpu.cc +++ b/paddle/fluid/operators/fill_any_like_op_xpu.cc @@ -31,6 +31,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel { using XPUInTDType = typename XPUTypeTrait::Type; void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); @@ -63,7 +64,7 @@ class FillAnyLikeXPUKernel : public framework::OpKernel { pten::FullLikeKernel( static_cast::TYPE&>(dev_ctx), - value, out); + *x, value, pten::DataType::UNDEFINED, out); } }; diff --git a/paddle/pten/infermeta/nullary.cc b/paddle/pten/infermeta/nullary.cc index 6823c6252e..e924914d8f 100644 --- a/paddle/pten/infermeta/nullary.cc +++ b/paddle/pten/infermeta/nullary.cc @@ -28,9 +28,8 @@ void CreateInferMetaBase(const std::vector& shape, void CreateInferMeta(const ScalarArray& shape, DataType dtype, - DataLayout layout, MetaTensor* out) { - CreateInferMetaBase(shape.GetData(), dtype, layout, out); + CreateInferMetaBase(shape.GetData(), dtype, DataLayout::NCHW, out); } } // namespace pten diff --git a/paddle/pten/infermeta/nullary.h b/paddle/pten/infermeta/nullary.h index 965e240e90..a7530abcf1 100644 --- a/paddle/pten/infermeta/nullary.h +++ b/paddle/pten/infermeta/nullary.h @@ -33,9 +33,6 @@ void CreateInferMetaBase(const std::vector& shape, DataLayout layout, MetaTensor* out); -void CreateInferMeta(const ScalarArray& shape, - DataType dtype, - DataLayout layout, - MetaTensor* out); +void CreateInferMeta(const ScalarArray& shape, DataType dtype, MetaTensor* out); } // namespace pten diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index ec9ba519b9..cd35d4fef7 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -79,13 +79,10 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) { out->set_layout(x.layout()); } -void CreateLikeInferMeta(const MetaTensor& x, - DataType dtype, - DataLayout layout, - MetaTensor* out) { +void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype); - out->set_layout(layout == DataLayout::UNDEFINED ? x.layout() : layout); + out->set_layout(x.layout()); } static pten::framework::DDim ValidateShape( diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 5bdf1d491c..2bc4b53f8f 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -41,10 +41,7 @@ void FlattenInferMeta(const MetaTensor& x, void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); -void CreateLikeInferMeta(const MetaTensor& x, - DataType dtype, - DataLayout layout, - MetaTensor* out); +void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out); void InferMetaFromVecValue(const MetaTensor& x, const std::vector& shape, diff --git a/paddle/pten/kernels/cpu/full_kernel.cc b/paddle/pten/kernels/cpu/full_kernel.cc index 62e1bbf1d9..49a613f868 100644 --- a/paddle/pten/kernels/cpu/full_kernel.cc +++ b/paddle/pten/kernels/cpu/full_kernel.cc @@ -16,7 +16,62 @@ limitations under the License. */ #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/full_kernel_impl.h" + +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" + +namespace pten { + +template +void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { + dev_ctx.template Alloc(tensor); + auto t = pten::EigenVector::Flatten(*tensor); + t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(val)); +} + +template +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DataType dtype, + DenseTensor* out) { + out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); + FullValue(dev_ctx, out, val.to()); +} + +template +void FullLikeKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& val, + DataType dtype, + DenseTensor* out) { + auto value = val.to(); + using CommonType = typename std::common_type< + float, + typename std::conditional::value, + float, + T>::type>::type; + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + pten::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), + static_cast(value))); + FullValue(dev_ctx, out, value); +} + +} // namespace pten PT_REGISTER_KERNEL(full, CPU, diff --git a/paddle/pten/kernels/empty_kernel.cc b/paddle/pten/kernels/empty_kernel.cc index 03fe240a88..0f49a4380c 100644 --- a/paddle/pten/kernels/empty_kernel.cc +++ b/paddle/pten/kernels/empty_kernel.cc @@ -23,12 +23,16 @@ namespace pten { template void EmptyKernel(const Context& dev_ctx, const ScalarArray& shape, + DataType dtype, DenseTensor* out) { out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); } template -void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) { +void EmptyLikeKernel(const Context& dev_ctx, + const DenseTensor& x, + DataType dtype, + DenseTensor* out) { dev_ctx.template Alloc(out); } diff --git a/paddle/pten/kernels/empty_kernel.h b/paddle/pten/kernels/empty_kernel.h index 98f7e03d77..a7ba512424 100644 --- a/paddle/pten/kernels/empty_kernel.h +++ b/paddle/pten/kernels/empty_kernel.h @@ -25,10 +25,14 @@ namespace pten { template void EmptyKernel(const Context& dev_ctx, const ScalarArray& shape, + DataType dtype, DenseTensor* out); template -void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out); +void EmptyLikeKernel(const Context& dev_ctx, + const DenseTensor& x, + DataType dtype, + DenseTensor* out); // TODO(chenweihang): the tensor creation method need to be replaced later, // all kernel api call Empty here instead of making tensor self @@ -52,27 +56,22 @@ DenseTensor Empty(const Context& dev_ctx) { template DenseTensor Empty(const Context& dev_ctx, const ScalarArray& shape, - DataType dtype = DataType::FLOAT32, - Backend backend = Backend::CPU, // Is backend needed here? - DataLayout layout = DataLayout::NCHW) { + DataType dtype = DataType::FLOAT32) { auto dense_out = Empty(dev_ctx); MetaTensor meta_out(&dense_out); - CreateInferMeta(shape, dtype, layout, &meta_out); - EmptyKernel(dev_ctx, shape, &dense_out); + CreateInferMeta(shape, dtype, &meta_out); + EmptyKernel(dev_ctx, shape, dtype, &dense_out); return dense_out; } template -DenseTensor EmptyLike( - const Context& dev_ctx, - const DenseTensor& x, - DataType dtype = DataType::UNDEFINED, - Backend backend = Backend::UNDEFINED, // Is backend needed here? - DataLayout layout = DataLayout::UNDEFINED) { +DenseTensor EmptyLike(const Context& dev_ctx, + const DenseTensor& x, + DataType dtype = DataType::UNDEFINED) { auto dense_out = Empty(dev_ctx); MetaTensor meta_out(&dense_out); - CreateLikeInferMeta(x, dtype, layout, &meta_out); - EmptyLikeKernel(dev_ctx, &dense_out); + CreateLikeInferMeta(x, dtype, &meta_out); + EmptyLikeKernel(dev_ctx, x, dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/kernels/full_kernel.h b/paddle/pten/kernels/full_kernel.h index 030eb4b1c7..b8b78e311a 100644 --- a/paddle/pten/kernels/full_kernel.h +++ b/paddle/pten/kernels/full_kernel.h @@ -27,39 +27,37 @@ template void FullKernel(const Context& dev_ctx, const ScalarArray& shape, const Scalar& val, + DataType dtype, DenseTensor* out); template void FullLikeKernel(const Context& dev_ctx, + const DenseTensor& x, const Scalar& val, + DataType dtype, DenseTensor* out); template DenseTensor Full(const Context& dev_ctx, const ScalarArray& shape, const Scalar& val, - DataType dtype = DataType::FLOAT32, - Backend backend = Backend::CPU, // Is backend needed here? - DataLayout layout = DataLayout::NCHW) { + DataType dtype = DataType::FLOAT32) { auto dense_out = Empty(dev_ctx); MetaTensor meta_out(&dense_out); - CreateInferMeta(shape, dtype, layout, &meta_out); - FullKernel(dev_ctx, shape, val, &dense_out); + CreateInferMeta(shape, dtype, &meta_out); + FullKernel(dev_ctx, shape, val, dtype, &dense_out); return dense_out; } template -DenseTensor FullLike( - const Context& dev_ctx, - const DenseTensor& x, - const Scalar& val, - DataType dtype = DataType::UNDEFINED, - Backend backend = Backend::UNDEFINED, // Is backend needed here? - DataLayout layout = DataLayout::UNDEFINED) { +DenseTensor FullLike(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& val, + DataType dtype = DataType::UNDEFINED) { auto dense_out = Empty(dev_ctx); MetaTensor meta_out(&dense_out); - CreateLikeInferMeta(x, dtype, layout, &meta_out); - FullLikeKernel(dev_ctx, val, &dense_out); + CreateLikeInferMeta(x, dtype, &meta_out); + FullLikeKernel(dev_ctx, x, val, dtype, &dense_out); return dense_out; } diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index 7f600fb313..4ae50625e2 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -33,10 +33,11 @@ struct FullFuctor { } }; -template -void FullKernel(const ContextT& dev_ctx, +template +void FullKernel(const Context& dev_ctx, const ScalarArray& shape, const Scalar& val, + DataType dtype, DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); int numel = out->numel(); @@ -53,9 +54,11 @@ void FullKernel(const ContextT& dev_ctx, } } -template -void FullLikeKernel(const ContextT& dev_ctx, +template +void FullLikeKernel(const Context& dev_ctx, + const DenseTensor& x, const Scalar& val, + DataType dtype, DenseTensor* out) { auto value = val.to(); using CommonType = typename std::common_type< diff --git a/paddle/pten/kernels/xpu/full_kernel.cc b/paddle/pten/kernels/xpu/full_kernel.cc index cf6befac02..bd406fdb3e 100644 --- a/paddle/pten/kernels/xpu/full_kernel.cc +++ b/paddle/pten/kernels/xpu/full_kernel.cc @@ -57,6 +57,7 @@ template void FullKernel(const Context& dev_ctx, const ScalarArray& shape, const Scalar& val, + DataType dtype, DenseTensor* out) { out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); FullValueXPU(dev_ctx, out, val.to()); @@ -64,7 +65,9 @@ void FullKernel(const Context& dev_ctx, template void FullLikeKernel(const Context& dev_ctx, + const DenseTensor& x, const Scalar& val, + DataType dtype, DenseTensor* out) { auto value = val.to(); using XPUInTDType = typename XPUTypeTrait::Type; diff --git a/paddle/pten/ops/compat/empty_sig.cc b/paddle/pten/ops/compat/empty_sig.cc index c74f610698..35aa17dcf9 100644 --- a/paddle/pten/ops/compat/empty_sig.cc +++ b/paddle/pten/ops/compat/empty_sig.cc @@ -18,11 +18,11 @@ namespace pten { KernelSignature EmptyOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("ShapeTensor")) { - return KernelSignature("empty", {}, {"ShapeTensor"}, {"Out"}); + return KernelSignature("empty", {}, {"ShapeTensor", "dtype"}, {"Out"}); } else if (ctx.InputSize("ShapeTensorList") > 0) { - return KernelSignature("empty", {}, {"ShapeTensorList"}, {"Out"}); + return KernelSignature("empty", {}, {"ShapeTensorList", "dtype"}, {"Out"}); } else { - return KernelSignature("empty", {}, {"shape"}, {"Out"}); + return KernelSignature("empty", {}, {"shape", "dtype"}, {"Out"}); } } diff --git a/paddle/pten/ops/compat/fill_any_like_sig.cc b/paddle/pten/ops/compat/fill_any_like_sig.cc index 81065d0c8a..0440d3769f 100644 --- a/paddle/pten/ops/compat/fill_any_like_sig.cc +++ b/paddle/pten/ops/compat/fill_any_like_sig.cc @@ -18,7 +18,7 @@ namespace pten { KernelSignature FillAnyLikeOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("full_like", {}, {"value"}, {"Out"}); + return KernelSignature("full_like", {"X"}, {"value", "dtype"}, {"Out"}); } } // namespace pten diff --git a/paddle/pten/ops/compat/fill_constant_sig.cc b/paddle/pten/ops/compat/fill_constant_sig.cc index 73dee270f7..242fefe999 100644 --- a/paddle/pten/ops/compat/fill_constant_sig.cc +++ b/paddle/pten/ops/compat/fill_constant_sig.cc @@ -23,42 +23,46 @@ KernelSignature FillConstantOpArgumentMapping( if (ctx.HasInput("ShapeTensor")) { if (ctx.HasInput("ValueTensor")) { return KernelSignature( - "full", {}, {"ShapeTensor", "ValueTensor"}, {"Out"}); + "full", {}, {"ShapeTensor", "ValueTensor", "dtype"}, {"Out"}); } else { const auto& str_value = paddle::any_cast(ctx.Attr("str_value")); if (str_value.empty()) { - return KernelSignature("full", {}, {"ShapeTensor", "value"}, {"Out"}); + return KernelSignature( + "full", {}, {"ShapeTensor", "value", "dtype"}, {"Out"}); } else { return KernelSignature( - "full", {}, {"ShapeTensor", "str_value"}, {"Out"}); + "full", {}, {"ShapeTensor", "str_value", "dtype"}, {"Out"}); } } } else if (ctx.InputSize("ShapeTensorList") > 0) { if (ctx.HasInput("ValueTensor")) { return KernelSignature( - "full", {}, {"ShapeTensorList", "ValueTensor"}, {"Out"}); + "full", {}, {"ShapeTensorList", "ValueTensor", "dtype"}, {"Out"}); } else { const auto& str_value = paddle::any_cast(ctx.Attr("str_value")); if (str_value.empty()) { return KernelSignature( - "full", {}, {"ShapeTensorList", "value"}, {"Out"}); + "full", {}, {"ShapeTensorList", "value", "dtype"}, {"Out"}); } else { return KernelSignature( - "full", {}, {"ShapeTensorList", "str_value"}, {"Out"}); + "full", {}, {"ShapeTensorList", "str_value", "dtype"}, {"Out"}); } } } else { if (ctx.HasInput("ValueTensor")) { - return KernelSignature("full", {}, {"shape", "ValueTensor"}, {"Out"}); + return KernelSignature( + "full", {}, {"shape", "ValueTensor", "dtype"}, {"Out"}); } else { const auto& str_value = paddle::any_cast(ctx.Attr("str_value")); if (str_value.empty()) { - return KernelSignature("full", {}, {"shape", "value"}, {"Out"}); + return KernelSignature( + "full", {}, {"shape", "value", "dtype"}, {"Out"}); } else { - return KernelSignature("full", {}, {"shape", "str_value"}, {"Out"}); + return KernelSignature( + "full", {}, {"shape", "str_value", "dtype"}, {"Out"}); } } } diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 60e64c0284..390ccdd157 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -51,30 +51,28 @@ func : dot - api : empty - args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) + args : (ScalarArray shape, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU) output: Tensor infer_meta : func : CreateInferMeta - param : [shape, dtype, layout] + param : [shape, dtype] kernel : func : empty - param : [shape] + param : [shape, dtype] data_type : dtype backend : place - layout : layout - api : empty_like - args : (Tensor x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) + args : (Tensor x, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED) output: Tensor infer_meta : func : CreateLikeInferMeta - param : [x, dtype, layout] + param : [x, dtype] kernel : func : empty_like - param : [] + param : [x, dtype] data_type : dtype > x backend : place > x - layout : layout > x - api : flatten args : (Tensor x, int start_axis, int stop_axis) @@ -85,30 +83,28 @@ func : flatten - api : full - args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU, DataLayout layout=DataLayout::NCHW) + args : (ScalarArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Backend place=Backend::CPU) output: Tensor infer_meta : func : CreateInferMeta - param : [shape, dtype, layout] + param : [shape, dtype] kernel : func : full - param : [shape, value] + param : [shape, value, dtype] data_type : dtype backend : place - layout : layout - api : full_like - args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED, DataLayout layout = DataLayout::UNDEFINED) + args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Backend place = Backend::UNDEFINED) output: Tensor infer_meta : func : CreateLikeInferMeta - param : [x, dtype, layout] + param : [x, dtype] kernel : func : full_like - param : [value] + param : [x, value, dtype] data_type : dtype > x backend : place > x - layout : layout > x - api : matmul args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) @@ -136,9 +132,9 @@ func : multiply - api : ones_like - args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) + args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED) output : Tensor - invoke : full_like(x, 1, dtype, place, layout) + invoke : full_like(x, 1, dtype, place) - api : reshape args : (Tensor x, ScalarArray shape) @@ -185,6 +181,6 @@ data_type : x - api : zeros_like - args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED) + args : (Tensor x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED) output : Tensor - invoke : full_like(x, 0, dtype, place, layout) + invoke : full_like(x, 0, dtype, place) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 26abfdc031..7667a836e5 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -358,8 +358,8 @@ PADDLE_API {self.outputs['return_type']} {self.get_api_func_name() + '_'}({self. """ if len(input_names) == 0: - assert attr_backend_count > 0 and attr_layout_count > 0 and attr_data_type_count > 0, \ - f"{api} api: When there is no input tensor, the args must have 'Backend', 'DataLayout' and 'DataType'." + assert attr_backend_count > 0 and attr_data_type_count > 0, \ + f"{api} api: When there is no input tensor, the args must have 'Backend' and 'DataType'." kernel_select_args = "" for input_name in input_names: diff --git a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py index 6972b9af25..f5cf870573 100644 --- a/python/paddle/utils/code_gen/wrapped_infermeta_gen.py +++ b/python/paddle/utils/code_gen/wrapped_infermeta_gen.py @@ -29,30 +29,36 @@ def gene_wrapped_infermeta_and_register(api): PT_REGISTER_INFER_META_FN({api.kernel['func'][0]}, pten::{api.infer_meta['func']});""" if api.infer_meta['param'] is not None: + kernel_params = api.kernel['param'] + if kernel_params is None: + kernel_params = api.inputs['names'] + api.attrs['names'] + if kernel_params == api.infer_meta['param']: + return '', '', register_code + + assert len(api.infer_meta['param']) <= len(kernel_params), \ + f"{api.api} api: Parameters error. The params of infer_meta should be a subset of kernel params." + tensor_type_map = { 'const Tensor&': 'const MetaTensor&', 'const std::vector&': 'const std::vector&', 'Tensor': 'MetaTensor*', 'std::vector': 'std::vector*', } + wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) args = [] - check_args = [] for input_name in api.inputs['names']: - args.append(tensor_type_map[api.inputs['input_info'][ - input_name]] + ' ' + input_name) - check_args.append(input_name) + if input_name in kernel_params: + args.append(tensor_type_map[api.inputs['input_info'][ + input_name]] + ' ' + input_name) for attr_name in api.attrs['names']: - args.append(api.attrs['attr_info'][attr_name][0] + ' ' + - attr_name) - check_args.append(attr_name) + if attr_name in kernel_params: + args.append(api.attrs['attr_info'][attr_name][0] + ' ' + + attr_name) for i, out_type in enumerate(api.outputs['types']): args.append(tensor_type_map[out_type] + ' ' + api.outputs[ 'names'][i]) - if check_args == api.infer_meta['param']: - return '', '', register_code - invoke_param = api.infer_meta['param'] invoke_param.extend(api.outputs['names']) -- GitLab