未验证 提交 ca909408 编写于 作者: C Chen Weihang 提交者: GitHub

opt attr eaque perf (#42272)

上级 88d68c08
...@@ -203,12 +203,17 @@ struct ExtractAttribute<std::vector<double>> { ...@@ -203,12 +203,17 @@ struct ExtractAttribute<std::vector<double>> {
const std::string& attr_name_; const std::string& attr_name_;
}; };
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<proto::AttrType>(tmp.which() - 1); return static_cast<proto::AttrType>(tmp.which() - 1);
} }
inline proto::AttrType AttrTypeID(const Attribute& attr) {
return static_cast<proto::AttrType>(attr.which() - 1);
}
class AttrReader { class AttrReader {
public: public:
explicit AttrReader(const AttributeMap& attrs) explicit AttrReader(const AttributeMap& attrs)
......
...@@ -501,16 +501,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -501,16 +501,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} else if (ctx->HasAttr(attr_name)) { } else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move( infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr(std::move( infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::IntArray({BOOST_GET_CONST(int, attr)})); phi::IntArray({BOOST_GET_CONST(int, attr)}));
} else { } else {
...@@ -524,15 +521,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -524,15 +521,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
if (ctx->HasAttr(attr_name)) { if (ctx->HasAttr(attr_name)) {
// TODO(chentianyu03): support other attrs later // TODO(chentianyu03): support other attrs later
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == proto::AttrType::FLOAT) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(float, attr))); phi::Scalar(BOOST_GET_CONST(float, attr)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr))); phi::Scalar(BOOST_GET_CONST(std::string, attr)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr))); phi::Scalar(BOOST_GET_CONST(int, attr)));
} else { } else {
...@@ -562,8 +557,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -562,8 +557,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
} }
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -571,8 +565,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -571,8 +565,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -580,8 +573,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -580,8 +573,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -589,8 +581,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -589,8 +581,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
infer_meta_context.EmplaceBackAttr(std::move(scalar_list)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -624,8 +615,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -624,8 +615,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr)); BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
......
...@@ -2420,18 +2420,16 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2420,18 +2420,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) {
auto attr_iter = Attrs().find(attr_names[i]); auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // shape is in the attribute if (attr_iter != Attrs().end()) { // shape is in the attribute
if (std::type_index(attr_iter->second.type()) == auto& attr = attr_iter->second;
std::type_index(typeid(std::vector<int64_t>))) { if (AttrTypeID(attr) == proto::AttrType::LONGS) {
pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(&BOOST_GET_CONST(int32_t, attr_iter->second), 1))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INTS) {
pt_kernel_context->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INT) {
pt_kernel_context->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to IntArray when " "Unsupported cast op attribute `%s` to IntArray when "
...@@ -2449,21 +2447,16 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2449,21 +2447,16 @@ void OperatorWithKernel::BuildPhiKernelContext(
} }
} }
} else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { } else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
// TODO(chenweihang): support other attrs later
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
auto attr_iter = Attrs().find(attr_names[i]); auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // scalar is in the attribute if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = Attrs().at(attr_names[i]); auto& attr = attr_iter->second;
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == proto::AttrType::FLOAT) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::INT) {
std::type_index(typeid(int))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else { } else {
...@@ -2480,8 +2473,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2480,8 +2473,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = Attrs().at(attr_names[i]); auto& attr = Attrs().at(attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2489,8 +2481,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2489,8 +2481,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2498,8 +2489,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2498,8 +2489,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2507,8 +2497,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2507,8 +2497,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -2559,12 +2548,10 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2559,12 +2548,10 @@ void OperatorWithKernel::BuildPhiKernelContext(
BOOST_GET_CONST(int, attr_it->second))); BOOST_GET_CONST(int, attr_it->second)));
pt_kernel_context->EmplaceBackAttr(data_type); pt_kernel_context->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
if (std::type_index(attr_it->second.type()) == if (AttrTypeID(attr_it->second) == proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context->EmplaceBackAttr( pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr_it->second)); BOOST_GET_CONST(std::vector<int64_t>, attr_it->second));
} else if (std::type_index(attr_it->second.type()) == } else if (AttrTypeID(attr_it->second) == proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second); BOOST_GET_CONST(std::vector<int>, attr_it->second);
......
...@@ -382,20 +382,16 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -382,20 +382,16 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
if (attrs.find(attr_names[i]) != if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute attrs.end()) { // shape is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::LONG) {
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1))); std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index == phi::AttributeType::INT32S) { } else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
...@@ -429,15 +425,13 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -429,15 +425,13 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
default_attrs.find(attr_names[i]) != default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == std::type_index(typeid(float))) { if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::STRING) {
std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
std::type_index(typeid(int))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr)))); std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else { } else {
...@@ -465,8 +459,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -465,8 +459,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
} }
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) { } else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -474,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -474,8 +467,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -483,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -483,8 +475,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOATS) {
std::type_index(typeid(std::vector<float>))) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -492,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -492,8 +483,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT64S) {
std::type_index(typeid(std::vector<double>))) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -501,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -501,8 +491,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
scalar_list.emplace_back(val); scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::BOOLEANS) {
std::type_index(typeid(std::vector<bool>))) {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr); const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<phi::Scalar> scalar_list; std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size()); scalar_list.reserve(vec.size());
...@@ -534,12 +523,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -534,12 +523,10 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
BOOST_GET_CONST(int, attr))); BOOST_GET_CONST(int, attr)));
kernel_ctx->EmplaceBackAttr(data_type); kernel_ctx->EmplaceBackAttr(data_type);
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { } else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
if (std::type_index(attr.type()) == if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr( kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr)); BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) == } else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
std::type_index(typeid(std::vector<int>))) {
// Emplace Back Attr according to the type of Phi_Kernel args. // Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册