未验证 提交 7d6096ff 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Auto-Generate InterMeta register (#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code
上级 1252f4bb
...@@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; } ...@@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const { const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx); return *inputs_.at(idx);
} }
std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
}
return result;
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) { MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get(); return outputs_.at(idx).get();
} }
......
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
#include <string> #include <string>
#include <utility> #include <utility>
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/enforce.h" #include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/macros.h" #include "paddle/pten/core/macros.h"
#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/core/meta_tensor.h"
...@@ -46,6 +48,7 @@ class InferMetaContext { ...@@ -46,6 +48,7 @@ class InferMetaContext {
const MetaConfig& GetMetaConfig() const; const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const; const MetaTensor& InputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx); MetaTensor* MutableOutputAt(size_t idx);
template <typename AttrType> template <typename AttrType>
...@@ -85,7 +88,8 @@ class InferMetaContext { ...@@ -85,7 +88,8 @@ class InferMetaContext {
"InferMeta's Attributes should appear before Outputs."); \ "InferMeta's Attributes should appear before Outputs."); \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \ attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \ InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(pargs..., \ Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \ arg); \
} \ } \
} }
...@@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> { ...@@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
} }
}; };
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
// TODO(chenweihang): support vector<MetaTensor> input later // TODO(chenweihang): support vector<MetaTensor> input later
template <typename... Tail> template <typename... Tail>
...@@ -227,7 +260,6 @@ struct InferMetaFnRegistrar { ...@@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \ "PT_REGISTER_INFER_META_FN must be called in global namespace."); \
static const ::pten::InferMetaFnRegistrar \ static const ::pten::InferMetaFnRegistrar \
__registrar_arg_map_fn_for_##kernel_name_prefix( \ __registrar_arg_map_fn_for_##kernel_name_prefix( \
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \ #kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))
int TouchInferMetaFnSymbol_##op_type() { return 0; }
} // namespace pten } // namespace pten
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace pten { namespace pten {
void CreateInferMeta(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout, DataLayout layout,
MetaTensor* out) { MetaTensor* out) {
...@@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape, ...@@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape,
DataType dtype, DataType dtype,
DataLayout layout, DataLayout layout,
MetaTensor* out) { MetaTensor* out) {
CreateInferMeta(shape.GetData(), dtype, layout, out); CreateInferMetaBase(shape.GetData(), dtype, layout, out);
} }
} // namespace pten } // namespace pten
...@@ -28,7 +28,7 @@ namespace pten { ...@@ -28,7 +28,7 @@ namespace pten {
// Because functions in this file not only can infer shape, but also need // Because functions in this file not only can infer shape, but also need
// infer lod or other useful data. // infer lod or other useful data.
void CreateInferMeta(const std::vector<int64_t>& shape, void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype, DataType dtype,
DataLayout layout, DataLayout layout,
MetaTensor* out); MetaTensor* out);
......
...@@ -242,10 +242,10 @@ void SumInferMeta(const MetaTensor& x, ...@@ -242,10 +242,10 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype, DataType dtype,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out)); ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
} }
void ReduceInferMeta(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
DataType dtype, DataType dtype,
...@@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x, ...@@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
MetaTensor* out) { MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out); ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
} }
void TransferLayoutInferMeta(const MetaTensor& x, void TransferLayoutInferMeta(const MetaTensor& x,
...@@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x, ...@@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x,
} }
} // namespace pten } // namespace pten
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
...@@ -53,7 +53,7 @@ void ReshapeInferMeta(const MetaTensor& x, ...@@ -53,7 +53,7 @@ void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
MetaTensor* out); MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x, void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim, bool keep_dim,
DataType dtype, DataType dtype,
......
...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx, ...@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) { bool keep_dim) {
auto dense_out = pten::Empty<T, Context>(dev_ctx); auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out); ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out); MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -161,6 +161,14 @@ ...@@ -161,6 +161,14 @@
kernel : kernel :
func : scale func : scale
- api : sign
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign
- api : subtract - api : subtract
args : (const Tensor& x, const Tensor& y) args : (const Tensor& x, const Tensor& y)
output : Tensor output : Tensor
......
...@@ -379,14 +379,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare ...@@ -379,14 +379,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
input_infos = self.inputs['input_info'] input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&'] kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
attr_names = self.attrs['names'] attr_names = self.attrs['names']
kernel_param = self.kernel['param'] kernel_param = self.kernel['param']
if kernel_param is None: if kernel_param is None:
kernel_param = input_names + attr_names kernel_param = input_names + attr_names
......
...@@ -60,6 +60,14 @@ class ForwardAPI(BaseAPI): ...@@ -60,6 +60,14 @@ class ForwardAPI(BaseAPI):
return kernel_output, output_names, output_create return kernel_output, output_names, output_create
def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""
else:
return ''
def header_include(): def header_include():
return """ return """
...@@ -83,6 +91,7 @@ def source_include(header_file_path): ...@@ -83,6 +91,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h" #include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h" #include "paddle/pten/infermeta/multiary.h"
...@@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file)) source_file.write(source_include(include_header_file))
source_file.write(namespace[0]) source_file.write(namespace[0])
infer_meta_register_code = ''
for api in apis: for api in apis:
api_code = ForwardAPI(api) api_code = ForwardAPI(api)
print(api_code.gene_api_declaration()) print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration()) header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code()) source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
source_file.write(api_register()) source_file.write(api_register())
source_file.write(infer_meta_register_code)
header_file.close() header_file.close()
source_file.close() source_file.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册