From bf92ccc7a82c7ed60f4102b8b1b742cb14ac31ef Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Wed, 5 Jul 2023 11:05:45 +0800 Subject: [PATCH] [NewIR]Fix tensor attribute translator bug (#55129) * suport optional input in new_ir * polish code * add coverate test * update * update * add unitest * remove reduplicate code * udpate * fix assign error * revert test arg min max * update * fix bug * polish code --- paddle/fluid/framework/op_desc.cc | 6 +++++- paddle/fluid/framework/op_desc.h | 2 +- .../ir/phi_kernel_adaptor/phi_kernel_util.h | 17 +++++++++++++++++ .../ir_adaptor/translator/op_translator.cc | 15 +++++++++------ paddle/phi/api/yaml/op_compat.yaml | 2 ++ paddle/phi/infermeta/unary.cc | 1 - 6 files changed, 34 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index da98eb77a8d..dcdf15c19b7 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -583,7 +583,11 @@ bool OpDesc::HasOutput(const std::string &name) const { return outputs_.find(name) != outputs_.end(); } -bool OpDesc::HasInput(const std::string &name) const { +bool OpDesc::HasInput(const std::string &name, bool with_attr_var) const { + if (with_attr_var) { + auto it = attrs_.find(name); + if (it != attrs_.end() && HasAttrVar(it->second)) return true; + } return inputs_.find(name) != inputs_.end(); } diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 45c7ba639f6..85870ea18f2 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -77,7 +77,7 @@ class OpDesc { bool HasOutput(const std::string &name) const; - bool HasInput(const std::string &name) const; + bool HasInput(const std::string &name, bool with_attr_var = false) const; std::vector OutputArgumentNames() const; diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 4408d6be1ed..7ecf94fe2fe 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -156,6 +156,23 @@ void BuildPhiContext( ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "ir::StrAttribute") { ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == + "ir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().data(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ( + array_list[0].isa(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be dialect::ScalarAttribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back(array_list[i] + .dyn_cast() + .data()); + } + } + ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { auto array_list = attr_map[t].dyn_cast().data(); std::vector vec_res; diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 423feaabb30..356f0fc990c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -372,7 +372,7 @@ std::vector OpTranscriber::GenerateOperationInput( std::vector legacy_input_vars; // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute - if (op_desc.HasInput(legacy_input_name)) { + if (op_desc.HasInput(legacy_input_name, true)) { legacy_input_vars = op_desc.Input(legacy_input_name, true); } @@ -779,18 +779,21 @@ struct AssignValueOpTranscriber : public OpTranscriber { dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); attribute_map["place"] = attr_place; - if (op_desc.HasAttr("bool_values")) { + int dtype = paddle::get(op_desc.GetAttr("dtype")); + + if (dtype == /*BOOL*/ 0) { legacy_attr = op_desc.GetAttr("bool_values"); - } else if (op_desc.HasAttr("fp32_values")) { - legacy_attr = op_desc.GetAttr("fp32_values"); - } else if (op_desc.HasAttr("int32_values")) { + } else if (dtype == /*INT32*/ 2) { legacy_attr = op_desc.GetAttr("int32_values"); - } else if (op_desc.HasAttr("int64_values")) { + } else if (dtype == /*FP32*/ 5) { + legacy_attr = op_desc.GetAttr("fp32_values"); + } else if (dtype == /*INT64*/ 3) { legacy_attr = op_desc.GetAttr("int64_values"); } else { IR_THROW( "Op assign_value should have attribute `**_values` but not find"); } + ir::Attribute attr_values = attribute_translator( attr_info_maps.at("values").type_name, legacy_attr); attribute_map["values"] = attr_values; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 20e6e8b40d1..8e8cca4296e 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -208,6 +208,7 @@ - op : argmax(arg_max) inputs : x : X + axis : axis outputs : out : Out scalar: @@ -218,6 +219,7 @@ - op : argmin(arg_min) inputs : x : X + axis : axis outputs : out : Out scalar: diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d590f0a875d..f37ae1d688c 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -183,7 +183,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x, } return; } - auto int_axis = axis.to(); const auto& x_dims = x.dims(); -- GitLab