diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index da98eb77a8dc2586fff0647c19fcd06644516128..dcdf15c19b7bf6c1ccc1dcfaa265f64ddf04aed3 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -583,7 +583,11 @@ bool OpDesc::HasOutput(const std::string &name) const { return outputs_.find(name) != outputs_.end(); } -bool OpDesc::HasInput(const std::string &name) const { +bool OpDesc::HasInput(const std::string &name, bool with_attr_var) const { + if (with_attr_var) { + auto it = attrs_.find(name); + if (it != attrs_.end() && HasAttrVar(it->second)) return true; + } return inputs_.find(name) != inputs_.end(); } diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 45c7ba639f6a0d2089dc0e6a978d8fac5f6fa892..85870ea18f27942f14cd107cf8e1e7f542e5e1e5 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -77,7 +77,7 @@ class OpDesc { bool HasOutput(const std::string &name) const; - bool HasInput(const std::string &name) const; + bool HasInput(const std::string &name, bool with_attr_var = false) const; std::vector OutputArgumentNames() const; diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h index 4408d6be1ed3369757861a1cf5f4392e9e58b41a..7ecf94fe2fe3840d9604dd7744ba02caacf0420f 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h @@ -156,6 +156,23 @@ void BuildPhiContext( ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); } else if (attr_type_name == "ir::StrAttribute") { ctx->EmplaceBackAttr(attr_map[t].dyn_cast().data()); + } else if (attr_type_name == + "ir::ArrayAttribute") { + auto array_list = attr_map[t].dyn_cast().data(); + std::vector vec_res; + if (array_list.size() > 0) { + PADDLE_ENFORCE_EQ( + array_list[0].isa(), + true, + phi::errors::Unimplemented( + "the 0th elementwise MUST be dialect::ScalarAttribute")); + for (size_t i = 0; i < array_list.size(); ++i) { + vec_res.push_back(array_list[i] + .dyn_cast() + .data()); + } + } + ctx->EmplaceBackAttr(vec_res); } else if (attr_type_name == "ir::ArrayAttribute") { auto array_list = attr_map[t].dyn_cast().data(); std::vector vec_res; diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 423feaabb30181f5c2116d3a98433f48b41f0332..356f0fc990c74cf0adbdeb0fdf3b23ffdc281740 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -372,7 +372,7 @@ std::vector OpTranscriber::GenerateOperationInput( std::vector legacy_input_vars; // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute - if (op_desc.HasInput(legacy_input_name)) { + if (op_desc.HasInput(legacy_input_name, true)) { legacy_input_vars = op_desc.Input(legacy_input_name, true); } @@ -779,18 +779,21 @@ struct AssignValueOpTranscriber : public OpTranscriber { dialect::PlaceAttribute::get(ctx, phi::CPUPlace()); attribute_map["place"] = attr_place; - if (op_desc.HasAttr("bool_values")) { + int dtype = paddle::get(op_desc.GetAttr("dtype")); + + if (dtype == /*BOOL*/ 0) { legacy_attr = op_desc.GetAttr("bool_values"); - } else if (op_desc.HasAttr("fp32_values")) { - legacy_attr = op_desc.GetAttr("fp32_values"); - } else if (op_desc.HasAttr("int32_values")) { + } else if (dtype == /*INT32*/ 2) { legacy_attr = op_desc.GetAttr("int32_values"); - } else if (op_desc.HasAttr("int64_values")) { + } else if (dtype == /*FP32*/ 5) { + legacy_attr = op_desc.GetAttr("fp32_values"); + } else if (dtype == /*INT64*/ 3) { legacy_attr = op_desc.GetAttr("int64_values"); } else { IR_THROW( "Op assign_value should have attribute `**_values` but not find"); } + ir::Attribute attr_values = attribute_translator( attr_info_maps.at("values").type_name, legacy_attr); attribute_map["values"] = attr_values; diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 20e6e8b40d161930c29ed4d422c1b79e9fd02322..8e8cca4296e7d05e79f4837886eb63fde45b0e78 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -208,6 +208,7 @@ - op : argmax(arg_max) inputs : x : X + axis : axis outputs : out : Out scalar: @@ -218,6 +219,7 @@ - op : argmin(arg_min) inputs : x : X + axis : axis outputs : out : Out scalar: diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index d590f0a875d716602a12bd070a2aedd884d81716..f37ae1d688ca5661bc9770e5ff3abc750c1d42ac 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -183,7 +183,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x, } return; } - auto int_axis = axis.to(); const auto& x_dims = x.dims();