diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index b12b067e72b5d9035e82cd0b55d41d240d3029b3..423feaabb30181f5c2116d3a98433f48b41f0332 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -314,6 +314,21 @@ std::vector OpTranscriber::GenerateOperationInput( ir::Program* program) { VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; + auto& op_normalizer = OpNameNormalizer::instance(); + const auto* mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); + + std::set yaml_input_set; + for (const auto& info : input_infos) { + if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { + continue; + } + + std::string legacy_input_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + + yaml_input_set.insert(legacy_input_name); + } // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -321,7 +336,9 @@ std::vector OpTranscriber::GenerateOperationInput( auto& args = n.second; for (const auto& arg_name : args) { - IR_ENFORCE(param_map->count(arg_name) != 0, + bool check = + param_map->count(arg_name) != 0 || !yaml_input_set.count(arg_name); + IR_ENFORCE(check, "arg %s.%s as input should be exists before prasing %s", name, arg_name, @@ -337,9 +354,6 @@ std::vector OpTranscriber::GenerateOperationInput( VLOG(10) << "[op:" << op_desc.Type() << "][input] start"; std::vector op_inputs; - auto& op_normalizer = OpNameNormalizer::instance(); - const auto* mutable_attributes = - op_normalizer.GetMutableAttributes(op_desc.Type()); for (const auto& info : input_infos) { if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {