diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index d7a2a42ca7dc751f8a6834ef4b3e53e2e0467523..3a1733344171636e49e7b8343ae2ca0b76b03902 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/tensor_utils.h" namespace paddle { @@ -447,7 +448,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, auto attr_reader = ctx->Attrs(); for (size_t i = 0; i < attr_names.size(); ++i) { auto& attr_name = attr_names[i]; - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { // When attr is a vector_tensor or tensor, transform it to IntArray if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); @@ -517,8 +518,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { if (ctx->HasAttr(attr_name)) { // TODO(chentianyu03): support other attrs later auto& attr = attr_reader.GetAttr(attr_name); @@ -558,8 +558,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name, infershape_input.size())); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = attr_reader.GetAttr(attr_name); if (std::type_index(attr.type()) == std::type_index(typeid(std::vector))) { @@ -606,27 +605,23 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); - if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + if (attr_defs[i].type_index == phi::AttributeType::BOOL) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOLS) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { if (std::type_index(attr.type()) == std::type_index(typeid(std::vector))) { // Emplace Back Attr according to the type of Phi_Kernel args. @@ -638,20 +633,16 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT64S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = paddle::framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr))); @@ -663,7 +654,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } else if (ctx->HasInput(attr_name)) { // convert from data - if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { if (ctx->IsRuntime()) { auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 140103b10592fdfdee95a2ba8d03d12d7880aa5a..abb645915ed556a6dad38b60b32b1c15261647b5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2413,7 +2413,7 @@ void OperatorWithKernel::BuildPhiKernelContext( VLOG(4) << "Done outputs"; for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { auto attr_iter = Attrs().find(attr_names[i]); if (attr_iter != Attrs().end()) { // shape is in the attribute if (std::type_index(attr_iter->second.type()) == @@ -2444,8 +2444,7 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiIntArrayFromVarList(ins_vector))); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { // TODO(chenweihang): support other attrs later // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs @@ -2475,8 +2474,7 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = Attrs().at(attr_names[i]); if (std::type_index(attr.type()) == std::type_index(typeid(std::vector))) { @@ -2521,9 +2519,8 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_names[i])); } } else { - // TODO(chenweihang): support other attrs later auto attr_it = attrs_.find(attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { if (attr_it == attrs_.end()) { auto in_it = ctx.inputs.find(attr_names[i]); if (in_it != ctx.inputs.end()) { @@ -2540,27 +2537,24 @@ void OperatorWithKernel::BuildPhiKernelContext( pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(int, attr_it->second)); } - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(float, attr_it->second)); - } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(bool, attr_it->second)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(int64_t, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::string, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = paddle::framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr_it->second))); pt_kernel_context->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { if (std::type_index(attr_it->second.type()) == std::type_index(typeid(std::vector))) { pt_kernel_context->EmplaceBackAttr( @@ -2574,17 +2568,14 @@ void OperatorWithKernel::BuildPhiKernelContext( vector_int_attr.end()); pt_kernel_context->EmplaceBackAttr(vector_int64_attr); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr_it->second); pt_kernel_context->EmplaceBackAttr(vector_int_attr); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr_it->second)); } else { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index dedb6a382efa6f0be2d6de9d07c3cd4580d0d453..6cc86f8129913c028aa7f92f58eec7ff25117249 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -378,7 +378,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, } for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { if (attrs.find(attr_names[i]) != attrs.end()) { // shape is in the attribute auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); @@ -398,8 +398,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, std::type_index(typeid(int32_t))) { kernel_ctx->EmplaceBackAttr( std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); kernel_ctx->EmplaceBackAttr(vector_int_attr); } else { @@ -423,9 +422,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, std::move(experimental::MakePhiIntArrayFromVarList(variables))); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { - // TODO(chenweihang): support other attrs later + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs if (attrs.find(attr_names[i]) != attrs.end() || @@ -460,14 +457,13 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, auto& ins_vector = ins.at(attr_names[i]); auto tensor_attr = experimental::MakePhiScalarFromVar(ins_vector[0]->Var()); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { int val = tensor_attr.template to(); kernel_ctx->EmplaceBackAttr(val); } else { PADDLE_THROW(platform::errors::Unimplemented("only support int here")); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); if (std::type_index(attr.type()) == std::type_index(typeid(std::vector))) { @@ -521,28 +517,23 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, attr_names[i])); } } else { - // TODO(chenweihang): support other attrs later - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr))); kernel_ctx->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { if (std::type_index(attr.type()) == std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr( @@ -555,15 +546,12 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, vector_int_attr.end()); kernel_ctx->EmplaceBackAttr(vector_int64_attr); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { kernel_ctx->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 6d71c5016bda4ee5e8b4cc9be9716260766bea06..08329d0c8636a93797181241a5e931622dcaff8b 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -140,6 +140,68 @@ const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( return iter->second.cbegin()->second.args_def(); } +std::ostream& operator<<(std::ostream& os, AttributeType attr_type) { + switch (attr_type) { + case AttributeType::BOOL: + os << "bool"; + break; + case AttributeType::INT32: + os << "int"; + break; + case AttributeType::INT64: + os << "int64_t"; + break; + case AttributeType::FLOAT32: + os << "float"; + break; + case AttributeType::FLOAT64: + os << "double"; + break; + case AttributeType::STRING: + os << "string"; + break; + case AttributeType::BOOLS: + os << "vector"; + break; + case AttributeType::INT32S: + os << "vector"; + break; + case AttributeType::INT64S: + os << "vector"; + break; + case AttributeType::FLOAT32S: + os << "vector"; + break; + case AttributeType::FLOAT64S: + os << "vector"; + break; + case AttributeType::STRINGS: + os << "vector"; + break; + case AttributeType::SCALAR: + os << "Scalar"; + break; + case AttributeType::SCALARS: + os << "vector"; + break; + case AttributeType::INT_ARRAY: + os << "IntArray"; + break; + case AttributeType::DATA_TYPE: + os << "DataType"; + break; + case AttributeType::DATA_LAYOUT: + os << "DataLayout"; + break; + case AttributeType::PLACE: + os << "Place"; + break; + default: + os << "Undefined"; + } + return os; +} + // print kernel info with json format: // { // "(CPU, Undefined(AnyLayout), complex64)": { @@ -175,7 +237,7 @@ std::ostream& operator<<(std::ostream& os, const Kernel& kernel) { need_comma = false; for (auto& arg_def : kernel.args_def().attribute_defs()) { if (need_comma) os << ","; - os << "\"" << arg_def.type_index.name() << "\""; + os << "\"" << arg_def.type_index << "\""; need_comma = true; } os << "]}"; diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 3ac99a426319dd82a97ca2dafe407b01244a8419..9d7ebd97895168ed9e79553f1a4bd4cfcc17cad6 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -122,11 +122,33 @@ struct TensorArgDef { } }; +// Align the original fluid Attribute type with lower overhead +enum class AttributeType { + UNDEFINED = 0, + BOOL, + INT32, + INT64, + FLOAT32, + FLOAT64, + STRING, + BOOLS, + INT32S, + INT64S, + FLOAT32S, + FLOAT64S, + STRINGS, + SCALAR, + SCALARS, + INT_ARRAY, + DATA_TYPE, + DATA_LAYOUT, + PLACE, +}; + struct AttributeArgDef { - std::type_index type_index; + AttributeType type_index; - explicit AttributeArgDef(std::type_index type_index) - : type_index(type_index) {} + explicit AttributeArgDef(AttributeType type_index) : type_index(type_index) {} }; class KernelArgsDef { @@ -147,7 +169,7 @@ class KernelArgsDef { output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index)); } - void AppendAttribute(std::type_index type_index) { + void AppendAttribute(AttributeType type_index) { attribute_defs_.emplace_back(AttributeArgDef(type_index)); } @@ -277,6 +299,8 @@ inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { return os; } +std::ostream& operator<<(std::ostream& os, AttributeType attr_type); + std::ostream& operator<<(std::ostream& os, const Kernel& kernel); std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory); diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 356ab58f40726b93eb0f4f57c06a7b873b8d5153..36ab9c081cc374d38272dd435594f75a86cff3ff 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -163,11 +163,51 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(bool))) { + args_def->AppendAttribute(AttributeType::BOOL); + } else if (arg_type == std::type_index(typeid(int))) { + args_def->AppendAttribute(AttributeType::INT32); + } else if (arg_type == std::type_index(typeid(int64_t))) { + args_def->AppendAttribute(AttributeType::INT64); + } else if (arg_type == std::type_index(typeid(float))) { + args_def->AppendAttribute(AttributeType::FLOAT32); + } else if (arg_type == std::type_index(typeid(double))) { + args_def->AppendAttribute(AttributeType::FLOAT64); + } else if (arg_type == std::type_index(typeid(std::string))) { + args_def->AppendAttribute(AttributeType::STRING); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::BOOLS); + } else if (arg_type == std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::INT32S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::INT64S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::FLOAT32S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::FLOAT64S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::STRINGS); + } else if (arg_type == std::type_index(typeid(const Scalar&))) { + args_def->AppendAttribute(AttributeType::SCALAR); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::SCALARS); + } else if (arg_type == std::type_index(typeid(const IntArray&))) { + args_def->AppendAttribute(AttributeType::INT_ARRAY); + } else if (arg_type == std::type_index(typeid(DataType))) { + args_def->AppendAttribute(AttributeType::DATA_TYPE); + } else if (arg_type == std::type_index(typeid(DataLayout))) { + args_def->AppendAttribute(AttributeType::DATA_LAYOUT); + } else if (arg_type == std::type_index(typeid(Place))) { + args_def->AppendAttribute(AttributeType::PLACE); } else { - // Attribute deal with - // TODO(chenweihang): now here allow any types of attribute, maybe - // should add limits here - args_def->AppendAttribute(arg_type); + PADDLE_THROW(phi::errors::Unavailable( + "Unsupported kernel argument type `%s`.", arg_type.name())); } } } diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py index 0b67c6ba44a1da7a50331a4dde9f5392052b74dc..b83bfe911aa48e6cf63cc5afd5aa97239197b8f8 100644 --- a/tools/infrt/generate_phi_kernel_dialect.py +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -20,12 +20,12 @@ from get_compat_kernel_signature import get_compat_kernels_info #TODO @DannyIsFunny: more attr types need to be supported. attr_type_converter = { - "i": 'SI32Attr', - "b": 'BoolAttr', - "l": 'SI64Attr', - "f": 'F32Attr', - "NSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE": 'StrAttr', - "St6vectorIiSaIiEE": 'I32ArrayAttr' + "int": 'SI32Attr', + "bool": 'BoolAttr', + "int64_t": 'SI64Attr', + "float": 'F32Attr', + "string": 'StrAttr', + "vector": 'I32ArrayAttr' } target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"}