未验证 提交 cb4467f0 编写于 作者: H hong 提交者: GitHub

fix_op_translator_input_check (#55065)

上级 ce31a72e
...@@ -314,6 +314,21 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -314,6 +314,21 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
ir::Program* program) { ir::Program* program) {
VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance"; VLOG(10) << "[op:" << op_desc.Type() << "][input] entrance";
auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
std::set<std::string> yaml_input_set;
for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
continue;
}
std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
yaml_input_set.insert(legacy_input_name);
}
// scan all inputs to see if any of them is generated as a vector<Tensor> // scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out. // so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) { for (const auto& n : op_desc.Inputs()) {
...@@ -321,7 +336,9 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -321,7 +336,9 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
auto& args = n.second; auto& args = n.second;
for (const auto& arg_name : args) { for (const auto& arg_name : args) {
IR_ENFORCE(param_map->count(arg_name) != 0, bool check =
param_map->count(arg_name) != 0 || !yaml_input_set.count(arg_name);
IR_ENFORCE(check,
"arg %s.%s as input should be exists before prasing %s", "arg %s.%s as input should be exists before prasing %s",
name, name,
arg_name, arg_name,
...@@ -337,9 +354,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput( ...@@ -337,9 +354,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
VLOG(10) << "[op:" << op_desc.Type() << "][input] start"; VLOG(10) << "[op:" << op_desc.Type() << "][input] start";
std::vector<ir::OpResult> op_inputs; std::vector<ir::OpResult> op_inputs;
auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
for (const auto& info : input_infos) { for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) { if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册