未验证 提交 7d6096ff 编写于 作者: Z zyfncg 提交者: GitHub

【Pten】Auto-Generate InterMeta register (#39436)

* fix code conflict

* generate inter_meta register

* clear cache

* just try

* add sign c++ api

* polish some code
上级 1252f4bb
......@@ -66,6 +66,19 @@ const MetaConfig& InferMetaContext::GetMetaConfig() const { return config_; }
const MetaTensor& InferMetaContext::InputAt(size_t idx) const {
return *inputs_.at(idx);
}
std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
}
return result;
}
MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}
......
......@@ -17,6 +17,8 @@ limitations under the License. */
#include <string>
#include <utility>
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/macros.h"
#include "paddle/pten/core/meta_tensor.h"
......@@ -46,6 +48,7 @@ class InferMetaContext {
const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;
std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
template <typename AttrType>
......@@ -85,7 +88,8 @@ class InferMetaContext {
"InferMeta's Attributes should appear before Outputs."); \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(pargs..., \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
......@@ -124,6 +128,35 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};
template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
// TODO(chenweihang): support vector<MetaTensor> input later
template <typename... Tail>
......@@ -227,7 +260,6 @@ struct InferMetaFnRegistrar {
"PT_REGISTER_INFER_META_FN must be called in global namespace."); \
static const ::pten::InferMetaFnRegistrar \
__registrar_arg_map_fn_for_##kernel_name_prefix( \
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn)); \
int TouchInferMetaFnSymbol_##op_type() { return 0; }
#kernel_name_prefix, PT_INFER_META(variadic_infer_meta_fn))
} // namespace pten
......@@ -16,10 +16,10 @@ limitations under the License. */
namespace pten {
void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
auto out_dims = pten::framework::make_ddim(shape);
out->set_dims(out_dims);
out->set_dtype(dtype);
......@@ -30,7 +30,7 @@ void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out) {
CreateInferMeta(shape.GetData(), dtype, layout, out);
CreateInferMetaBase(shape.GetData(), dtype, layout, out);
}
} // namespace pten
......@@ -28,10 +28,10 @@ namespace pten {
// Because functions in this file not only can infer shape, but also need
// infer lod or other useful data.
void CreateInferMeta(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);
void CreateInferMetaBase(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout,
MetaTensor* out);
void CreateInferMeta(const ScalarArray& shape,
DataType dtype,
......
......@@ -242,14 +242,14 @@ void SumInferMeta(const MetaTensor& x,
DataType dtype,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, dtype, std::move(out));
ReduceInferMetaBase(x, axis, keep_dim, dtype, out);
}
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out) {
bool reduce_all = true;
std::set<int64_t> dims_set(axis.begin(), axis.end());
for (int64_t i = 0; i < x.dims().size(); ++i) {
......@@ -304,7 +304,7 @@ void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
MetaTensor* out) {
ReduceInferMeta(x, axis, keep_dim, DataType::UNDEFINED, out);
ReduceInferMetaBase(x, axis, keep_dim, DataType::UNDEFINED, out);
}
void TransferLayoutInferMeta(const MetaTensor& x,
......@@ -316,5 +316,3 @@ void TransferLayoutInferMeta(const MetaTensor& x,
}
} // namespace pten
PT_REGISTER_INFER_META_FN(sign, pten::UnchangedInferMeta);
......@@ -53,11 +53,11 @@ void ReshapeInferMeta(const MetaTensor& x,
const ScalarArray& shape,
MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);
void ReduceInferMetaBase(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
DataType dtype,
MetaTensor* out);
void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
......
......@@ -156,7 +156,7 @@ DenseTensor Mean(const Context& dev_ctx,
bool keep_dim) {
auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
ReduceInferMeta(x, axis, keep_dim, x.dtype(), &meta_out);
ReduceInferMetaBase(x, axis, keep_dim, x.dtype(), &meta_out);
MeanKernel<T, Context>(dev_ctx, x, axis, keep_dim, &dense_out);
return dense_out;
}
......
......@@ -161,6 +161,14 @@
kernel :
func : scale
- api : sign
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : sign
- api : subtract
args : (const Tensor& x, const Tensor& y)
output : Tensor
......@@ -173,10 +181,10 @@
- api : sum
args : (const Tensor& x, const std::vector<int64_t>& axis={}, DataType dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor
infer_meta :
infer_meta :
func : SumInferMeta
param: [x, axis, dtype, keep_dim]
kernel :
kernel :
func : sum
param : [x, axis, dtype, keep_dim]
data_type : x
......
......@@ -379,14 +379,7 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
input_tensor_code = ""
for input_name in input_names:
# set input code
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
......@@ -401,11 +394,11 @@ PADDLE_API {self.outputs['return_type']} {self.api}({self.args_str['args_declare
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({i}), {trans_flag});"""
else:
input_tensor_code = input_tensor_code + f"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
auto {PREFIX_TENSOR_NAME}{input_name} = TensorToDenseTensor({input_name});"""
kernel_args = "*dev_ctx, "
for param in kernel_param:
......
......@@ -60,6 +60,14 @@ class ForwardAPI(BaseAPI):
return kernel_output, output_names, output_create
def gene_infer_meta_register(self):
if self.is_base_api:
return f"""
PT_REGISTER_INFER_META_FN({self.kernel['func']}, pten::{self.infer_meta['func']});"""
else:
return ''
def header_include():
return """
......@@ -83,6 +91,7 @@ def source_include(header_file_path):
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/infermeta/multiary.h"
......@@ -127,15 +136,21 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
source_file.write(source_include(include_header_file))
source_file.write(namespace[0])
infer_meta_register_code = ''
for api in apis:
api_code = ForwardAPI(api)
print(api_code.gene_api_declaration())
header_file.write(api_code.gene_api_declaration())
source_file.write(api_code.gene_api_code())
infer_meta_register_code = infer_meta_register_code + api_code.gene_infer_meta_register(
)
header_file.write(namespace[1])
source_file.write(namespace[1])
source_file.write(api_register())
source_file.write(infer_meta_register_code)
header_file.close()
source_file.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册