未验证 提交 bf92ccc7 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix tensor attribute translator bug (#55129)

* suport optional input in new_ir

* polish code

* add coverate test

* update

* update

* add unitest

* remove reduplicate code

* udpate

* fix assign error

* revert test arg min max

* update

* fix bug

* polish code
上级 d6e90046
...@@ -583,7 +583,11 @@ bool OpDesc::HasOutput(const std::string &name) const { ...@@ -583,7 +583,11 @@ bool OpDesc::HasOutput(const std::string &name) const {
return outputs_.find(name) != outputs_.end(); return outputs_.find(name) != outputs_.end();
} }
bool OpDesc::HasInput(const std::string &name) const { bool OpDesc::HasInput(const std::string &name, bool with_attr_var) const {
if (with_attr_var) {
auto it = attrs_.find(name);
if (it != attrs_.end() && HasAttrVar(it->second)) return true;
}
return inputs_.find(name) != inputs_.end(); return inputs_.find(name) != inputs_.end();
} }
......
...@@ -77,7 +77,7 @@ class OpDesc { ...@@ -77,7 +77,7 @@ class OpDesc {
bool HasOutput(const std::string &name) const; bool HasOutput(const std::string &name) const;
bool HasInput(const std::string &name) const; bool HasInput(const std::string &name, bool with_attr_var = false) const;
std::vector<std::string> OutputArgumentNames() const; std::vector<std::string> OutputArgumentNames() const;
......
...@@ -156,6 +156,23 @@ void BuildPhiContext( ...@@ -156,6 +156,23 @@ void BuildPhiContext(
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") { } else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
} else if (attr_type_name ==
"ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<phi::Scalar> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<paddle::dialect::ScalarAttribute>(),
true,
phi::errors::Unimplemented(
"the 0th elementwise MUST be dialect::ScalarAttribute"));
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(array_list[i]
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data());
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int32_t> vec_res; std::vector<int32_t> vec_res;
......
...@@ -372,7 +372,7 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -372,7 +372,7 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<std::string> legacy_input_vars; std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc // return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute // TODO(lyk): HasInput doesnot consider variadic attribute
if (op_desc.HasInput(legacy_input_name)) { if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true); legacy_input_vars = op_desc.Input(legacy_input_name, true);
} }
...@@ -779,18 +779,21 @@ struct AssignValueOpTranscriber : public OpTranscriber { ...@@ -779,18 +779,21 @@ struct AssignValueOpTranscriber : public OpTranscriber {
dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
attribute_map["place"] = attr_place; attribute_map["place"] = attr_place;
if (op_desc.HasAttr("bool_values")) { int dtype = paddle::get<int>(op_desc.GetAttr("dtype"));
if (dtype == /*BOOL*/ 0) {
legacy_attr = op_desc.GetAttr("bool_values"); legacy_attr = op_desc.GetAttr("bool_values");
} else if (op_desc.HasAttr("fp32_values")) { } else if (dtype == /*INT32*/ 2) {
legacy_attr = op_desc.GetAttr("fp32_values");
} else if (op_desc.HasAttr("int32_values")) {
legacy_attr = op_desc.GetAttr("int32_values"); legacy_attr = op_desc.GetAttr("int32_values");
} else if (op_desc.HasAttr("int64_values")) { } else if (dtype == /*FP32*/ 5) {
legacy_attr = op_desc.GetAttr("fp32_values");
} else if (dtype == /*INT64*/ 3) {
legacy_attr = op_desc.GetAttr("int64_values"); legacy_attr = op_desc.GetAttr("int64_values");
} else { } else {
IR_THROW( IR_THROW(
"Op assign_value should have attribute `**_values` but not find"); "Op assign_value should have attribute `**_values` but not find");
} }
ir::Attribute attr_values = attribute_translator( ir::Attribute attr_values = attribute_translator(
attr_info_maps.at("values").type_name, legacy_attr); attr_info_maps.at("values").type_name, legacy_attr);
attribute_map["values"] = attr_values; attribute_map["values"] = attr_values;
......
...@@ -208,6 +208,7 @@ ...@@ -208,6 +208,7 @@
- op : argmax(arg_max) - op : argmax(arg_max)
inputs : inputs :
x : X x : X
axis : axis
outputs : outputs :
out : Out out : Out
scalar: scalar:
...@@ -218,6 +219,7 @@ ...@@ -218,6 +219,7 @@
- op : argmin(arg_min) - op : argmin(arg_min)
inputs : inputs :
x : X x : X
axis : axis
outputs : outputs :
out : Out out : Out
scalar: scalar:
......
...@@ -183,7 +183,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x, ...@@ -183,7 +183,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
} }
return; return;
} }
auto int_axis = axis.to<int64_t>(); auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims(); const auto& x_dims = x.dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册