diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 699d4d5d4005f378a55df4b49f97c033807686a1..e4d7c7ca5f41b3fd785b02f5a8e3e34275ff9592 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -166,20 +166,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co std::shared_ptr builder = std::make_shared(); // we set special device info of a input tensor. - bool is_ref = false; - auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); - if (op_info != nullptr) { - is_ref = op_info->is_ref(); - } - MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - if (MsContext::GetInstance()->execution_mode() == kPynativeMode && - AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) { - continue; - } - if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { + if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown) { std::vector output_format = {selected_kernel_info.GetInputFormat(input_index)}; builder->SetOutputsFormat(output_format); - std::vector output_type = {selected_kernel_info.GetInputDeviceType(input_index)}; + std::vector output_type = {AnfAlgo::GetOutputInferDataType(real_input_node, 0)}; builder->SetOutputsDeviceType(output_type); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get()); } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index c809618b339851c18cab672eeb58d2d91b622f81..bd3ed8e30ef2611ecdc5b8aefdea04fda16c5716 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp return false; } + std::vector reshape_type; + if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { + return false; + } + if (param_type == "dynamic") { if (dyn_input_sizes.empty()) { MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; @@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } dyn_input_idx++; } else if (param_type == "required") { @@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } else { if (kernel_info_index < real_input_num) { MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; @@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } } - std::vector reshape_type; - if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { - return false; - } - reshape_types.push_back(reshape_type); } builder->SetInputReshapeType(reshape_types); @@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; continue; } + std::vector reshape_type; + if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { + return false; + } + size_t output_num = 0; if (output->param_type() == "dynamic") { if (outputs.size() > 1) { @@ -467,12 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); outputs_device_type.push_back(type_id); outputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); output_idx++; } - std::vector reshape_type; - if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { - return false; - } reshape_types.push_back(reshape_type); } diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 05d36fd4f22e32b9ccddfad7537dde2f3c7ce12b..0422456971fcc2615de299a478e78179c59f07a5 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -33,12 +33,15 @@ using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const AnfNodePtr &node, const TypeId device_type, - const kernel::KernelBuildInfo &ori_build_info) { + const kernel::KernelBuildInfo &ori_build_info, + const std::vector &reshape_type) { KernelBuildInfoBuilder builder; builder.SetInputsFormat({input_format}); builder.SetOutputsFormat({output_format}); builder.SetInputsDeviceType({device_type}); builder.SetOutputsDeviceType({device_type}); + builder.SetOutputReshapeType({reshape_type}); + builder.SetInputReshapeType({reshape_type}); builder.SetKernelType(ori_build_info.kernel_type()); builder.SetFusionType(ori_build_info.fusion_type()); builder.SetProcessor(ori_build_info.processor()); @@ -175,6 +178,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt AnfNodePtr trans_node = nullptr; AnfNodePtr input_node = node; AnfNodePtr trans_data = nullptr; + std::vector reshape_type = AnfAlgo::GetOutputReshapeType(node, 0); TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); MS_EXCEPTION_IF_NULL(node); if (origin_format.empty() || dest_format.empty()) { @@ -189,6 +193,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); MS_EXCEPTION_IF_NULL(cnode); input_node = AnfAlgo::GetInputNode(cnode, insert_index); + reshape_type = AnfAlgo::GetInputReshapeType(node, insert_index); } bool need_padding = false; if (is_insert_input) { @@ -222,7 +227,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); - auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); + auto kernel_build_info = + RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info, reshape_type); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); return trans_node; } @@ -309,9 +315,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { - if (node->isa()) { - return true; - } else if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + if (node->isa() || node->isa()) { return true; } return false;