diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 7026cc7cf1aa3acdc27728350b7572a0aa8970f7..6c4171a5b896aaf9c34ba62e1e2d16bd02fc5551 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -203,12 +203,17 @@ struct ExtractAttribute> { const std::string& attr_name_; }; + template inline proto::AttrType AttrTypeID() { Attribute tmp = T(); return static_cast(tmp.which() - 1); } +inline proto::AttrType AttrTypeID(const Attribute& attr) { + return static_cast(attr.which() - 1); +} + class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index d7a2a42ca7dc751f8a6834ef4b3e53e2e0467523..91dea654ee66283157874c8d93eeb4c7399a1f3f 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/tensor_utils.h" namespace paddle { @@ -447,7 +448,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, auto attr_reader = ctx->Attrs(); for (size_t i = 0; i < attr_names.size(); ++i) { auto& attr_name = attr_names[i]; - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { // When attr is a vector_tensor or tensor, transform it to IntArray if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); @@ -498,16 +499,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } else if (ctx->HasAttr(attr_name)) { auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + if (AttrTypeID(attr) == proto::AttrType::INTS) { infer_meta_context.EmplaceBackAttr(std::move( phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { infer_meta_context.EmplaceBackAttr(std::move( phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { + } else if (AttrTypeID(attr) == proto::AttrType::INT) { infer_meta_context.EmplaceBackAttr( phi::IntArray({BOOST_GET_CONST(int, attr)})); } else { @@ -517,20 +515,17 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { if (ctx->HasAttr(attr_name)) { // TODO(chentianyu03): support other attrs later auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + if (AttrTypeID(attr) == proto::AttrType::FLOAT) { infer_meta_context.EmplaceBackAttr( phi::Scalar(BOOST_GET_CONST(float, attr))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { + } else if (AttrTypeID(attr) == proto::AttrType::STRING) { infer_meta_context.EmplaceBackAttr( phi::Scalar(BOOST_GET_CONST(std::string, attr))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { + } else if (AttrTypeID(attr) == proto::AttrType::INT) { infer_meta_context.EmplaceBackAttr( phi::Scalar(BOOST_GET_CONST(int, attr))); } else { @@ -558,11 +553,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name, infershape_input.size())); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = attr_reader.GetAttr(attr_name); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + if (AttrTypeID(attr) == proto::AttrType::INTS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -570,8 +563,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, scalar_list.emplace_back(val); } infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -579,8 +571,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, scalar_list.emplace_back(val); } infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -588,8 +579,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, scalar_list.emplace_back(val); } infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -606,29 +596,24 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } else if (ctx->HasAttr(attr_name)) { // Emplace Back Attr according to the type of attr. auto& attr = attr_reader.GetAttr(attr_name); - if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + if (attr_defs[i].type_index == phi::AttributeType::BOOL) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOLS) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { + if (AttrTypeID(attr) == proto::AttrType::INTS) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), @@ -638,20 +623,16 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT64S) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { infer_meta_context.EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = paddle::framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr))); @@ -663,7 +644,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } } else if (ctx->HasInput(attr_name)) { // convert from data - if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { if (ctx->IsRuntime()) { auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f60d3ced4566135e0cc58727c8b59de5c4f49821..7468aaedecec5ef5c88aebb980133770924ab6db 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2415,21 +2415,19 @@ void OperatorWithKernel::BuildPhiKernelContext( VLOG(4) << "Done outputs"; for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { auto attr_iter = Attrs().find(attr_names[i]); if (attr_iter != Attrs().end()) { // shape is in the attribute - if (std::type_index(attr_iter->second.type()) == - std::type_index(typeid(std::vector))) { - pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - BOOST_GET_CONST(std::vector, attr_iter->second)))); - } else if (std::type_index(attr_iter->second.type()) == - std::type_index(typeid(std::vector))) { - pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - BOOST_GET_CONST(std::vector, attr_iter->second)))); - } else if (std::type_index(attr_iter->second.type()) == - std::type_index(typeid(int32_t))) { + auto& attr = attr_iter->second; + if (AttrTypeID(attr) == proto::AttrType::LONGS) { pt_kernel_context->EmplaceBackAttr(std::move( - phi::IntArray(&BOOST_GET_CONST(int32_t, attr_iter->second), 1))); + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (AttrTypeID(attr) == proto::AttrType::INTS) { + pt_kernel_context->EmplaceBackAttr(std::move( + phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (AttrTypeID(attr) == proto::AttrType::INT) { + pt_kernel_context->EmplaceBackAttr( + std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to IntArray when " @@ -2446,23 +2444,17 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiIntArrayFromVarList(ins_vector))); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { - // TODO(chenweihang): support other attrs later - // TODO(zhangyunfei): Scalar should hold scaler type, and we should check - // attribtue type by attr_defs + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { auto attr_iter = Attrs().find(attr_names[i]); if (attr_iter != Attrs().end()) { // scalar is in the attribute - auto& attr = Attrs().at(attr_names[i]); - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + auto& attr = attr_iter->second; + if (AttrTypeID(attr) == proto::AttrType::FLOAT) { pt_kernel_context->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { + } else if (AttrTypeID(attr) == proto::AttrType::STRING) { pt_kernel_context->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { + } else if (AttrTypeID(attr) == proto::AttrType::INT) { pt_kernel_context->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); } else { @@ -2477,11 +2469,9 @@ void OperatorWithKernel::BuildPhiKernelContext( std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = Attrs().at(attr_names[i]); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + if (AttrTypeID(attr) == proto::AttrType::INTS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -2489,8 +2479,7 @@ void OperatorWithKernel::BuildPhiKernelContext( scalar_list.emplace_back(val); } pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::LONGS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -2498,8 +2487,7 @@ void OperatorWithKernel::BuildPhiKernelContext( scalar_list.emplace_back(val); } pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -2507,8 +2495,7 @@ void OperatorWithKernel::BuildPhiKernelContext( scalar_list.emplace_back(val); } pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -2523,9 +2510,8 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_names[i])); } } else { - // TODO(chenweihang): support other attrs later auto attr_it = attrs_.find(attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { if (attr_it == attrs_.end()) { auto in_it = ctx.inputs.find(attr_names[i]); if (in_it != ctx.inputs.end()) { @@ -2542,33 +2528,28 @@ void OperatorWithKernel::BuildPhiKernelContext( pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(int, attr_it->second)); } - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(float, attr_it->second)); - } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(bool, attr_it->second)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(int64_t, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::string, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = paddle::framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr_it->second))); pt_kernel_context->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - if (std::type_index(attr_it->second.type()) == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { + if (AttrTypeID(attr_it->second) == proto::AttrType::LONGS) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr_it->second)); - } else if (std::type_index(attr_it->second.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr_it->second) == proto::AttrType::INTS) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr_it->second); @@ -2576,17 +2557,14 @@ void OperatorWithKernel::BuildPhiKernelContext( vector_int_attr.end()); pt_kernel_context->EmplaceBackAttr(vector_int64_attr); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr_it->second); pt_kernel_context->EmplaceBackAttr(vector_int_attr); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr_it->second)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr_it->second)); } else { diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index dedb6a382efa6f0be2d6de9d07c3cd4580d0d453..5c7f337dc6cf4290d9a5ad43f8ee98f28d29a97e 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -378,28 +378,23 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, } for (size_t i = 0; i < attr_names.size(); ++i) { - if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { + if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { if (attrs.find(attr_names[i]) != attrs.end()) { // shape is in the attribute auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { kernel_ctx->EmplaceBackAttr(std::move( phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { kernel_ctx->EmplaceBackAttr(std::move( phi::IntArray(BOOST_GET_CONST(std::vector, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int64_t))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::LONG) { kernel_ctx->EmplaceBackAttr( std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int32_t))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) { kernel_ctx->EmplaceBackAttr( std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); kernel_ctx->EmplaceBackAttr(vector_int_attr); } else { @@ -423,24 +418,20 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, std::move(experimental::MakePhiIntArrayFromVarList(variables))); } } - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::Scalar))) { - // TODO(chenweihang): support other attrs later + } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { // TODO(zhangyunfei): Scalar should hold scaler type, and we should check // attribtue type by attr_defs if (attrs.find(attr_names[i]) != attrs.end() || default_attrs.find(attr_names[i]) != default_attrs.end()) { // scalar is in the attribute auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (std::type_index(attr.type()) == std::type_index(typeid(float))) { + if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) { kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::string))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::STRING) { kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(int))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) { kernel_ctx->EmplaceBackAttr( std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); } else { @@ -460,17 +451,15 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, auto& ins_vector = ins.at(attr_names[i]); auto tensor_attr = experimental::MakePhiScalarFromVar(ins_vector[0]->Var()); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { int val = tensor_attr.template to(); kernel_ctx->EmplaceBackAttr(val); } else { PADDLE_THROW(platform::errors::Unimplemented("only support int here")); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -478,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, scalar_list.emplace_back(val); } kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -487,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, scalar_list.emplace_back(val); } kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOATS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -496,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, scalar_list.emplace_back(val); } kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT64S) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -505,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, scalar_list.emplace_back(val); } kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::BOOLEANS) { const auto& vec = BOOST_GET_CONST(std::vector, attr); std::vector scalar_list; scalar_list.reserve(vec.size()); @@ -521,49 +506,39 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, attr_names[i])); } } else { - // TODO(chenweihang): support other attrs later - auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); - if (attr_defs[i].type_index == std::type_index(typeid(int))) { + if (attr_defs[i].type_index == phi::AttributeType::INT32) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(float))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { + } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); - } else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::string))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRING) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(phi::DataType))) { + } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { auto data_type = framework::TransToPhiDataType( static_cast( BOOST_GET_CONST(int, attr))); kernel_ctx->EmplaceBackAttr(data_type); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { - if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { + if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { kernel_ctx->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (std::type_index(attr.type()) == - std::type_index(typeid(std::vector))) { + } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { // Emplace Back Attr according to the type of Phi_Kernel args. const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); const std::vector vector_int64_attr(vector_int_attr.begin(), vector_int_attr.end()); kernel_ctx->EmplaceBackAttr(vector_int64_attr); } - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { kernel_ctx->EmplaceBackAttr( BOOST_GET_CONST(std::vector, attr)); - } else if (attr_defs[i].type_index == - std::type_index(typeid(std::vector))) { + } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); } else { PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index d3fd2e0204e54f1cbaed8049ac83abcad7efed7a..b4ae3e00f2dc4783632aad506528de9602c1e9b7 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -140,6 +140,68 @@ const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( return iter->second.cbegin()->second.args_def(); } +std::ostream& operator<<(std::ostream& os, AttributeType attr_type) { + switch (attr_type) { + case AttributeType::BOOL: + os << "bool"; + break; + case AttributeType::INT32: + os << "int"; + break; + case AttributeType::INT64: + os << "int64_t"; + break; + case AttributeType::FLOAT32: + os << "float"; + break; + case AttributeType::FLOAT64: + os << "double"; + break; + case AttributeType::STRING: + os << "string"; + break; + case AttributeType::BOOLS: + os << "vector"; + break; + case AttributeType::INT32S: + os << "vector"; + break; + case AttributeType::INT64S: + os << "vector"; + break; + case AttributeType::FLOAT32S: + os << "vector"; + break; + case AttributeType::FLOAT64S: + os << "vector"; + break; + case AttributeType::STRINGS: + os << "vector"; + break; + case AttributeType::SCALAR: + os << "Scalar"; + break; + case AttributeType::SCALARS: + os << "vector"; + break; + case AttributeType::INT_ARRAY: + os << "IntArray"; + break; + case AttributeType::DATA_TYPE: + os << "DataType"; + break; + case AttributeType::DATA_LAYOUT: + os << "DataLayout"; + break; + case AttributeType::PLACE: + os << "Place"; + break; + default: + os << "Undefined"; + } + return os; +} + // print kernel info with json format: // { // "(CPU, Undefined(AnyLayout), complex64)": { @@ -175,7 +237,7 @@ std::ostream& operator<<(std::ostream& os, const Kernel& kernel) { need_comma = false; for (auto& arg_def : kernel.args_def().attribute_defs()) { if (need_comma) os << ","; - os << "\"" << arg_def.type_index.name() << "\""; + os << "\"" << arg_def.type_index << "\""; need_comma = true; } os << "]}"; diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 812b6222cb5e293ffdaa1051462dcf945052fef9..bce6e6c383fea8e485c51ad954c621bce80b7c7c 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -122,11 +122,33 @@ struct TensorArgDef { } }; +// Align the original fluid Attribute type with lower overhead +enum class AttributeType { + UNDEFINED = 0, + BOOL, + INT32, + INT64, + FLOAT32, + FLOAT64, + STRING, + BOOLS, + INT32S, + INT64S, + FLOAT32S, + FLOAT64S, + STRINGS, + SCALAR, + SCALARS, + INT_ARRAY, + DATA_TYPE, + DATA_LAYOUT, + PLACE, +}; + struct AttributeArgDef { - std::type_index type_index; + AttributeType type_index; - explicit AttributeArgDef(std::type_index type_index) - : type_index(type_index) {} + explicit AttributeArgDef(AttributeType type_index) : type_index(type_index) {} }; class KernelArgsDef { @@ -147,7 +169,7 @@ class KernelArgsDef { output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index)); } - void AppendAttribute(std::type_index type_index) { + void AppendAttribute(AttributeType type_index) { attribute_defs_.emplace_back(AttributeArgDef(type_index)); } @@ -277,6 +299,8 @@ inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { return os; } +std::ostream& operator<<(std::ostream& os, AttributeType attr_type); + std::ostream& operator<<(std::ostream& os, const Kernel& kernel); std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory); diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 356ab58f40726b93eb0f4f57c06a7b873b8d5153..36ab9c081cc374d38272dd435594f75a86cff3ff 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -163,11 +163,51 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(bool))) { + args_def->AppendAttribute(AttributeType::BOOL); + } else if (arg_type == std::type_index(typeid(int))) { + args_def->AppendAttribute(AttributeType::INT32); + } else if (arg_type == std::type_index(typeid(int64_t))) { + args_def->AppendAttribute(AttributeType::INT64); + } else if (arg_type == std::type_index(typeid(float))) { + args_def->AppendAttribute(AttributeType::FLOAT32); + } else if (arg_type == std::type_index(typeid(double))) { + args_def->AppendAttribute(AttributeType::FLOAT64); + } else if (arg_type == std::type_index(typeid(std::string))) { + args_def->AppendAttribute(AttributeType::STRING); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::BOOLS); + } else if (arg_type == std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::INT32S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::INT64S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::FLOAT32S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::FLOAT64S); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::STRINGS); + } else if (arg_type == std::type_index(typeid(const Scalar&))) { + args_def->AppendAttribute(AttributeType::SCALAR); + } else if (arg_type == + std::type_index(typeid(const std::vector&))) { + args_def->AppendAttribute(AttributeType::SCALARS); + } else if (arg_type == std::type_index(typeid(const IntArray&))) { + args_def->AppendAttribute(AttributeType::INT_ARRAY); + } else if (arg_type == std::type_index(typeid(DataType))) { + args_def->AppendAttribute(AttributeType::DATA_TYPE); + } else if (arg_type == std::type_index(typeid(DataLayout))) { + args_def->AppendAttribute(AttributeType::DATA_LAYOUT); + } else if (arg_type == std::type_index(typeid(Place))) { + args_def->AppendAttribute(AttributeType::PLACE); } else { - // Attribute deal with - // TODO(chenweihang): now here allow any types of attribute, maybe - // should add limits here - args_def->AppendAttribute(arg_type); + PADDLE_THROW(phi::errors::Unavailable( + "Unsupported kernel argument type `%s`.", arg_type.name())); } } } diff --git a/paddle/phi/tests/core/test_kernel_factory.cc b/paddle/phi/tests/core/test_kernel_factory.cc index cb4b50f5b6c3dce52f5d188b86d748a59cd41f1e..490d4967eeba252d0869ae747607d19e2df561e3 100644 --- a/paddle/phi/tests/core/test_kernel_factory.cc +++ b/paddle/phi/tests/core/test_kernel_factory.cc @@ -73,6 +73,67 @@ TEST(KernelRegistry, SetFP32Input) { EXPECT_EQ(output_defs.at(0).dtype, phi::DataType::FLOAT16); } +TEST(AttributeType, OStream) { + std::ostringstream oss; + oss << phi::AttributeType::UNDEFINED; + EXPECT_EQ(oss.str(), "Undefined"); + oss.str(""); + oss << phi::AttributeType::BOOL; + EXPECT_EQ(oss.str(), "bool"); + oss.str(""); + oss << phi::AttributeType::INT32; + EXPECT_EQ(oss.str(), "int"); + oss.str(""); + oss << phi::AttributeType::INT64; + EXPECT_EQ(oss.str(), "int64_t"); + oss.str(""); + oss << phi::AttributeType::FLOAT32; + EXPECT_EQ(oss.str(), "float"); + oss.str(""); + oss << phi::AttributeType::FLOAT64; + EXPECT_EQ(oss.str(), "double"); + oss.str(""); + oss << phi::AttributeType::STRING; + EXPECT_EQ(oss.str(), "string"); + oss.str(""); + oss << phi::AttributeType::BOOLS; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::INT32S; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::INT64S; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::FLOAT32S; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::FLOAT64S; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::STRINGS; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::SCALAR; + EXPECT_EQ(oss.str(), "Scalar"); + oss.str(""); + oss << phi::AttributeType::SCALARS; + EXPECT_EQ(oss.str(), "vector"); + oss.str(""); + oss << phi::AttributeType::INT_ARRAY; + EXPECT_EQ(oss.str(), "IntArray"); + oss.str(""); + oss << phi::AttributeType::DATA_TYPE; + EXPECT_EQ(oss.str(), "DataType"); + oss.str(""); + oss << phi::AttributeType::DATA_LAYOUT; + EXPECT_EQ(oss.str(), "DataLayout"); + oss.str(""); + oss << phi::AttributeType::PLACE; + EXPECT_EQ(oss.str(), "Place"); + oss.str(""); +} + } // namespace tests } // namespace phi diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py index 0b67c6ba44a1da7a50331a4dde9f5392052b74dc..b83bfe911aa48e6cf63cc5afd5aa97239197b8f8 100644 --- a/tools/infrt/generate_phi_kernel_dialect.py +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -20,12 +20,12 @@ from get_compat_kernel_signature import get_compat_kernels_info #TODO @DannyIsFunny: more attr types need to be supported. attr_type_converter = { - "i": 'SI32Attr', - "b": 'BoolAttr', - "l": 'SI64Attr', - "f": 'F32Attr', - "NSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE": 'StrAttr', - "St6vectorIiSaIiEE": 'I32ArrayAttr' + "int": 'SI32Attr', + "bool": 'BoolAttr', + "int64_t": 'SI64Attr', + "float": 'F32Attr', + "string": 'StrAttr', + "vector": 'I32ArrayAttr' } target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"}