提交 9b7df3d0 编写于 作者: V VectorSL

gpu optimize transpose

上级 3449abd7
......@@ -33,7 +33,25 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
return {0, 3, 1, 2};
} else {
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format;
MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format;
}
}
// Transpose can be replaceed by nop reshape in some situations.
// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1}
// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2}
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) {
if (out_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D.";
}
std::vector<int> perm1 = {0, 2, 3, 1};
std::vector<int> perm2 = {0, 3, 1, 2};
if (transpose_perm == perm1) {
return (out_shape[1] == 1 && out_shape[2] == 1);
} else if (transpose_perm == perm2) {
return (out_shape[2] == 1 && out_shape[3] == 1);
} else {
return false;
}
}
......@@ -56,8 +74,16 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
int used_node_index, const std::vector<int> &transpose_perm) {
MS_EXCEPTION_IF_NULL(graph);
// 1.Create a transpose node.
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
// 0.Judge whether it is a fake transpose
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
// 1.Create a transpose node or a fake transpose node:reshape.
mindspore::PrimitivePtr transpose_prim;
if (is_fake) {
transpose_prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
} else {
transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
}
MS_EXCEPTION_IF_NULL(transpose_prim);
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
......@@ -66,7 +92,9 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
if (!is_fake) {
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
}
// 4.Set the input of used_node.
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
<< ", index: " << used_node_index;
......
......@@ -57,7 +57,7 @@ const AnfNodePtr ReplaceBNCastFusion::Process(const FuncGraphPtr &graph, const A
if (item_idx == 0) {
auto cast = GetRealNodeUsedList(graph, outlist->at(i).first);
if (AnfAlgo::GetCNodeName(cast->at(0).first) != "Cast") {
return nullptr;
continue;
}
manager->Replace(utils::cast<CNodePtr>(cast->at(0).first), utils::cast<CNodePtr>(outlist->at(i).first));
outputs_type.push_back(kNumberTypeFloat16);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册