diff --git a/paddle/pten/core/infermeta_utils.cc b/paddle/pten/core/infermeta_utils.cc index 3edd1eb87267457cf9fb2132119839b76ae62674..326f9aa720be23611f4f99668c96ffb55d0da552 100644 --- a/paddle/pten/core/infermeta_utils.cc +++ b/paddle/pten/core/infermeta_utils.cc @@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } const MetaTensor& InferMetaContext::InputAt(size_t idx) const { return *inputs_.at(idx); } + +std::vector InferMetaContext::InputsBetween(size_t start, + size_t end) const { + std::vector result; + result.reserve(end - start); + + for (size_t i = start; i < end; ++i) { + result.emplace_back(*inputs_.at(i)); + } + + return result; +} + MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { return outputs_.at(idx).get(); } diff --git a/paddle/pten/core/infermeta_utils.h b/paddle/pten/core/infermeta_utils.h index fecfab7153f53178aed4458e4239d5331fb6f74e..aed1fc2b77ff0cc8f40e8a025f910311b1016036 100644 --- a/paddle/pten/core/infermeta_utils.h +++ b/paddle/pten/core/infermeta_utils.h @@ -17,6 +17,8 @@ limitations under the License. */ #include #include +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/enforce.h" #include "paddle/pten/core/macros.h" #include "paddle/pten/core/meta_tensor.h" @@ -46,6 +48,7 @@ class InferMetaContext { const MetaConfig& GetMetaConfig() const; const MetaTensor& InputAt(size_t idx) const; + std::vector InputsBetween(size_t start, size_t end) const; MetaTensor* MutableOutputAt(size_t idx); template @@ -85,7 +88,8 @@ class InferMetaContext { "InferMeta's Attributes should appear before Outputs."); \ attr_type arg = ctx->AttrAt(attr_idx); \ InferMetaFnCallHelper< \ - Tail...>::template Call(pargs..., \ + Tail...>::template Call(ctx, \ + pargs..., \ arg); \ } \ } @@ -124,6 +128,35 @@ struct InferMetaFnImpl { } }; + template + struct InferMetaFnCallHelper&, Tail...> { + template + static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferMeta's Input should appear before Attributes."); + static_assert(out_idx == 0, + "InferMeta's Input should appear before Outputs."); + const std::pair range = ctx->InputRangeAt(in_idx); + std::vector arg = + ctx->InputsBetween(range.first, range.second); + InferMetaFnCallHelper< + Tail...>::template Call(ctx, + pargs..., + arg); + } + }; + + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE( + const std::vector&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&); + PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&); + // TODO(chenweihang): support vector input later template @@ -227,7 +260,6 @@ struct InferMetaFnRegistrar { "PT_REGISTER_INFER_META_FN must be called in global namespace."); \ static const ::pten::InferMetaFnRegistrar \ __registrar_arg_map_fn_for_##kernel_name_prefix( \ - #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \ - int TouchInferMetaFnSymbol_##op_type() { return 0; } + #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)) } // namespace pten diff --git a/paddle/pten/infermeta/nullary.cc b/paddle/pten/infermeta/nullary.cc index fd9b2a8f717f47aec0a80e0ede68508520153cf7..6823c6252eaddc01db0cb9c6e53e51b88047b18e 100644 --- a/paddle/pten/infermeta/nullary.cc +++ b/paddle/pten/infermeta/nullary.cc @@ -16,10 +16,10 @@ limitations under the License. */ namespace pten { -void CreateInferMeta(const std::vector& shape, - DataType dtype, - DataLayout layout, - MetaTensor* out) { +void CreateInferMetaBase(const std::vector& shape, + DataType dtype, + DataLayout layout, + MetaTensor* out) { auto out_dims = pten::framework::make_ddim(shape); out->set_dims(out_dims); out->set_dtype(dtype); @@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape, DataType dtype, DataLayout layout, MetaTensor* out) { - CreateInferMeta(shape.GetData(), dtype, layout, out); + CreateInferMetaBase(shape.GetData(), dtype, layout, out); } } // namespace pten diff --git a/paddle/pten/infermeta/nullary.h b/paddle/pten/infermeta/nullary.h index f0b6aad26bea6a0ff606f1d353ca4649042234e7..965e240e903b67adb34b6b5b8c1249a878de2be3 100644 --- a/paddle/pten/infermeta/nullary.h +++ b/paddle/pten/infermeta/nullary.h @@ -28,10 +28,10 @@ namespace pten { // Because functions in this file not only can infer shape, but also need // infer lod or other useful data. -void CreateInferMeta(const std::vector& shape, - DataType dtype, - DataLayout layout, - MetaTensor* out); +void CreateInferMetaBase(const std::vector& shape, + DataType dtype, + DataLayout layout, + MetaTensor* out); void CreateInferMeta(const ScalarArray& shape, DataType dtype, diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index ae1461fe8e74204faf0a158d4e7167eaba31970e..5f3b0712b5863145f7340dfb6ae34a6809d9d635 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -242,14 +242,14 @@ void SumInferMeta(const MetaTensor& x, DataType dtype, bool keep_dim, MetaTensor* out) { - ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out)); + ReduceInferMetaBase(x, axis, keep_dim, dtype, out); } -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - DataType dtype, - MetaTensor* out) { +void ReduceInferMetaBase(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + DataType dtype, + MetaTensor* out) { bool reduce_all = true; std::set dims_set(axis.begin(), axis.end()); for (int64_t i = 0; i < x.dims().size(); ++i) { @@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim, MetaTensor* out) { - ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out); + ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out); } void TransferLayoutInferMeta(const MetaTensor& x, @@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x, } } // namespace pten - -PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta); diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 65c6380695c1af808672aef2d75b8b0b54598709..f1dc806b4e9caee980f0bd4b9d5085f375c55bd2 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -53,11 +53,11 @@ void ReshapeInferMeta(const MetaTensor& x, const ScalarArray& shape, MetaTensor* out); -void ReduceInferMeta(const MetaTensor& x, - const std::vector& axis, - bool keep_dim, - DataType dtype, - MetaTensor* out); +void ReduceInferMetaBase(const MetaTensor& x, + const std::vector& axis, + bool keep_dim, + DataType dtype, + MetaTensor* out); void ReduceInferMeta(const MetaTensor& x, const std::vector& axis, diff --git a/paddle/pten/kernels/math_kernel.h b/paddle/pten/kernels/math_kernel.h index eb39b618eb6b3a87fca30a9d27706979774bbbac..4245316568b19ab020abdc2c1b41e347bb9e32c1 100644 --- a/paddle/pten/kernels/math_kernel.h +++ b/paddle/pten/kernels/math_kernel.h @@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, bool keep_dim) { auto dense_out = pten::Empty(dev_ctx); MetaTensor meta_out(&dense_out); - ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out); + ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out); MeanKernel(dev_ctx, x, axis, keep_dim, &dense_out); return dense_out; } diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 7768cb926e454cf8c5b10c097e1173f290fddff6..22f672704525f48556b1d745861c1e295f319ba3 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -161,6 +161,14 @@ kernel : func : scale +- api : sign + args : (const Tensor& x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : sign + - api : subtract args : (const Tensor& x, const Tensor& y) output : Tensor @@ -173,10 +181,10 @@ - api : sum args : (const Tensor& x, const std::vector& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false) output : Tensor - infer_meta : + infer_meta : func : SumInferMeta param: [x, axis, dtype, keep_dim] - kernel : + kernel : func : sum param : [x, axis, dtype, keep_dim] data_type : x diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 497522c8585e00547a3a1cd4d41d0075dad4ff7b..05c8861dcfc8cffb44195548d974f263064a8378 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -379,14 +379,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare input_infos = self.inputs['input_info'] kernel_args_type_list = ['const platform::DeviceContext&'] - input_tensor_code = "" - for input_name in input_names: - # set input code - input_tensor_code = input_tensor_code + f""" - auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" - attr_names = self.attrs['names'] - kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names @@ -401,11 +394,11 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare elif input_name in self.data_transform['support_trans_dtype']: trans_flag = "{false, true}" input_tensor_code = input_tensor_code + f""" - auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" + auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});""" else: input_tensor_code = input_tensor_code + f""" - auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" + auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});""" kernel_args = "*dev_ctx, " for param in kernel_param: diff --git a/python/paddle/utils/code_gen/api_gen.py b/python/paddle/utils/code_gen/api_gen.py index 3a2a5a697004d8921b011f5ed7b1578ddab69b12..7039129a796d69beb45f3e6141338fa05df60e25 100644 --- a/python/paddle/utils/code_gen/api_gen.py +++ b/python/paddle/utils/code_gen/api_gen.py @@ -60,6 +60,14 @@ class ForwardAPI(BaseAPI): return kernel_output, output_names, output_create + def gene_infer_meta_register(self): + if self.is_base_api: + return f""" +PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});""" + + else: + return '' + def header_include(): return """ @@ -83,6 +91,7 @@ def source_include(header_file_path): #include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/utils/storage.h" +#include "paddle/pten/core/infermeta_utils.h" #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/multiary.h" @@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): source_file.write(source_include(include_header_file)) source_file.write(namespace[0]) + infer_meta_register_code = '' + for api in apis: api_code = ForwardAPI(api) print(api_code.gene_api_declaration()) header_file.write(api_code.gene_api_declaration()) source_file.write(api_code.gene_api_code()) + infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register( + ) header_file.write(namespace[1]) source_file.write(namespace[1]) + source_file.write(api_register()) + source_file.write(infer_meta_register_code) header_file.close() source_file.close()