未验证 提交 bf92ccc7 编写于 作者: H hong 提交者: GitHub

[NewIR]Fix tensor attribute translator bug (#55129)

* suport optional input in new_ir

* polish code

* add coverate test

* update

* update

* add unitest

* remove reduplicate code

* udpate

* fix assign error

* revert test arg min max

* update

* fix bug

* polish code
上级 d6e90046
......@@ -583,7 +583,11 @@ bool OpDesc::HasOutput(const std::string &name) const {
return outputs_.find(name) != outputs_.end();
}
bool OpDesc::HasInput(const std::string &name) const {
bool OpDesc::HasInput(const std::string &name, bool with_attr_var) const {
if (with_attr_var) {
auto it = attrs_.find(name);
if (it != attrs_.end() && HasAttrVar(it->second)) return true;
}
return inputs_.find(name) != inputs_.end();
}
......
......@@ -77,7 +77,7 @@ class OpDesc {
bool HasOutput(const std::string &name) const;
bool HasInput(const std::string &name) const;
bool HasInput(const std::string &name, bool with_attr_var = false) const;
std::vector<std::string> OutputArgumentNames() const;
......
......@@ -156,6 +156,23 @@ void BuildPhiContext(
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
} else if (attr_type_name ==
"ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<phi::Scalar> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
array_list[0].isa<paddle::dialect::ScalarAttribute>(),
true,
phi::errors::Unimplemented(
"the 0th elementwise MUST be dialect::ScalarAttribute"));
for (size_t i = 0; i < array_list.size(); ++i) {
vec_res.push_back(array_list[i]
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data());
}
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
std::vector<int32_t> vec_res;
......
......@@ -372,7 +372,7 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (op_desc.HasInput(legacy_input_name)) {
if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true);
}
......@@ -779,18 +779,21 @@ struct AssignValueOpTranscriber : public OpTranscriber {
dialect::PlaceAttribute::get(ctx, phi::CPUPlace());
attribute_map["place"] = attr_place;
if (op_desc.HasAttr("bool_values")) {
int dtype = paddle::get<int>(op_desc.GetAttr("dtype"));
if (dtype == /*BOOL*/ 0) {
legacy_attr = op_desc.GetAttr("bool_values");
} else if (op_desc.HasAttr("fp32_values")) {
legacy_attr = op_desc.GetAttr("fp32_values");
} else if (op_desc.HasAttr("int32_values")) {
} else if (dtype == /*INT32*/ 2) {
legacy_attr = op_desc.GetAttr("int32_values");
} else if (op_desc.HasAttr("int64_values")) {
} else if (dtype == /*FP32*/ 5) {
legacy_attr = op_desc.GetAttr("fp32_values");
} else if (dtype == /*INT64*/ 3) {
legacy_attr = op_desc.GetAttr("int64_values");
} else {
IR_THROW(
"Op assign_value should have attribute `**_values` but not find");
}
ir::Attribute attr_values = attribute_translator(
attr_info_maps.at("values").type_name, legacy_attr);
attribute_map["values"] = attr_values;
......
......@@ -208,6 +208,7 @@
- op : argmax(arg_max)
inputs :
x : X
axis : axis
outputs :
out : Out
scalar:
......@@ -218,6 +219,7 @@
- op : argmin(arg_min)
inputs :
x : X
axis : axis
outputs :
out : Out
scalar:
......
......@@ -183,7 +183,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}
return;
}
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册