未验证 提交 9bc423b1 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick] Optimize dygraph performance part4 (#42306)

* Remove std::type_index in AttributeArdDef (#42122)

* polish some impl

* add lost attr type

* polish details

* fix error type

* polish in name lists

* add double attr

* adapt infrt attr parse

* add attr type test (#42263)

* opt attr eaque perf (#42272)
上级 5e303e7d
...@@ -203,12 +203,17 @@ struct ExtractAttribute<std::vector<double>> { ...@@ -203,12 +203,17 @@ struct ExtractAttribute<std::vector<double>> {
const std::string& attr_name_; const std::string& attr_name_;
}; };
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<proto::AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.which() - 1);
} }
inline proto::AttrType AttrTypeID(const Attribute& attr) {
return static_cast<proto::AttrType>(attr.which() - 1);
}
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) explicit AttrReader(const AttributeMap& attrs)
......
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
namespace paddle { namespace paddle {
...@@ -447,7 +448,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -447,7 +448,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto attr_reader = ctx->Attrs(); auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr_name = attr_names[i]; auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) {
// When attr is a vector_tensor or tensor, transform it to IntArray // When attr is a vector_tensor or tensor, transform it to IntArray
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
...@@ -498,16 +499,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -498,16 +499,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move( infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr(std::move( infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::IntArray({BOOST_GET_CONST(int, attr)})); phi::IntArray({BOOST_GET_CONST(int, attr)}));
} else { } else {
...@@ -517,20 +515,17 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -517,20 +515,17 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name)); attr_name));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
std::type_index(typeid(phi::Scalar))) {
if (ctx->HasAttr(attr_name)) { if (ctx->HasAttr(attr_name)) {
// TODO(chentianyu03): support other attrs later // TODO(chentianyu03): support other attrs later
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == proto::AttrType::FLOAT) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(float, attr))); phi::Scalar(BOOST_GET_CONST(float, attr)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr))); phi::Scalar(BOOST_GET_CONST(std::string, attr)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr))); phi::Scalar(BOOST_GET_CONST(int, attr)));
} else { } else {
...@@ -558,11 +553,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -558,11 +553,9 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
attr_name, infershape_input.size())); attr_name, infershape_input.size()));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -570,8 +563,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -570,8 +563,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -579,8 +571,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -579,8 +571,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -588,8 +579,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -588,8 +579,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -606,29 +596,24 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -606,29 +596,24 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr. // Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (attr_defs[i].type_index == std::type_index(typeid(bool))) { if (attr_defs[i].type_index == phi::AttributeType::BOOL) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(int))) { } else if (attr_defs[i].type_index == phi::AttributeType::INT32) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRING) {
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::BOOLS) {
std::type_index(typeid(std::vector<bool>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr)); BOOST_GET_CONST(std::vector<bool>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
std::type_index(typeid(std::vector<int>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr)); BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
std::type_index(typeid(std::vector<int64_t>))) { if (AttrTypeID(attr) == proto::AttrType::INTS) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
...@@ -638,20 +623,16 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -638,20 +623,16 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr)); BOOST_GET_CONST(std::vector<int64_t>, attr));
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) {
std::type_index(typeid(std::vector<float>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr)); BOOST_GET_CONST(std::vector<float>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr)); BOOST_GET_CONST(std::vector<double>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) {
std::type_index(typeid(std::vector<std::string>))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr)); BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) {
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPhiDataType( auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
...@@ -663,7 +644,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -663,7 +644,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} else if (ctx->HasInput(attr_name)) { } else if (ctx->HasInput(attr_name)) {
// convert from data // convert from data
if (attr_defs[i].type_index == std::type_index(typeid(int32_t))) { if (attr_defs[i].type_index == phi::AttributeType::INT32) {
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]);
......
...@@ -2415,21 +2415,19 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2415,21 +2415,19 @@ void OperatorWithKernel::BuildPhiKernelContext(
VLOG(4) << "Done outputs"; VLOG(4) << "Done outputs";
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) {
auto attr_iter = Attrs().find(attr_names[i]); auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // shape is in the attribute if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) == auto& attr = attr_iter->second;
std::type_index(typeid(std::vector<int64_t>))) { if (AttrTypeID(attr) == proto::AttrType::LONGS) {
pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(&BOOST_GET_CONST(int32_t, attr_iter->second), 1))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INTS) {
pt_kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INT) {
pt_kernel_context->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to IntArray when " "Unsupported cast op attribute `%s` to IntArray when "
...@@ -2446,23 +2444,17 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2446,23 +2444,17 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::move(experimental::MakePhiIntArrayFromVarList(ins_vector))); std::move(experimental::MakePhiIntArrayFromVarList(ins_vector)));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
std::type_index(typeid(phi::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
auto attr_iter = Attrs().find(attr_names[i]); auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // scalar is in the attribute if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]); auto& attr = attr_iter->second;
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == proto::AttrType::FLOAT) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else { } else {
...@@ -2477,11 +2469,9 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2477,11 +2469,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
std::move(experimental::MakePhiScalarFromVar(*ins_vector.front()))); std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = Attrs().at(attr_names[i]); auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2489,8 +2479,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2489,8 +2479,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2498,8 +2487,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2498,8 +2487,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2507,8 +2495,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2507,8 +2495,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2523,9 +2510,8 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2523,9 +2510,8 @@ void OperatorWithKernel::BuildPhiKernelContext(
attr_names[i])); attr_names[i]));
} }
} else { } else {
// TODO(chenweihang): support other attrs later
auto attr_it = attrs_.find(attr_names[i]); auto attr_it = attrs_.find(attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == phi::AttributeType::INT32) {
if (attr_it == attrs_.end()) { if (attr_it == attrs_.end()) {
auto in_it = ctx.inputs.find(attr_names[i]); auto in_it = ctx.inputs.find(attr_names[i]);
if (in_it != ctx.inputs.end()) { if (in_it != ctx.inputs.end()) {
...@@ -2542,33 +2528,28 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2542,33 +2528,28 @@ void OperatorWithKernel::BuildPhiKernelContext(
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int, attr_it->second)); BOOST_GET_CONST(int, attr_it->second));
} }
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(float, attr_it->second)); BOOST_GET_CONST(float, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(bool, attr_it->second)); BOOST_GET_CONST(bool, attr_it->second));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int64_t, attr_it->second)); BOOST_GET_CONST(int64_t, attr_it->second));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRING) {
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::string, attr_it->second)); BOOST_GET_CONST(std::string, attr_it->second));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) {
std::type_index(typeid(phi::DataType))) {
auto data_type = paddle::framework::TransToPhiDataType( auto data_type = paddle::framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr_it->second))); BOOST_GET_CONST(int, attr_it->second)));
pt_kernel_context->EmplaceBackAttr(data_type); pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
std::type_index(typeid(std::vector<int64_t>))) { if (AttrTypeID(attr_it->second) == proto::AttrType::LONGS) {
if (std::type_index(attr_it->second.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr_it->second)); BOOST_GET_CONST(std::vector<int64_t>, attr_it->second));
} else if (std::type_index(attr_it->second.type()) == } else if (AttrTypeID(attr_it->second) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second); BOOST_GET_CONST(std::vector<int>, attr_it->second);
...@@ -2576,17 +2557,14 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2576,17 +2557,14 @@ void OperatorWithKernel::BuildPhiKernelContext(
vector_int_attr.end()); vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr); pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second); BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr); pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) {
std::type_index(typeid(std::vector<std::string>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second)); BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) {
std::type_index(typeid(std::vector<float>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_it->second)); BOOST_GET_CONST(std::vector<float>, attr_it->second));
} else { } else {
......
...@@ -378,28 +378,23 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -378,28 +378,23 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == std::type_index(typeid(phi::IntArray))) { if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) {
if (attrs.find(attr_names[i]) != if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::LONG) {
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1))); std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
kernel_ctx->EmplaceBackAttr(vector_int_attr); kernel_ctx->EmplaceBackAttr(vector_int_attr);
} else { } else {
...@@ -423,24 +418,20 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -423,24 +418,20 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
std::move(experimental::MakePhiIntArrayFromVarList(variables))); std::move(experimental::MakePhiIntArrayFromVarList(variables)));
} }
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
std::type_index(typeid(phi::Scalar))) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check // TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs // attribtue type by attr_defs
if (attrs.find(attr_names[i]) != attrs.end() || if (attrs.find(attr_names[i]) != attrs.end() ||
default_attrs.find(attr_names[i]) != default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else { } else {
...@@ -460,17 +451,15 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -460,17 +451,15 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
auto& ins_vector = ins.at(attr_names[i]); auto& ins_vector = ins.at(attr_names[i]);
auto tensor_attr = auto tensor_attr =
experimental::MakePhiScalarFromVar(ins_vector[0]->Var()); experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == phi::AttributeType::INT32) {
int val = tensor_attr.template to<int>(); int val = tensor_attr.template to<int>();
kernel_ctx->EmplaceBackAttr(val); kernel_ctx->EmplaceBackAttr(val);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented("only support int here")); PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
std::type_index(typeid(std::vector<phi::Scalar>))) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -478,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -478,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -487,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -487,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -496,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -496,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -505,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -505,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::BOOLEANS) {
std::type_index(typeid(std::vector<bool>))) {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr); const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -521,49 +506,39 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -521,49 +506,39 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
attr_names[i])); attr_names[i]));
} }
} else { } else {
// TODO(chenweihang): support other attrs later
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (attr_defs[i].type_index == std::type_index(typeid(int))) { if (attr_defs[i].type_index == phi::AttributeType::INT32) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) { } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == phi::AttributeType::BOOL) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRING) {
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) {
std::type_index(typeid(phi::DataType))) {
auto data_type = framework::TransToPhiDataType( auto data_type = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type); kernel_ctx->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
std::type_index(typeid(std::vector<int64_t>))) { if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr)); BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end()); vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr); kernel_ctx->EmplaceBackAttr(vector_int64_attr);
} }
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) {
std::type_index(typeid(std::vector<std::string>))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr)); BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) {
std::type_index(typeid(std::vector<float>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -140,6 +140,68 @@ const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef( ...@@ -140,6 +140,68 @@ const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
return iter->second.cbegin()->second.args_def(); return iter->second.cbegin()->second.args_def();
} }
std::ostream& operator<<(std::ostream& os, AttributeType attr_type) {
switch (attr_type) {
case AttributeType::BOOL:
os << "bool";
break;
case AttributeType::INT32:
os << "int";
break;
case AttributeType::INT64:
os << "int64_t";
break;
case AttributeType::FLOAT32:
os << "float";
break;
case AttributeType::FLOAT64:
os << "double";
break;
case AttributeType::STRING:
os << "string";
break;
case AttributeType::BOOLS:
os << "vector<bool>";
break;
case AttributeType::INT32S:
os << "vector<int>";
break;
case AttributeType::INT64S:
os << "vector<int64_t>";
break;
case AttributeType::FLOAT32S:
os << "vector<float>";
break;
case AttributeType::FLOAT64S:
os << "vector<double>";
break;
case AttributeType::STRINGS:
os << "vector<string>";
break;
case AttributeType::SCALAR:
os << "Scalar";
break;
case AttributeType::SCALARS:
os << "vector<Scalar>";
break;
case AttributeType::INT_ARRAY:
os << "IntArray";
break;
case AttributeType::DATA_TYPE:
os << "DataType";
break;
case AttributeType::DATA_LAYOUT:
os << "DataLayout";
break;
case AttributeType::PLACE:
os << "Place";
break;
default:
os << "Undefined";
}
return os;
}
// print kernel info with json format: // print kernel info with json format:
// { // {
// "(CPU, Undefined(AnyLayout), complex64)": { // "(CPU, Undefined(AnyLayout), complex64)": {
...@@ -175,7 +237,7 @@ std::ostream& operator<<(std::ostream& os, const Kernel& kernel) { ...@@ -175,7 +237,7 @@ std::ostream& operator<<(std::ostream& os, const Kernel& kernel) {
need_comma = false; need_comma = false;
for (auto& arg_def : kernel.args_def().attribute_defs()) { for (auto& arg_def : kernel.args_def().attribute_defs()) {
if (need_comma) os << ","; if (need_comma) os << ",";
os << "\"" << arg_def.type_index.name() << "\""; os << "\"" << arg_def.type_index << "\"";
need_comma = true; need_comma = true;
} }
os << "]}"; os << "]}";
......
...@@ -122,11 +122,33 @@ struct TensorArgDef { ...@@ -122,11 +122,33 @@ struct TensorArgDef {
} }
}; };
// Align the original fluid Attribute type with lower overhead
enum class AttributeType {
UNDEFINED = 0,
BOOL,
INT32,
INT64,
FLOAT32,
FLOAT64,
STRING,
BOOLS,
INT32S,
INT64S,
FLOAT32S,
FLOAT64S,
STRINGS,
SCALAR,
SCALARS,
INT_ARRAY,
DATA_TYPE,
DATA_LAYOUT,
PLACE,
};
struct AttributeArgDef { struct AttributeArgDef {
std::type_index type_index; AttributeType type_index;
explicit AttributeArgDef(std::type_index type_index) explicit AttributeArgDef(AttributeType type_index) : type_index(type_index) {}
: type_index(type_index) {}
}; };
class KernelArgsDef { class KernelArgsDef {
...@@ -147,7 +169,7 @@ class KernelArgsDef { ...@@ -147,7 +169,7 @@ class KernelArgsDef {
output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index)); output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, type_index));
} }
void AppendAttribute(std::type_index type_index) { void AppendAttribute(AttributeType type_index) {
attribute_defs_.emplace_back(AttributeArgDef(type_index)); attribute_defs_.emplace_back(AttributeArgDef(type_index));
} }
...@@ -277,6 +299,8 @@ inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { ...@@ -277,6 +299,8 @@ inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) {
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, AttributeType attr_type);
std::ostream& operator<<(std::ostream& os, const Kernel& kernel); std::ostream& operator<<(std::ostream& os, const Kernel& kernel);
std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory); std::ostream& operator<<(std::ostream& os, KernelFactory& kernel_factory);
......
...@@ -163,11 +163,51 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -163,11 +163,51 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(bool))) {
args_def->AppendAttribute(AttributeType::BOOL);
} else if (arg_type == std::type_index(typeid(int))) {
args_def->AppendAttribute(AttributeType::INT32);
} else if (arg_type == std::type_index(typeid(int64_t))) {
args_def->AppendAttribute(AttributeType::INT64);
} else if (arg_type == std::type_index(typeid(float))) {
args_def->AppendAttribute(AttributeType::FLOAT32);
} else if (arg_type == std::type_index(typeid(double))) {
args_def->AppendAttribute(AttributeType::FLOAT64);
} else if (arg_type == std::type_index(typeid(std::string))) {
args_def->AppendAttribute(AttributeType::STRING);
} else if (arg_type ==
std::type_index(typeid(const std::vector<bool>&))) {
args_def->AppendAttribute(AttributeType::BOOLS);
} else if (arg_type == std::type_index(typeid(const std::vector<int>&))) {
args_def->AppendAttribute(AttributeType::INT32S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<int64_t>&))) {
args_def->AppendAttribute(AttributeType::INT64S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<float>&))) {
args_def->AppendAttribute(AttributeType::FLOAT32S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<double>&))) {
args_def->AppendAttribute(AttributeType::FLOAT64S);
} else if (arg_type ==
std::type_index(typeid(const std::vector<std::string>&))) {
args_def->AppendAttribute(AttributeType::STRINGS);
} else if (arg_type == std::type_index(typeid(const Scalar&))) {
args_def->AppendAttribute(AttributeType::SCALAR);
} else if (arg_type ==
std::type_index(typeid(const std::vector<Scalar>&))) {
args_def->AppendAttribute(AttributeType::SCALARS);
} else if (arg_type == std::type_index(typeid(const IntArray&))) {
args_def->AppendAttribute(AttributeType::INT_ARRAY);
} else if (arg_type == std::type_index(typeid(DataType))) {
args_def->AppendAttribute(AttributeType::DATA_TYPE);
} else if (arg_type == std::type_index(typeid(DataLayout))) {
args_def->AppendAttribute(AttributeType::DATA_LAYOUT);
} else if (arg_type == std::type_index(typeid(Place))) {
args_def->AppendAttribute(AttributeType::PLACE);
} else { } else {
// Attribute deal with PADDLE_THROW(phi::errors::Unavailable(
// TODO(chenweihang): now here allow any types of attribute, maybe "Unsupported kernel argument type `%s`.", arg_type.name()));
// should add limits here
args_def->AppendAttribute(arg_type);
} }
} }
} }
......
...@@ -73,6 +73,67 @@ TEST(KernelRegistry, SetFP32Input) { ...@@ -73,6 +73,67 @@ TEST(KernelRegistry, SetFP32Input) {
EXPECT_EQ(output_defs.at(0).dtype, phi::DataType::FLOAT16); EXPECT_EQ(output_defs.at(0).dtype, phi::DataType::FLOAT16);
} }
TEST(AttributeType, OStream) {
std::ostringstream oss;
oss << phi::AttributeType::UNDEFINED;
EXPECT_EQ(oss.str(), "Undefined");
oss.str("");
oss << phi::AttributeType::BOOL;
EXPECT_EQ(oss.str(), "bool");
oss.str("");
oss << phi::AttributeType::INT32;
EXPECT_EQ(oss.str(), "int");
oss.str("");
oss << phi::AttributeType::INT64;
EXPECT_EQ(oss.str(), "int64_t");
oss.str("");
oss << phi::AttributeType::FLOAT32;
EXPECT_EQ(oss.str(), "float");
oss.str("");
oss << phi::AttributeType::FLOAT64;
EXPECT_EQ(oss.str(), "double");
oss.str("");
oss << phi::AttributeType::STRING;
EXPECT_EQ(oss.str(), "string");
oss.str("");
oss << phi::AttributeType::BOOLS;
EXPECT_EQ(oss.str(), "vector<bool>");
oss.str("");
oss << phi::AttributeType::INT32S;
EXPECT_EQ(oss.str(), "vector<int>");
oss.str("");
oss << phi::AttributeType::INT64S;
EXPECT_EQ(oss.str(), "vector<int64_t>");
oss.str("");
oss << phi::AttributeType::FLOAT32S;
EXPECT_EQ(oss.str(), "vector<float>");
oss.str("");
oss << phi::AttributeType::FLOAT64S;
EXPECT_EQ(oss.str(), "vector<double>");
oss.str("");
oss << phi::AttributeType::STRINGS;
EXPECT_EQ(oss.str(), "vector<string>");
oss.str("");
oss << phi::AttributeType::SCALAR;
EXPECT_EQ(oss.str(), "Scalar");
oss.str("");
oss << phi::AttributeType::SCALARS;
EXPECT_EQ(oss.str(), "vector<Scalar>");
oss.str("");
oss << phi::AttributeType::INT_ARRAY;
EXPECT_EQ(oss.str(), "IntArray");
oss.str("");
oss << phi::AttributeType::DATA_TYPE;
EXPECT_EQ(oss.str(), "DataType");
oss.str("");
oss << phi::AttributeType::DATA_LAYOUT;
EXPECT_EQ(oss.str(), "DataLayout");
oss.str("");
oss << phi::AttributeType::PLACE;
EXPECT_EQ(oss.str(), "Place");
oss.str("");
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
......
...@@ -20,12 +20,12 @@ from get_compat_kernel_signature import get_compat_kernels_info ...@@ -20,12 +20,12 @@ from get_compat_kernel_signature import get_compat_kernels_info
#TODO @DannyIsFunny: more attr types need to be supported. #TODO @DannyIsFunny: more attr types need to be supported.
attr_type_converter = { attr_type_converter = {
"i": 'SI32Attr', "int": 'SI32Attr',
"b": 'BoolAttr', "bool": 'BoolAttr',
"l": 'SI64Attr', "int64_t": 'SI64Attr',
"f": 'F32Attr', "float": 'F32Attr',
"NSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE": 'StrAttr', "string": 'StrAttr',
"St6vectorIiSaIiEE": 'I32ArrayAttr' "vector<int>": 'I32ArrayAttr'
} }
target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"} target_type_converter = {"CPU": "CPU", "GPU": "GPU", "Undefined": "UNK"}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册