提交 9fd9a1af 编写于 作者: L lianliguang

add warning info to statistics how much nodes using raise or reduce to selected kernel info

上级 3dd369ce
......@@ -342,7 +342,7 @@ void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelB
std::vector<int> *node_mix_precision_datatype_index) {
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
bool add_node_datatype_flag = false;
if (node_mix_precision_datatype->size() == 0) {
if (node_mix_precision_datatype->empty()) {
add_node_datatype_flag = true;
}
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
......@@ -464,8 +464,9 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecis
}
} // namespace
void SelectKernelInfo(const CNodePtr &kernel_node) {
int SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
int status = kStatusAllMatched;
MS_EXCEPTION_IF_NULL(kernel_node);
bool precision_reduce = false;
std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
......@@ -486,11 +487,13 @@ void SelectKernelInfo(const CNodePtr &kernel_node) {
<< "] cannot find valid kernel info, not supported the type" << buffer.str();
} else {
PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
}
}
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
// Set format and data type for input tensor.
SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
return status;
}
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
......
......@@ -21,7 +21,7 @@
namespace mindspore {
namespace device {
namespace ascend {
void SelectKernelInfo(const CNodePtr &kernel_node);
int SelectKernelInfo(const CNodePtr &kernel_node);
bool CheckKernelAccuracySupported(const CNodePtr &kernel_node, const kernel::KernelBuildInfoPtr &new_kernel_build_info);
} // namespace ascend
} // namespace device
......
......@@ -312,10 +312,25 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr
// compile graph steps
void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
MS_LOG(INFO) << "Start!";
size_t raise_precision_count = 0;
size_t reduce_precision_count = 0;
for (const auto &cnode : kernel_graph.execution_order()) {
device::ascend::SelectKernelInfo(cnode);
auto status = device::ascend::SelectKernelInfo(cnode);
if (status == kStatusRaisePrecision) {
raise_precision_count++;
} else if (status == kStatusReducePrecision) {
reduce_precision_count++;
}
MS_LOG(INFO) << "Select ApplyKernel: " << cnode->DebugString();
}
if (raise_precision_count > 0) {
MS_LOG(WARNING) << "There has " << raise_precision_count
<< " node/nodes used raise precision to selected the kernel!";
}
if (reduce_precision_count > 0) {
MS_LOG(WARNING) << "There has " << reduce_precision_count
<< " node/nodes used reduce precision to selected the kernel!";
}
MS_LOG(INFO) << "Finish!";
}
......
......@@ -184,7 +184,10 @@ constexpr auto kControlDependBehindIndex = 2;
// index define of depend
constexpr auto kRealInputIndexInDepend = 1;
constexpr auto kDependAttachNodeIndex = 2;
// status of kernel select result
const int kStatusReducePrecision = -1;
const int kStatusRaisePrecision = 1;
const int kStatusAllMatched = 0;
// format
constexpr auto kOpFormat_DEFAULT = "DefaultFormat";
constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册