提交 0f42c662 编写于 作者: W wangcong

DepthwiseConv2d+Eltwise fusion pass

上级 168dfb25
......@@ -38,6 +38,11 @@ constexpr auto kFusionKernelNamePrfix = "te_fusion";
constexpr auto kOptional = "optional_";
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
std::map<std::string, std::string> TbeKernelBuild::buffer_fussion_op_map_ = {
{"DepthwiseConv2dNative", "DepthwiseConv2D"},
{"TensorAdd", "Add"}
};
std::string NormalizeFullScopeName(const string &full_scope_name) {
// exp:Default/ReLU-op0 -->Default_ReLU_op0
string normal_ret = full_scope_name;
......@@ -825,8 +830,9 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
(*compute_op_str)["output_desc"] = output_desc_list;
// gen others
auto type = AnfAlgo::GetCNodeName(cnode);
if (type == "TensorAdd") {
type = "Add";
// replace special op type for buffer fusion op
if (buffer_fussion_op_map_.find(type) != buffer_fussion_op_map_.end()) {
type = buffer_fussion_op_map_[type];
}
(*compute_op_str)["type"] = type;
tbe::TbeAdapter::NormalizeFuncName(&type);
......
......@@ -76,6 +76,7 @@ class TbeKernelBuild {
std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool IsDynamicInput(const CNodePtr &cnode);
static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
static std::map<std::string, std::string> buffer_fussion_op_map_;
};
class TbeKernelJsonCreator {
......
......@@ -545,6 +545,39 @@ void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr
}
}
void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion, bool is_order) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
if (is_order) {
// DepthwiseConvolution--->Elemwise
auto depthwise_conv = cnode->input(1);
MS_EXCEPTION_IF_NULL(depthwise_conv);
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(depthwise_conv) == prim::kPrimDepthwiseConv2dNative->name()) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
} else {
// Elemwise-->DepthwiseConvolution
auto relu = cnode->input(1);
MS_EXCEPTION_IF_NULL(relu);
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() ||
AnfAlgo::GetCNodeName() == kReluV2OpName) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
std::unordered_set<AnfNodePtr> record{cnode, relu};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
}
}
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
......@@ -563,7 +596,11 @@ void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph,
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
}
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
}
}
}
......
......@@ -51,6 +51,8 @@ class BufferFusion : public Pass {
FusedNodeRecord *candidate_fusion);
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion, bool is_order);
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册