未验证 提交 5063546a 编写于 作者: C Chen Weihang 提交者: GitHub

Optimize attribute selected performence (#42294)

* opt attr eaque perf

* opt attr select code

* fix one hot infermeta

* polish get attr impl

* fix tests failed

* add testcases
上级 7cb49539
...@@ -242,7 +242,7 @@ class AttrReader { ...@@ -242,7 +242,7 @@ class AttrReader {
return *attr_value; return *attr_value;
} }
inline const Attribute& GetAttr(const std::string& name) const { const Attribute* GetAttr(const std::string& name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
bool found = it != attrs_.end(); bool found = it != attrs_.end();
if (!found) { if (!found) {
...@@ -251,11 +251,10 @@ class AttrReader { ...@@ -251,11 +251,10 @@ class AttrReader {
found = it != default_attrs_->end(); found = it != default_attrs_->end();
} }
} }
PADDLE_ENFORCE_EQ(found, true, if (found) {
platform::errors::NotFound( return &it->second;
"Attribute (%s) should be in AttributeMap.", name)); }
return nullptr;
return it->second;
} }
private: private:
......
...@@ -52,8 +52,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -52,8 +52,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
} }
paddle::any Attr(const std::string& name) const override { paddle::any Attr(const std::string& name) const override {
auto& attr = ctx_.Attrs().GetAttr(name); auto* attr = ctx_.Attrs().GetAttr(name);
return GetAttrValue(attr); PADDLE_ENFORCE_NOT_NULL(
attr, platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));
return GetAttrValue(*attr);
} }
size_t InputSize(const std::string& name) const override { size_t InputSize(const std::string& name) const override {
...@@ -450,216 +453,255 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -450,216 +453,255 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto attr_reader = ctx->Attrs(); auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
auto& attr_name = attr_names[i]; auto& attr_name = attr_names[i];
if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { VLOG(6) << "BuildInferMetaContext: " << attr_name << ": "
// When attr is a vector_tensor or tensor, transform it to IntArray << attr_defs[i].type_index;
if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { auto* attr_ptr = attr_reader.GetAttr(attr_name);
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); switch (attr_defs[i].type_index) {
if (ctx->IsRuntime()) { case phi::AttributeType::SCALAR:
// If is in runtime, we will get tensor's value for IntArray if (attr_ptr) {
// and push it into attrs auto& attr = *attr_ptr;
std::vector<Variable*> vars; switch (AttrTypeID(attr)) {
vars.reserve(infershape_inputs.size()); case framework::proto::AttrType::FLOAT:
for (size_t i = 0; i < infershape_inputs.size(); i++) { infer_meta_context.EmplaceBackAttr(
vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i])); phi::Scalar(BOOST_GET_CONST(float, attr)));
break;
case framework::proto::AttrType::INT:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(int, attr)));
break;
case framework::proto::AttrType::STRING:
infer_meta_context.EmplaceBackAttr(
phi::Scalar(BOOST_GET_CONST(std::string, attr)));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"InferMetaContext.",
attr_name));
} }
if (infershape_inputs.size() != 1) { } else if (ctx->HasInput(attr_name)) {
infer_meta_context.EmplaceBackAttr( auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name));
std::move(experimental::MakePhiIntArrayFromVarList(vars))); if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
infer_meta_context.EmplaceBackAttr(
std::move(experimental::MakePhiScalarFromVar(*var)));
} else {
phi::Scalar tensor_scalar(-1);
tensor_scalar.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar));
}
} else { } else {
infer_meta_context.EmplaceBackAttr( PADDLE_THROW(platform::errors::InvalidArgument(
std::move(experimental::MakePhiIntArrayFromVar(*vars[0]))); "Invalid input.size() when cast op attribute `%s` to Scalar, "
"expected 1, but actually is %d .",
attr_name, infershape_input.size()));
} }
} else { } else {
// If is not in runtime, we will set default value(-1) for IntArray // do nothing, skip current attr
std::vector<VarDesc*> vars; }
vars.reserve(infershape_inputs.size()); break;
for (size_t i = 0; i < infershape_inputs.size(); ++i) { case phi::AttributeType::INT_ARRAY:
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i])); // When attr is a vector_tensor or tensor, transform it to IntArray
if (attr_ptr) {
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::INTS:
infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
break;
case framework::proto::AttrType::LONGS:
infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
break;
case framework::proto::AttrType::INT:
infer_meta_context.EmplaceBackAttr(
phi::IntArray({BOOST_GET_CONST(int, attr)}));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to IntArray when "
"construct InferMetaContext.",
attr_name));
} }
} else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) {
int64_t num_ele = 0; auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name));
if (vars.size() == 1) { if (ctx->IsRuntime()) {
num_ele = 1; // If is in runtime, we will get tensor's value for IntArray
const auto& tensor_dims = vars[0]->GetShape(); // and push it into attrs
for (size_t i = 0; i < tensor_dims.size(); ++i) { std::vector<Variable*> vars;
num_ele *= tensor_dims[i]; vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
vars.push_back(BOOST_GET_CONST(Variable*, infershape_inputs[i]));
} }
if (infershape_inputs.size() != 1) {
if (num_ele <= 0) { infer_meta_context.EmplaceBackAttr(
PADDLE_THROW(platform::errors::Unimplemented( std::move(experimental::MakePhiIntArrayFromVarList(vars)));
"Invalid number for construct phi::IntArray, expected " } else {
"number > 0, but actually is %d. ", infer_meta_context.EmplaceBackAttr(
num_ele)); std::move(experimental::MakePhiIntArrayFromVar(*vars[0])));
} }
} else { } else {
num_ele = vars.size(); // If is not in runtime, we will set default value(-1) for IntArray
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); ++i) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}
int64_t num_ele = 0;
if (vars.size() == 1) {
num_ele = 1;
const auto& tensor_dims = vars[0]->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
if (num_ele <= 0) {
PADDLE_THROW(platform::errors::Unimplemented(
"Invalid number for construct phi::IntArray, expected "
"number > 0, but actually is %d. ",
num_ele));
}
} else {
num_ele = vars.size();
}
phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
} }
phi::IntArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
infer_meta_context.EmplaceBackAttr(std::move(tensor_attr));
}
} else if (ctx->HasAttr(attr_name)) {
auto& attr = attr_reader.GetAttr(attr_name);
if (AttrTypeID(attr) == proto::AttrType::INTS) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INT) {
infer_meta_context.EmplaceBackAttr(
phi::IntArray({BOOST_GET_CONST(int, attr)}));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( // do nothing, skip current attr
"Unsupported cast op attribute `%s` to IntArray when "
"construct InferMetaContext.",
attr_name));
} }
} break;
} else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) { case phi::AttributeType::SCALARS:
if (ctx->HasAttr(attr_name)) { if (attr_ptr) {
// TODO(chentianyu03): support other attrs later auto& attr = *attr_ptr;
auto& attr = attr_reader.GetAttr(attr_name); switch (AttrTypeID(attr)) {
if (AttrTypeID(attr) == proto::AttrType::FLOAT) { case framework::proto::AttrType::INTS: {
infer_meta_context.EmplaceBackAttr( const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
phi::Scalar(BOOST_GET_CONST(float, attr))); std::vector<phi::Scalar> scalar_list;
} else if (AttrTypeID(attr) == proto::AttrType::STRING) { scalar_list.reserve(vec.size());
infer_meta_context.EmplaceBackAttr( for (const auto& val : vec) {
phi::Scalar(BOOST_GET_CONST(std::string, attr))); scalar_list.emplace_back(val);
} else if (AttrTypeID(attr) == proto::AttrType::INT) { }
infer_meta_context.EmplaceBackAttr( infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
phi::Scalar(BOOST_GET_CONST(int, attr))); } break;
} else { case framework::proto::AttrType::LONGS: {
PADDLE_THROW(platform::errors::Unimplemented( const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
"Unsupported cast op attribute `%s` to Scalar when construct " std::vector<phi::Scalar> scalar_list;
"InferMetaContext.", scalar_list.reserve(vec.size());
attr_name)); for (const auto& val : vec) {
} scalar_list.emplace_back(val);
} else if (ctx->HasInput(attr_name)) { }
auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
if (infershape_input.size() == 1) { } break;
if (ctx->IsRuntime()) { case framework::proto::AttrType::FLOATS: {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]); const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
infer_meta_context.EmplaceBackAttr( std::vector<phi::Scalar> scalar_list;
std::move(experimental::MakePhiScalarFromVar(*var))); scalar_list.reserve(vec.size());
} else { for (const auto& val : vec) {
phi::Scalar tensor_scalar(-1); scalar_list.emplace_back(val);
tensor_scalar.SetFromTensor(true); }
infer_meta_context.EmplaceBackAttr(std::move(tensor_scalar)); infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::FLOAT64S: {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( // do nothing, skip current attr
"Invalid input.size() when cast op attribute `%s` to Scalar, "
"expected 1, but actually is %d .",
attr_name, infershape_input.size()));
}
}
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = attr_reader.GetAttr(attr_name);
if (AttrTypeID(attr) == proto::AttrType::INTS) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == proto::AttrType::LONGS) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == proto::AttrType::FLOATS) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
infer_meta_context.EmplaceBackAttr(std::move(scalar_list));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct InferMetaContext.",
attr_names[i]));
}
} else if (ctx->HasAttr(attr_name)) {
// Emplace Back Attr according to the type of attr.
auto& attr = attr_reader.GetAttr(attr_name);
if (attr_defs[i].type_index == phi::AttributeType::BOOL) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT32) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT64) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::STRING) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::BOOLS) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) {
if (AttrTypeID(attr) == proto::AttrType::INTS) {
// Emplace Back Attr according to the type of Phi_Kernel args.
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(),
vector_int_attr.end());
infer_meta_context.EmplaceBackAttr(vector_int64_attr);
} else {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} }
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) { break;
infer_meta_context.EmplaceBackAttr( default:
BOOST_GET_CONST(std::vector<float>, attr)); if (attr_ptr) {
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT64S) { auto& attr = *attr_ptr;
infer_meta_context.EmplaceBackAttr( switch (attr_defs[i].type_index) {
BOOST_GET_CONST(std::vector<double>, attr)); case phi::AttributeType::FLOAT32:
} else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
infer_meta_context.EmplaceBackAttr( break;
BOOST_GET_CONST(std::vector<std::string>, attr)); case phi::AttributeType::INT32:
} else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
auto data_type = paddle::framework::TransToPhiDataType( break;
static_cast<framework::proto::VarType::Type>( case phi::AttributeType::BOOL:
BOOST_GET_CONST(int, attr))); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
infer_meta_context.EmplaceBackAttr(data_type); break;
} else { case phi::AttributeType::INT64:
PADDLE_THROW(platform::errors::Unimplemented( infer_meta_context.EmplaceBackAttr(
"Unsupported attribute type is received when call " BOOST_GET_CONST(int64_t, attr));
"InferShapeFunctor.")); break;
} case phi::AttributeType::INT32S:
} else if (ctx->HasInput(attr_name)) { infer_meta_context.EmplaceBackAttr(
// convert from data BOOST_GET_CONST(std::vector<int>, attr));
if (attr_defs[i].type_index == phi::AttributeType::INT32) { break;
if (ctx->IsRuntime()) { case phi::AttributeType::DATA_TYPE: {
auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); auto data_type = paddle::framework::TransToPhiDataType(
auto var_temp = BOOST_GET_CONST(Variable*, infershape_inputs[i]); static_cast<framework::proto::VarType::Type>(
auto val = experimental::MakePhiScalarFromVar(*var_temp); BOOST_GET_CONST(int, attr)));
int32_t val_int = val.template to<int32_t>(); infer_meta_context.EmplaceBackAttr(data_type);
infer_meta_context.EmplaceBackAttr(val_int); } break;
case phi::AttributeType::STRING:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::string, attr));
break;
case phi::AttributeType::INT64S:
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::LONGS:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
break;
case framework::proto::AttrType::INTS: {
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(
vector_int_attr.begin(), vector_int_attr.end());
infer_meta_context.EmplaceBackAttr(vector_int64_attr);
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<int64_t> "
"when "
"construct KernelContext.",
attr_names[i]));
}
break;
case phi::AttributeType::FLOAT32S:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr));
break;
case phi::AttributeType::STRINGS:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
break;
case phi::AttributeType::BOOLS:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr));
break;
case phi::AttributeType::FLOAT64S:
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { } else {
infer_meta_context.EmplaceBackAttr(-1); // do nothing, skip currnet attr
} }
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Get value from variable only support int yet"));
}
} }
} }
......
...@@ -2469,163 +2469,210 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2469,163 +2469,210 @@ void OperatorWithKernel::BuildPhiKernelContext(
VLOG(4) << "Done outputs"; VLOG(4) << "Done outputs";
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { VLOG(6) << "BuildPhiKernelContext: " << attr_names[i] << ": "
auto attr_iter = Attrs().find(attr_names[i]); << attr_defs[i].type_index;
if (attr_iter != Attrs().end()) { // shape is in the attribute auto attr_iter = Attrs().find(attr_names[i]);
auto& attr = attr_iter->second; switch (attr_defs[i].type_index) {
if (AttrTypeID(attr) == proto::AttrType::LONGS) { case phi::AttributeType::SCALAR:
pt_kernel_context->EmplaceBackAttr(std::move( if (attr_iter != Attrs().end()) {
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); // scalar is in the attribute
} else if (AttrTypeID(attr) == proto::AttrType::INTS) { switch (AttrTypeID(attr_iter->second)) {
pt_kernel_context->EmplaceBackAttr(std::move( case proto::AttrType::FLOAT:
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); pt_kernel_context->EmplaceBackAttr(std::move(
} else if (AttrTypeID(attr) == proto::AttrType::INT) { phi::Scalar(BOOST_GET_CONST(float, attr_iter->second))));
pt_kernel_context->EmplaceBackAttr( break;
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1))); case proto::AttrType::INT:
} else { pt_kernel_context->EmplaceBackAttr(std::move(
PADDLE_THROW(platform::errors::Unimplemented( phi::Scalar(BOOST_GET_CONST(int, attr_iter->second))));
"Unsupported cast op attribute `%s` to IntArray when " break;
"construct KernelContext.", case proto::AttrType::STRING:
attr_names[i])); pt_kernel_context->EmplaceBackAttr(std::move(phi::Scalar(
} BOOST_GET_CONST(std::string, attr_iter->second))));
} else { // shape is in the input break;
auto& ins_vector = ctx.inputs.at(attr_names[i]); default:
if (ins_vector.size() == 1) { // ShapeTensor PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context->EmplaceBackAttr(std::move( pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePhiIntArrayFromVar(*ins_vector.front()))); experimental::MakePhiScalarFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context->EmplaceBackAttr(
std::move(experimental::MakePhiIntArrayFromVarList(ins_vector)));
}
}
} else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
auto attr_iter = Attrs().find(attr_names[i]);
if (attr_iter != Attrs().end()) { // scalar is in the attribute
auto& attr = attr_iter->second;
if (AttrTypeID(attr) == proto::AttrType::FLOAT) {
pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::STRING) {
pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (AttrTypeID(attr) == proto::AttrType::INT) {
pt_kernel_context->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext.",
attr_names[i]));
}
} else {
auto& ins_vector = ctx.inputs.at(attr_names[i]);
pt_kernel_context->EmplaceBackAttr(
std::move(experimental::MakePhiScalarFromVar(*ins_vector.front())));
}
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = Attrs().at(attr_names[i]);
if (AttrTypeID(attr) == proto::AttrType::INTS) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); break;
} else if (AttrTypeID(attr) == proto::AttrType::LONGS) { case phi::AttributeType::INT_ARRAY:
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr); if (attr_iter != Attrs().end()) {
std::vector<phi::Scalar> scalar_list; switch (AttrTypeID(attr_iter->second)) {
scalar_list.reserve(vec.size()); case proto::AttrType::INTS:
for (const auto& val : vec) { pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
scalar_list.emplace_back(val); BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} break;
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); case proto::AttrType::LONGS:
} else if (AttrTypeID(attr) == proto::AttrType::FLOATS) { pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr); BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
std::vector<phi::Scalar> scalar_list; break;
scalar_list.reserve(vec.size()); case proto::AttrType::INT:
for (const auto& val : vec) { pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
scalar_list.emplace_back(val); &BOOST_GET_CONST(int32_t, attr_iter->second), 1)));
} break;
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); case proto::AttrType::LONG:
} else if (AttrTypeID(attr) == proto::AttrType::FLOAT64S) { pt_kernel_context->EmplaceBackAttr(std::move(phi::IntArray(
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr); &BOOST_GET_CONST(int64_t, attr_iter->second), 1)));
std::vector<phi::Scalar> scalar_list; break;
scalar_list.reserve(vec.size()); default:
for (const auto& val : vec) { PADDLE_THROW(platform::errors::Unimplemented(
scalar_list.emplace_back(val); "Unsupported cast op attribute `%s` to IntArray when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ctx.inputs.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePhiIntArrayFromVar(*ins_vector.front())));
} else { // ShapeTensorList
pt_kernel_context->EmplaceBackAttr(std::move(
experimental::MakePhiIntArrayFromVarList(ins_vector)));
}
} }
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list)); break;
} else { case phi::AttributeType::SCALARS: {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_ENFORCE_NE(
"Unsupported cast op attribute `%s` to vector<Scalar> when " attr_iter, Attrs().end(),
"construct KernelContext.", platform::errors::NotFound("(%s) is not found in AttributeMap when "
attr_names[i])); "buildind static KernelContext.",
} attr_names[i]));
} else { switch (AttrTypeID(attr_iter->second)) {
auto attr_it = attrs_.find(attr_names[i]); case proto::AttrType::INTS: {
if (attr_defs[i].type_index == phi::AttributeType::INT32) { const auto& vec =
if (attr_it == attrs_.end()) { BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second);
auto in_it = ctx.inputs.find(attr_names[i]); std::vector<phi::Scalar> scalar_list;
if (in_it != ctx.inputs.end()) { scalar_list.reserve(vec.size());
// get data from input for (const auto& val : vec) {
auto val = experimental::MakePhiScalarFromVar(*(in_it->second[0])); scalar_list.emplace_back(val);
int32_t val_int = val.template to<int32_t>(); }
pt_kernel_context->EmplaceBackAttr(val_int); pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} else { } break;
PADDLE_THROW(platform::errors::NotFound( case proto::AttrType::LONGS: {
"can not find attribute `%s` both in attribute and input ", const auto& vec =
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::FLOATS: {
const auto& vec =
BOOST_GET_CONST(std::vector<float>, attr_iter->second);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::FLOAT64S: {
const auto& vec =
BOOST_GET_CONST(std::vector<double>, attr_iter->second);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
case proto::AttrType::BOOLEANS: {
const auto& vec =
BOOST_GET_CONST(std::vector<bool>, attr_iter->second);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
pt_kernel_context->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i])); attr_names[i]));
}
} else {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(int, attr_it->second));
} }
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { } break;
pt_kernel_context->EmplaceBackAttr( default: {
BOOST_GET_CONST(float, attr_it->second)); PADDLE_ENFORCE_NE(
} else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { attr_iter, Attrs().end(),
pt_kernel_context->EmplaceBackAttr( platform::errors::NotFound("(%s) is not found in AttributeMap when "
BOOST_GET_CONST(bool, attr_it->second)); "buildind static KernelContext.",
} else if (attr_defs[i].type_index == phi::AttributeType::INT64) { attr_names[i]));
pt_kernel_context->EmplaceBackAttr( switch (attr_defs[i].type_index) {
BOOST_GET_CONST(int64_t, attr_it->second)); case phi::AttributeType::FLOAT32:
} else if (attr_defs[i].type_index == phi::AttributeType::STRING) { pt_kernel_context->EmplaceBackAttr(
pt_kernel_context->EmplaceBackAttr( BOOST_GET_CONST(float, attr_iter->second));
BOOST_GET_CONST(std::string, attr_it->second)); break;
} else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { case phi::AttributeType::INT32:
auto data_type = paddle::framework::TransToPhiDataType( pt_kernel_context->EmplaceBackAttr(
static_cast<framework::proto::VarType::Type>( BOOST_GET_CONST(int, attr_iter->second));
BOOST_GET_CONST(int, attr_it->second))); break;
pt_kernel_context->EmplaceBackAttr(data_type); case phi::AttributeType::BOOL:
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { pt_kernel_context->EmplaceBackAttr(
if (AttrTypeID(attr_it->second) == proto::AttrType::LONGS) { BOOST_GET_CONST(bool, attr_iter->second));
pt_kernel_context->EmplaceBackAttr( break;
BOOST_GET_CONST(std::vector<int64_t>, attr_it->second)); case phi::AttributeType::INT64:
} else if (AttrTypeID(attr_it->second) == proto::AttrType::INTS) { pt_kernel_context->EmplaceBackAttr(
// Emplace Back Attr according to the type of Phi_Kernel args. BOOST_GET_CONST(int64_t, attr_iter->second));
const auto& vector_int_attr = break;
BOOST_GET_CONST(std::vector<int>, attr_it->second); case phi::AttributeType::INT32S:
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), pt_kernel_context->EmplaceBackAttr(
vector_int_attr.end()); BOOST_GET_CONST(std::vector<int>, attr_iter->second));
pt_kernel_context->EmplaceBackAttr(vector_int64_attr); break;
case phi::AttributeType::DATA_TYPE: {
auto data_type = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(
BOOST_GET_CONST(int, attr_iter->second)));
pt_kernel_context->EmplaceBackAttr(data_type);
} break;
case phi::AttributeType::STRING:
pt_kernel_context->EmplaceBackAttr(
std::move(BOOST_GET_CONST(std::string, attr_iter->second)));
break;
case phi::AttributeType::INT64S:
switch (AttrTypeID(attr_iter->second)) {
case proto::AttrType::LONGS:
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second));
break;
case proto::AttrType::INTS: {
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_iter->second);
const std::vector<int64_t> vector_int64_attr(
vector_int_attr.begin(), vector_int_attr.end());
pt_kernel_context->EmplaceBackAttr(vector_int64_attr);
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<int64_t> "
"when "
"construct KernelContext.",
attr_names[i]));
}
break;
case phi::AttributeType::FLOAT32S:
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_iter->second));
break;
case phi::AttributeType::STRINGS:
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_iter->second));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
} }
} else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr_it->second);
pt_kernel_context->EmplaceBackAttr(vector_int_attr);
} else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr_it->second));
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) {
pt_kernel_context->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr_it->second));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext.",
attr_names[i]));
} }
} }
} }
......
...@@ -220,7 +220,7 @@ class PreparedOp { ...@@ -220,7 +220,7 @@ class PreparedOp {
static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map; static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map;
}; };
const inline framework::Attribute& GetAttr( const inline framework::Attribute* GetAttr(
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const std::string& name) { const framework::AttributeMap& default_attrs, const std::string& name) {
auto it = attrs.find(name); auto it = attrs.find(name);
...@@ -229,10 +229,10 @@ const inline framework::Attribute& GetAttr( ...@@ -229,10 +229,10 @@ const inline framework::Attribute& GetAttr(
it = default_attrs.find(name); it = default_attrs.find(name);
found = it != default_attrs.end(); found = it != default_attrs.end();
} }
PADDLE_ENFORCE_EQ( if (found) {
found, true, return &it->second;
platform::errors::NotFound("(%s) is not found in AttributeMap.", name)); }
return it->second; return nullptr;
} }
template <typename VarType> template <typename VarType>
...@@ -330,6 +330,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -330,6 +330,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(6) << "BuildDygraphPhiKernelContext: Inputs parsing completed.";
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second); size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
...@@ -380,178 +381,217 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, ...@@ -380,178 +381,217 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature,
} }
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(6) << "BuildDygraphPhiKernelContext: Outputs parsing completed.";
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
if (attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) { VLOG(6) << "BuildDygraphPhiKernelContext: " << attr_names[i] << ": "
if (attrs.find(attr_names[i]) != << attr_defs[i].type_index;
attrs.end()) { // shape is in the attribute auto* attr_ptr = GetAttr(attrs, default_attrs, attr_names[i]);
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); switch (attr_defs[i].type_index) {
if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { case phi::AttributeType::SCALAR:
kernel_ctx->EmplaceBackAttr(std::move( if (attr_ptr) {
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); // scalar is in the attribute
} else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::FLOAT:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
break;
case framework::proto::AttrType::INT:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
break;
case framework::proto::AttrType::STRING:
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr)))); experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
} else if (AttrTypeID(attr) == framework::proto::AttrType::LONG) {
kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
kernel_ctx->EmplaceBackAttr(vector_int_attr);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
"construct KernelContext.",
attr_names[i]));
} }
} else { // shape is in the input break;
auto& ins_vector = ins.at(attr_names[i]); case phi::AttributeType::INT_ARRAY:
if (ins_vector.size() == 1) { // ShapeTensor if (attr_ptr) {
kernel_ctx->EmplaceBackAttr(std::move( auto& attr = *attr_ptr;
experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var()))); switch (AttrTypeID(attr)) {
} else { // ShapeTensorList case framework::proto::AttrType::INTS:
std::vector<framework::Variable*> variables; kernel_ctx->EmplaceBackAttr(std::move(
variables.reserve(ins_vector.size()); phi::IntArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
for (const auto& var_base : ins_vector) { break;
variables.push_back(var_base->MutableVar()); case framework::proto::AttrType::LONGS:
kernel_ctx->EmplaceBackAttr(std::move(
phi::IntArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
break;
case framework::proto::AttrType::INT:
kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int32_t, attr), 1)));
break;
case framework::proto::AttrType::LONG:
kernel_ctx->EmplaceBackAttr(
std::move(phi::IntArray(&BOOST_GET_CONST(int64_t, attr), 1)));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to IntArray when "
"construct KernelContext.",
attr_names[i]));
}
} else { // shape is in the input
auto& ins_vector = ins.at(attr_names[i]);
if (ins_vector.size() == 1) { // ShapeTensor
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePhiIntArrayFromVar(ins_vector[0]->Var())));
} else { // ShapeTensorList
std::vector<framework::Variable*> variables;
variables.reserve(ins_vector.size());
for (const auto& var_base : ins_vector) {
variables.push_back(var_base->MutableVar());
}
kernel_ctx->EmplaceBackAttr(
std::move(experimental::MakePhiIntArrayFromVarList(variables)));
} }
kernel_ctx->EmplaceBackAttr(
std::move(experimental::MakePhiIntArrayFromVarList(variables)));
}
}
} else if (attr_defs[i].type_index == phi::AttributeType::SCALAR) {
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (attrs.find(attr_names[i]) != attrs.end() ||
default_attrs.find(attr_names[i]) !=
default_attrs.end()) { // scalar is in the attribute
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT) {
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(float, attr))));
} else if (AttrTypeID(attr) == framework::proto::AttrType::STRING) {
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(std::string, attr))));
} else if (AttrTypeID(attr) == framework::proto::AttrType::INT) {
kernel_ctx->EmplaceBackAttr(
std::move(phi::Scalar(BOOST_GET_CONST(int, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to Scalar when construct "
"KernelContext in dygraph.",
attr_names[i]));
}
} else { // scalar is in the input
auto& ins_vector = ins.at(attr_names[i]);
kernel_ctx->EmplaceBackAttr(std::move(
experimental::MakePhiScalarFromVar(ins_vector[0]->Var())));
}
} else if (ins.find(attr_names[i]) != ins.end()) {
// deal tensor attr here
auto& ins_vector = ins.at(attr_names[i]);
auto tensor_attr =
experimental::MakePhiScalarFromVar(ins_vector[0]->Var());
if (attr_defs[i].type_index == phi::AttributeType::INT32) {
int val = tensor_attr.template to<int>();
kernel_ctx->EmplaceBackAttr(val);
} else {
PADDLE_THROW(platform::errors::Unimplemented("only support int here"));
}
} else if (attr_defs[i].type_index == phi::AttributeType::SCALARS) {
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]);
if (AttrTypeID(attr) == framework::proto::AttrType::INTS) {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == framework::proto::AttrType::FLOATS) {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} else if (AttrTypeID(attr) == framework::proto::AttrType::FLOAT64S) {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); break;
} else if (AttrTypeID(attr) == framework::proto::AttrType::BOOLEANS) { case phi::AttributeType::SCALARS: {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr); PADDLE_ENFORCE_NOT_NULL(
std::vector<phi::Scalar> scalar_list; attr_ptr,
scalar_list.reserve(vec.size()); platform::errors::NotFound("(%s) is not found in AttributeMap when "
for (const auto& val : vec) { "buildind dygraph KernelContext.",
scalar_list.emplace_back(val); attr_names[i]));
auto& attr = *attr_ptr;
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::INTS: {
const auto& vec = BOOST_GET_CONST(std::vector<int32_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::LONGS: {
const auto& vec = BOOST_GET_CONST(std::vector<int64_t>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::FLOATS: {
const auto& vec = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::FLOAT64S: {
const auto& vec = BOOST_GET_CONST(std::vector<double>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
case framework::proto::AttrType::BOOLEANS: {
const auto& vec = BOOST_GET_CONST(std::vector<bool>, attr);
std::vector<phi::Scalar> scalar_list;
scalar_list.reserve(vec.size());
for (const auto& val : vec) {
scalar_list.emplace_back(val);
}
kernel_ctx->EmplaceBackAttr(std::move(scalar_list));
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<Scalar> when "
"construct KernelContext.",
attr_names[i]));
} }
kernel_ctx->EmplaceBackAttr(std::move(scalar_list)); } break;
} else { default: {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_ENFORCE_NOT_NULL(
"Unsupported cast op attribute `%s` to vector<Scalar> when " attr_ptr,
"construct KernelContext.", platform::errors::NotFound("(%s) is not found in AttributeMap when "
attr_names[i])); "buildind dygraph KernelContext.",
} attr_names[i]));
} else { auto& attr = *attr_ptr;
auto& attr = GetAttr(attrs, default_attrs, attr_names[i]); switch (attr_defs[i].type_index) {
if (attr_defs[i].type_index == phi::AttributeType::INT32) { case phi::AttributeType::FLOAT32:
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32) { break;
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); case phi::AttributeType::INT32:
} else if (attr_defs[i].type_index == phi::AttributeType::BOOL) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); break;
} else if (attr_defs[i].type_index == phi::AttributeType::INT64) { case phi::AttributeType::BOOL:
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::STRING) { break;
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); case phi::AttributeType::INT64:
} else if (attr_defs[i].type_index == phi::AttributeType::DATA_TYPE) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
auto data_type = framework::TransToPhiDataType( break;
static_cast<framework::proto::VarType::Type>( case phi::AttributeType::INT32S:
BOOST_GET_CONST(int, attr))); kernel_ctx->EmplaceBackAttr(
kernel_ctx->EmplaceBackAttr(data_type); BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::INT64S) { break;
if (AttrTypeID(attr) == framework::proto::AttrType::LONGS) { case phi::AttributeType::DATA_TYPE: {
kernel_ctx->EmplaceBackAttr( auto data_type = framework::TransToPhiDataType(
BOOST_GET_CONST(std::vector<int64_t>, attr)); static_cast<framework::proto::VarType::Type>(
} else if (AttrTypeID(attr) == framework::proto::AttrType::INTS) { BOOST_GET_CONST(int, attr)));
// Emplace Back Attr according to the type of Phi_Kernel args. kernel_ctx->EmplaceBackAttr(data_type);
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr); } break;
const std::vector<int64_t> vector_int64_attr(vector_int_attr.begin(), case phi::AttributeType::STRING:
vector_int_attr.end()); kernel_ctx->EmplaceBackAttr(
kernel_ctx->EmplaceBackAttr(vector_int64_attr); std::move(BOOST_GET_CONST(std::string, attr)));
break;
case phi::AttributeType::INT64S: {
switch (AttrTypeID(attr)) {
case framework::proto::AttrType::LONGS:
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
break;
case framework::proto::AttrType::INTS: {
const auto& vector_int_attr =
BOOST_GET_CONST(std::vector<int>, attr);
const std::vector<int64_t> vector_int64_attr(
vector_int_attr.begin(), vector_int_attr.end());
kernel_ctx->EmplaceBackAttr(vector_int64_attr);
} break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to vector<int64_t> "
"when "
"construct KernelContext.",
attr_names[i]));
}
} break;
case phi::AttributeType::FLOAT32S:
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr));
break;
case phi::AttributeType::STRINGS:
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
} }
} else if (attr_defs[i].type_index == phi::AttributeType::INT32S) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::STRINGS) {
kernel_ctx->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else if (attr_defs[i].type_index == phi::AttributeType::FLOAT32S) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<float>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` when construct "
"KernelContext in dygraph.",
attr_names[i]));
} }
} }
} }
VLOG(6) << "BuildDygraphPhiKernelContext: Attributes parsing completed.";
} }
template <typename VarType> template <typename VarType>
......
...@@ -3011,7 +3011,7 @@ void UnStackInferMeta(const MetaTensor& x, ...@@ -3011,7 +3011,7 @@ void UnStackInferMeta(const MetaTensor& x,
} }
void OneHotRawInferMeta(const MetaTensor& x, void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth, const Scalar& depth,
DataType dtype, DataType dtype,
bool allow_out_of_range, bool allow_out_of_range,
MetaTensor* out) { MetaTensor* out) {
...@@ -3021,7 +3021,7 @@ void OneHotRawInferMeta(const MetaTensor& x, ...@@ -3021,7 +3021,7 @@ void OneHotRawInferMeta(const MetaTensor& x,
1, 1,
phi::errors::InvalidArgument("Rank of Input(X) should be at least 1.")); phi::errors::InvalidArgument("Rank of Input(X) should be at least 1."));
auto out_dims_vec = phi::vectorize(x_dims); auto out_dims_vec = phi::vectorize(x_dims);
out_dims_vec.push_back(depth); out_dims_vec.push_back(depth.to<int>());
auto out_dims = phi::make_ddim(out_dims_vec); auto out_dims = phi::make_ddim(out_dims_vec);
out->set_dims(out_dims); out->set_dims(out_dims);
out->share_lod(x); out->share_lod(x);
......
...@@ -431,7 +431,7 @@ void UnStackInferMeta(const MetaTensor& x, ...@@ -431,7 +431,7 @@ void UnStackInferMeta(const MetaTensor& x,
std::vector<MetaTensor*> outs); std::vector<MetaTensor*> outs);
void OneHotRawInferMeta(const MetaTensor& x, void OneHotRawInferMeta(const MetaTensor& x,
int32_t depth, const Scalar& depth,
DataType dtype, DataType dtype,
bool allow_out_of_range, bool allow_out_of_range,
MetaTensor* out); MetaTensor* out);
......
...@@ -64,18 +64,19 @@ struct OneHotV2OpFunctor { ...@@ -64,18 +64,19 @@ struct OneHotV2OpFunctor {
template <typename T, typename Context> template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx, void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int32_t depth, const Scalar& depth,
DataType dtype, DataType dtype,
bool allow_out_of_range, bool allow_out_of_range,
DenseTensor* out) { DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims(); auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) { if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth; out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims); out->Resize(out_dims);
} }
phi::VisitDataType(dtype, phi::VisitDataType(dtype,
OneHotV2OpFunctor<Context, T>(&x, out, depth, dev_ctx)); OneHotV2OpFunctor<Context, T>(&x, out, depth_v, dev_ctx));
} }
} // namespace phi } // namespace phi
......
...@@ -73,18 +73,19 @@ struct OneHotV2OpCUDAFunctor { ...@@ -73,18 +73,19 @@ struct OneHotV2OpCUDAFunctor {
template <typename T, typename Context> template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx, void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int32_t depth, const Scalar& depth,
DataType dtype, DataType dtype,
bool allow_out_of_range, bool allow_out_of_range,
DenseTensor* out) { DenseTensor* out) {
auto depth_v = depth.to<int>();
auto out_dims = out->dims(); auto out_dims = out->dims();
if (out_dims[out_dims.size() - 1] == -1) { if (out_dims[out_dims.size() - 1] == -1) {
out_dims[out_dims.size() - 1] = depth; out_dims[out_dims.size() - 1] = depth_v;
out->Resize(out_dims); out->Resize(out_dims);
} }
phi::VisitDataType( phi::VisitDataType(
dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth, dev_ctx)); dtype, OneHotV2OpCUDAFunctor<Context, T>(&x, out, depth_v, dev_ctx));
} }
} // namespace phi } // namespace phi
......
...@@ -24,9 +24,8 @@ void OneHotKernel(const Context& dev_ctx, ...@@ -24,9 +24,8 @@ void OneHotKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const Scalar& num_classes_s, const Scalar& num_classes_s,
DenseTensor* out) { DenseTensor* out) {
int num_classes = num_classes_s.to<int>();
OneHotRawKernel<T>( OneHotRawKernel<T>(
dev_ctx, x, num_classes, phi::DataType::FLOAT32, false, out); dev_ctx, x, num_classes_s, phi::DataType::FLOAT32, false, out);
} }
} // namespace phi } // namespace phi
......
...@@ -28,7 +28,7 @@ void OneHotKernel(const Context& dev_ctx, ...@@ -28,7 +28,7 @@ void OneHotKernel(const Context& dev_ctx,
template <typename T, typename Context> template <typename T, typename Context>
void OneHotRawKernel(const Context& dev_ctx, void OneHotRawKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int32_t depth, const Scalar& depth,
DataType dtype, DataType dtype,
bool allow_out_of_range, bool allow_out_of_range,
DenseTensor* out); DenseTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册