提交 1b406148 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5517 gpu kernel_info_setter code review

Merge pull request !5517 from limingqi107/master
......@@ -49,10 +49,10 @@ using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
bool in_black_list = CheckInModeBlackList(kernel_graph);
bool graph_format_transform = IsSupportFormatTransform(kernel_graph);
for (const auto &kernel_node : kernel_graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel_node);
device::gpu::SetKernelInfo(kernel_node, in_black_list);
device::gpu::SetKernelInfo(kernel_node, graph_format_transform);
}
}
......@@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm->AddPass(std::make_shared<opt::ReplaceBNGradCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
if (!CheckInModeBlackList(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
if (IsSupportFormatTransform(kernel_graph) && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
......@@ -193,14 +193,14 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
}
}
bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const {
bool GPUSession::IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const {
auto kernels = kernel_graph->execution_order();
size_t conv_cnt = 0;
size_t bn_cnt = 0;
for (const auto &kernel : kernels) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == prim::kPrimLayerNorm->name()) {
return true;
return false;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
......@@ -210,9 +210,9 @@ bool GPUSession::CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel
}
}
if (conv_cnt == kConv2dCount && bn_cnt == kFusedBatchNormCount) {
return true;
return false;
}
return false;
return true;
}
GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
......
......@@ -67,7 +67,7 @@ class GPUSession : public SessionBasic {
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
bool CheckInModeBlackList(const std::shared_ptr<KernelGraph> &kernel_graph) const;
bool IsSupportFormatTransform(const std::shared_ptr<KernelGraph> &kernel_graph) const;
#ifdef ENABLE_DEBUGGER
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......
......@@ -404,7 +404,9 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v
// Release the kernel resource.
for (const auto &kernel : execution_order) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (kernel_mod == nullptr) {
continue;
}
kernel_mod->ReleaseResource();
}
}
......
......@@ -176,9 +176,18 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
if (inputs_type.size() == 0) {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
return false;
auto inputs_format_position = iter->second.first;
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
if (inputs_format_position.size() == 0) {
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) {
inputs_format_position.push_back(input_index);
}
}
for (const auto &input_format_position : inputs_format_position) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
if (input_shape.size() != 4) {
return false;
}
}
return true;
}
......@@ -223,7 +232,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
}
} // namespace
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) {
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
......@@ -237,7 +246,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) {
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
std::string origin_data_format = kOpFormat_DEFAULT;
if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
}
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
......
......@@ -53,7 +53,7 @@ static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>
{prim::kPrimAddN->name(), {{}, {0}}},
};
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list = false);
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform = false);
class KernelAttr {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册