提交 932e0ccf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1476 fix Addn GetInputReshapeType failed

Merge pull request !1476 from wenchunjiang/fix_reshape_type_bug
......@@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
return false;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
if (param_type == "dynamic") {
if (dyn_input_sizes.empty()) {
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
......@@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
dyn_input_idx++;
} else if (param_type == "required") {
......@@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
} else {
if (kernel_info_index < real_input_num) {
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
......@@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
builder->SetInputReshapeType(reshape_types);
......@@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
continue;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
size_t output_num = 0;
if (output->param_type() == "dynamic") {
if (outputs.size() > 1) {
......@@ -467,13 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
outputs_device_type.push_back(type_id);
outputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
output_idx++;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
builder->SetOutputReshapeType(reshape_types);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册