提交 4ec4205e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4221 gpu add format transform pass

Merge pull request !4221 from limingqi107/master
......@@ -62,7 +62,7 @@ class TransposeGpuFwdKernel : public GpuKernel {
MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
shape_size_ = input_shape.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but transpose supports max " << TRANSPOSE_MAX_DIMENSION
......
......@@ -52,6 +52,8 @@ TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
return outputs_device_type_[output_index];
}
const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; }
const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
......@@ -132,6 +134,11 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &ke
kernel_build_info_->kernel_type_ = kernel_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->origin_data_format_ = origin_data_format;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->inputs_format_ = inputs_format;
......
......@@ -38,6 +38,7 @@ class KernelBuildInfo {
op_pattern_ = kCommonPattern;
input_reshape_type_ = {};
output_reshape_type_ = {};
origin_data_format_ = kOpFormat_DEFAULT;
inputs_format_ = {};
outputs_format_ = {};
inputs_device_type_ = {};
......@@ -64,6 +65,8 @@ class KernelBuildInfo {
std::vector<Axis> GetOutputReshapeType(size_t input_index) const;
const std::string &GetOriginDataFormat() const;
const std::vector<std::string> &GetAllInputFormats() const;
const std::vector<std::string> &GetAllOutputFormats() const;
......@@ -97,6 +100,7 @@ class KernelBuildInfo {
private:
KernelType kernel_type_;
std::string origin_data_format_;
std::vector<std::string> inputs_format_;
OpPattern op_pattern_;
std::vector<std::string> outputs_format_;
......@@ -135,6 +139,8 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetKernelType(const KernelType &kernel_type);
void SetOriginDataFormat(const std::string &origin_data_format);
void SetInputsFormat(const std::vector<std::string> &inputs_format);
void SetOutputsFormat(const std::vector<std::string> &outputs_format);
......
......@@ -506,6 +506,45 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con
return output_node_list;
}
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto iter = manager->node_users().find(node);
if (iter == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) {
continue;
}
size_t used_output_index;
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimTupleGetItem->name()) {
used_output_index = AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
} else if (AnfAlgo::GetCNodeName(node) == prim::kPrimTupleGetItem->name()) {
used_output_index = output_index;
} else {
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(output_info.first, output_info.second - 1);
if (kernel_with_index.first.get() != node.get()) {
MS_LOG(EXCEPTION) << "Get used node failed for op[" << AnfAlgo::GetCNodeName(node) << "]";
}
used_output_index = kernel_with_index.second;
}
if (used_output_index == output_index) {
output_node_list->push_back(output_info);
}
}
return output_node_list;
}
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
......
......@@ -172,6 +172,10 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index);
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
bool AnfEqual(const BaseRef &a, const BaseRef &b);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
namespace mindspore {
namespace opt {
namespace {
std::vector<int> TransposeAxis(const std::string &src_format, const std::string &dst_format) {
if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) {
return {0, 2, 3, 1};
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
return {0, 3, 1, 2};
} else {
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format;
}
}
void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
auto output_type = AnfAlgo::GetOutputInferDataType(node, 0);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsFormat({output_format});
builder.SetOutputsDeviceType({output_type});
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
builder.SetProcessor(kernel::Processor::CUDA);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}
// Insert transpose op between node and used_node whose position is used_node_index.
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
int used_node_index, const std::vector<int> &transpose_perm) {
MS_EXCEPTION_IF_NULL(graph);
// 1.Create a transpose node.
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
MS_EXCEPTION_IF_NULL(transpose_prim);
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
auto transpose_op = graph->NewCNode(transpose_input);
// 3.Set the output info of transpose.
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
// 4.Set the input of used_node.
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
<< ", index: " << used_node_index;
AnfAlgo::SetNodeInput(utils::cast<CNodePtr>(used_node), transpose_op, used_node_index);
// 5. Update the manager info of transpose op.
FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->Clear();
manager->AddFuncGraph(graph);
return transpose_op;
}
} // namespace
const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
if (!AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node));
if (iter == device::gpu::kKernelFormatPositionMap.end()) {
return nullptr;
}
auto origin_data_format = AnfAlgo::GetOriginDataFormat(node);
if (origin_data_format == kOpFormat_DEFAULT) {
origin_data_format = kOpFormat_NCHW;
}
MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope();
// Insert input transpose from origin_data_format to input_format.
auto inputs_format = AnfAlgo::GetAllInputFormats(node);
for (size_t i = 0; i < inputs_format.size(); i++) {
if ((inputs_format[i] != kOpFormat_DEFAULT) && (inputs_format[i] != origin_data_format)) {
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
MS_EXCEPTION_IF_NULL(input_node);
auto transpose_perm = TransposeAxis(origin_data_format, inputs_format[i]);
auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
}
}
// Insert output transpose from output_format to origin_data_format.
auto outputs_format = AnfAlgo::GetAllOutputFormats(node);
for (size_t i = 0; i < outputs_format.size(); i++) {
if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) {
// Find all nodes connected with node output, and change their inputs to transpose.
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
auto used_node_index = used_node_list->at(j).second - 1;
auto transpose_perm = TransposeAxis(outputs_format[i], origin_data_format);
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item.";
// The tuple item need get next used nodes again.
ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]);
continue;
}
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op);
}
}
}
return node;
}
void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
const std::vector<int> &transpose_perm,
const std::string &transpose_format) const {
auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
for (size_t i = 0; i < used_node_list->size(); i++) {
auto used_node = used_node_list->at(i).first;
auto used_node_index = used_node_list->at(i).second - 1;
if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item.";
}
auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op);
}
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class InsertFormatTransformOp : public PatternProcessPass {
public:
explicit InsertFormatTransformOp(bool multigraph = true)
: PatternProcessPass("insert_format_transform_op", multigraph) {}
~InsertFormatTransformOp() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
const std::vector<int> &transpose_perm, const std::string &transpose_format) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_INSERT_FORMAT_TRANSFORM_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include <memory>
#include <vector>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef RemoveFormatTransformPair::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X});
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1});
return transpose2;
}
const AnfNodePtr RemoveFormatTransformPair::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
MS_LOG(DEBUG) << "Process node:" << node->fullname_with_scope();
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(input_node);
if (AnfAlgo::GetCNodeName(node) != prim::kPrimTranspose->name() ||
AnfAlgo::GetCNodeName(input_node) != prim::kPrimTranspose->name()) {
MS_LOG(EXCEPTION) << "The pattern is not transpose pair, "
<< "node:" << AnfAlgo::GetCNodeName(node) << " node input:" << AnfAlgo::GetCNodeName(input_node);
}
// If transpose operator used by more than one other operators, it cant not be deleted directly.
if (IsUsedByOthers(graph, input_node)) {
MS_LOG(DEBUG) << "The transpose node [" << input_node->fullname_with_scope()
<< "] is used by more than one other operators.";
return nullptr;
}
auto transpose1_input_shape = AnfAlgo::GetInputDeviceShape(input_node, 0);
auto transpose2_output_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
if (transpose2_output_shape == transpose1_input_shape) {
auto transpose1_input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(input_node), 0);
MS_EXCEPTION_IF_NULL(transpose1_input_node);
return transpose1_input_node;
}
return nullptr;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class RemoveFormatTransformPair : public PatternProcessPass {
public:
explicit RemoveFormatTransformPair(bool multigraph = true)
: PatternProcessPass("remove_format_transform_pair", multigraph) {}
~RemoveFormatTransformPair() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_REMOVE_FORMAT_TRANSFORM_PAIR_H_
......@@ -353,6 +353,48 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
}
}
std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto format = build_info->GetAllOutputFormats();
return format;
}
std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto format = build_info->GetAllInputFormats();
return format;
}
std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]";
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto format = build_info->GetOriginDataFormat();
return format;
}
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
......@@ -829,7 +871,7 @@ void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *
bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real kernel too
// parameter and value node is a real kernel too
if (!node->isa<CNode>()) {
return true;
}
......
......@@ -101,6 +101,12 @@ class AnfRuntimeAlgorithm {
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get the num of output real_kernel(which can be build and run in device)
static size_t GetOutputTensorNum(const AnfNodePtr &node);
// get all outputs format select of anf node
static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node);
// get all inputs format select of anf node
static std::vector<std::string> GetAllInputFormats(const AnfNodePtr &node);
// get origin data format select of anf node
static std::string GetOriginDataFormat(const AnfNodePtr &node);
// get output format select of anf node
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
// get input format select of anf node
......
......@@ -30,6 +30,8 @@
#include "backend/optimizer/gpu/replace_bn_grad_cast_fusion.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
#include "backend/optimizer/gpu/replace_addn_fusion.h"
#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include "backend/optimizer/gpu/remove_format_transform_pair.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/ms_utils.h"
#include "common/trans.h"
......@@ -76,6 +78,8 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::AllReduceFusion>());
pm->AddPass(std::make_shared<opt::GetitemTuple>());
optimizer->AddPassManager(pm);
......@@ -203,7 +207,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
}
// Assign CUDA streams
AssignStream(graph);
// Hide NoOp from execution graph
// Hide NopOp from execution graph
opt::HideNopNode(graph.get());
// Build kernel if node is cnode
BuildKernel(graph);
......@@ -213,7 +217,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
graph->set_execution_order(execution_order);
// Get summary nodes.
SetSummaryNodes(graph.get());
// Remove NoOp from execution graph
// Remove NopOp from execution graph
opt::RemoveNopNode(graph.get());
// Set graph manager.
MS_EXCEPTION_IF_NULL(context_);
......@@ -272,7 +276,7 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in
MS_EXCEPTION_IF_NULL(kernel_graph);
SelectKernel(kernel_graph);
StartKernelRT();
// Hide NoOp from execution graph
// Hide NopOp from execution graph
opt::HideNopNode(kernel_graph.get());
BuildKernel(kernel_graph);
run_op_graphs_[graph_info] = kernel_graph;
......@@ -282,7 +286,7 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph
const std::vector<tensor::TensorPtr> &input_tensors) {
auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph);
// Remove NoOp from execution graph
// Remove NopOp from execution graph
opt::RemoveNopNode(kernel_graph.get());
RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get());
// Execute the computation
......
......@@ -252,7 +252,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it.");
} else {
if (!stop_send_) {
MS_LOG(WARNING) << "Retry pushing data...";
MS_LOG(DEBUG) << "Retry pushing data...";
continue;
}
break;
......
......@@ -96,10 +96,10 @@ BlockQueueStatus_T BlockingQueue::Create(void *addr, const std::vector<size_t> &
void BlockingQueue::RegisterRelease(const std::function<void(void *)> &func) { queue_->RegisterRelease(func); }
BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int timeout_in_sec) {
BlockQueueStatus_T BlockingQueue::Push(const std::vector<DataItemGpu> &data, unsigned int) {
std::unique_lock<std::mutex> locker(mutex_);
if (queue_->IsFull()) {
if (not_full_cond_.wait_for(locker, std::chrono::seconds(timeout_in_sec)) == std::cv_status::timeout) {
if (not_full_cond_.wait_for(locker, std::chrono::microseconds(100)) == std::cv_status::timeout) {
return TIMEOUT;
}
}
......
......@@ -19,6 +19,7 @@
#include <memory>
#include "backend/kernel_compiler/kernel.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/session/anf_runtime_algorithm.h"
......@@ -157,25 +158,87 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
}
}
}
bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
return false;
}
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
return false;
}
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kKernelFormatPositionMap.find(kernel_name);
if (iter == kKernelFormatPositionMap.end()) {
return false;
}
if (inputs_type.size() == 0) {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
return false;
}
return false;
}
void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeId> &inputs_type,
std::vector<std::string> *inputs_format, std::vector<std::string> *outputs_format,
std::string *origin_data_format) {
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kKernelFormatPositionMap.find(kernel_name);
if (iter == kKernelFormatPositionMap.end()) {
return;
}
auto cal_format = (inputs_type[0] == kNumberTypeFloat16) ? kOpFormat_NHWC : kOpFormat_NCHW;
MS_LOG(DEBUG) << "Kernel node: " << kernel_node->fullname_with_scope() << ", format: " << cal_format;
auto inputs_format_position = iter->second.first;
for (const auto &input_format_position : inputs_format_position) {
if (input_format_position >= inputs_format->size()) {
MS_LOG(EXCEPTION) << "The position [" << input_format_position << "] is out of range of the input size ["
<< inputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
}
(*inputs_format)[input_format_position] = cal_format;
}
auto outputs_format_position = iter->second.second;
for (const auto &output_format_position : outputs_format_position) {
if (output_format_position >= outputs_format->size()) {
MS_LOG(EXCEPTION) << "The position [" << output_format_position << "] is out of range of the output size ["
<< outputs_format->size() << "] #kernel_node [" << kernel_node->fullname_with_scope() << "]";
}
(*outputs_format)[output_format_position] = cal_format;
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr("data_format")) {
*origin_data_format = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "data_format");
}
}
} // namespace
void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(kOpFormat_DEFAULT);
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
}
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(kOpFormat_DEFAULT);
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
}
std::string origin_data_format = kOpFormat_DEFAULT;
if (IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
}
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetOriginDataFormat(origin_data_format);
builder->SetInputsFormat(inputs_format);
builder->SetInputsDeviceType(inputs_type);
builder->SetOutputsFormat(outputs_format);
builder->SetOutputsDeviceType(outputs_type);
......
......@@ -20,13 +20,35 @@
#include <utility>
#include <string>
#include <vector>
#include <map>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "utils/utils.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace device {
namespace gpu {
// map<opName, (inputFormatPosition, outputFormatPosition)>, used for getting the insert position of format transform.
static std::map<std::string, std::pair<std::vector<size_t>, std::vector<size_t>>> kKernelFormatPositionMap = {
{prim::kPrimConv2D->name(), {{0, 1}, {0}}},
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {0}}},
{prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {0}}},
{prim::kPrimRelu->name(), {{0}, {0}}},
{prim::kPrimReluGrad->name(), {{0, 1}, {0}}},
{prim::kPrimMaxPool->name(), {{0}, {0}}},
{prim::kPrimMaxPoolGrad->name(), {{0, 1, 2}, {0}}},
{kAvgPoolOpName, {{0}, {0}}},
{kAvgPoolGradGpuOpName, {{0, 1, 2}, {0}}},
{kTensorAddOpName, {{0, 1}, {0}}},
{kFusedBatchNormEx, {{0}, {0}}},
{kFusedBatchNormExWithActivation, {{0}, {0}}},
{kFusedBatchNormExWithAddAndActivation, {{0, 5}, {0}}},
{kFusedBatchNormGradEx, {{0, 1}, {0}}},
{kFusedBatchNormGradExWithActivation, {{0, 1, 7}, {0}}},
{kFusedBatchNormGradExWithAddAndActivation, {{0, 1, 7}, {0, 3}}},
};
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr {
......
......@@ -189,6 +189,9 @@ constexpr auto kPullOpName = "Pull";
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
constexpr auto kEmbeddingLookupProxyOpName = "EmbeddingLookupProxy";
constexpr auto kPaddingOpName = "Padding";
constexpr auto kAvgPoolOpName = "AvgPool";
constexpr auto kAvgPoolGradGpuOpName = "AvgPoolGradGpu";
constexpr auto kTensorAddOpName = "TensorAdd";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册