From cb4467f0a2d051913712da070a6e3785373c5acf Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Mon, 3 Jul 2023 10:39:24 +0800 Subject: [PATCH] fix_op_translator_input_check (#55065) --- .../ir_adaptor/translator/op_translator.cc | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index b12b067e72b..423feaabb30 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -314,6 +314,21 @@ std::vector OpTranscriber::GenerateOperationInput( ir::Program* program) { VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; + auto& op_normalizer = OpNameNormalizer::instance(); + const auto* mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); + + std::set yaml_input_set; + for (const auto& info : input_infos) { + if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { + continue; + } + + std::string legacy_input_name = + op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + + yaml_input_set.insert(legacy_input_name); + } // scan all inputs to see if any of them is generated as a vector // so need an additional `SliceOp` to take it out. for (const auto& n : op_desc.Inputs()) { @@ -321,7 +336,9 @@ std::vector OpTranscriber::GenerateOperationInput( auto& args = n.second; for (const auto& arg_name : args) { - IR_ENFORCE(param_map->count(arg_name) != 0, + bool check = + param_map->count(arg_name) != 0 || !yaml_input_set.count(arg_name); + IR_ENFORCE(check, "arg %s.%s as input should be exists before prasing %s", name, arg_name, @@ -337,9 +354,6 @@ std::vector OpTranscriber::GenerateOperationInput( VLOG(10) << "[op:" << op_desc.Type() << "][input] start"; std::vector op_inputs; - auto& op_normalizer = OpNameNormalizer::instance(); - const auto* mutable_attributes = - op_normalizer.GetMutableAttributes(op_desc.Type()); for (const auto& info : input_infos) { if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { -- GitLab