提交 ce57e02d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1562 don't set parameter's format when it's has been setted before

Merge pull request !1562 from lianliguang/r0.3
......@@ -166,20 +166,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// we set special device info of a input tensor.
bool is_ref = false;
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
if (op_info != nullptr) {
is_ref = op_info->is_ref();
}
MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
continue;
}
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)};
builder->SetOutputsDeviceType(output_type);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
}
......
......@@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
return false;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
if (param_type == "dynamic") {
if (dyn_input_sizes.empty()) {
MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
......@@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
dyn_input_idx++;
} else if (param_type == "required") {
......@@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
} else {
if (kernel_info_index < real_input_num) {
MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
......@@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
inputs_device_type.push_back(type_id);
inputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
}
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(input->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
builder->SetInputReshapeType(reshape_types);
......@@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
continue;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
size_t output_num = 0;
if (output->param_type() == "dynamic") {
if (outputs.size() > 1) {
......@@ -467,12 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
outputs_device_type.push_back(type_id);
outputs_format.push_back(formats[builder_idex]);
reshape_types.push_back(reshape_type);
output_idx++;
}
std::vector<Axis> reshape_type;
if (!StringToAxisVector(output->reshape_type(), &reshape_type)) {
return false;
}
reshape_types.push_back(reshape_type);
}
......
......@@ -33,12 +33,15 @@ using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node, const TypeId device_type,
const kernel::KernelBuildInfo &ori_build_info) {
const kernel::KernelBuildInfo &ori_build_info,
const std::vector<kernel::Axis> &reshape_type) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({device_type});
builder.SetOutputReshapeType({reshape_type});
builder.SetInputReshapeType({reshape_type});
builder.SetKernelType(ori_build_info.kernel_type());
builder.SetFusionType(ori_build_info.fusion_type());
builder.SetProcessor(ori_build_info.processor());
......@@ -175,6 +178,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node;
AnfNodePtr trans_data = nullptr;
std::vector<kernel::Axis> reshape_type = AnfAlgo::GetOutputReshapeType(node, 0);
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
MS_EXCEPTION_IF_NULL(node);
if (origin_format.empty() || dest_format.empty()) {
......@@ -189,6 +193,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
MS_EXCEPTION_IF_NULL(cnode);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
reshape_type = AnfAlgo::GetInputReshapeType(node, insert_index);
}
bool need_padding = false;
if (is_insert_input) {
......@@ -222,7 +227,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL(trans_data);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info);
auto kernel_build_info =
RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info, reshape_type);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node;
}
......@@ -309,9 +315,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto cur_input = AnfAlgo::GetInputNode(cnode, input_index);
auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0);
auto is_weight_boundary = [](const AnfNodePtr &node) -> bool {
if (node->isa<ValueNode>()) {
return true;
} else if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
if (node->isa<ValueNode>() || node->isa<Parameter>()) {
return true;
}
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册