提交 7017ac3f 编写于 作者: W wangcong

depthwisecon2d-eltwise pass

上级 89f55e63
......@@ -38,11 +38,6 @@ constexpr auto kFusionKernelNamePrfix = "te_fusion";
constexpr auto kOptional = "optional_";
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
std::map<std::string, std::string> TbeKernelBuild::buffer_fussion_op_map_ = {
{"DepthwiseConv2dNative", "DepthwiseConv2D"},
{"TensorAdd", "Add"}
};
std::string NormalizeFullScopeName(const string &full_scope_name) {
// exp:Default/ReLU-op0 -->Default_ReLU_op0
string normal_ret = full_scope_name;
......@@ -726,6 +721,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
}
std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
static std::map<std::string, std::string> TbeKernelBuild::buffer_fussion_op_map = {
{"DepthwiseConv2dNative", "DepthwiseConv2D"}, {"TensorAdd", "Add"}};
string result = origin_type;
if (buffer_fussion_op_map.find(origin_type) != buffer_fussion_op_map.end()) {
result = buffer_fussion_op_map[origin_type];
}
return result;
}
bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
std::vector<nlohmann::json> *input_desc_list, size_t *index) {
......@@ -831,9 +836,7 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
// gen others
auto type = AnfAlgo::GetCNodeName(cnode);
// replace special op type for buffer fusion op
if (buffer_fussion_op_map_.find(type) != buffer_fussion_op_map_.end()) {
type = buffer_fussion_op_map_[type];
}
type = GetRealOpType(type);
(*compute_op_str)["type"] = type;
tbe::TbeAdapter::NormalizeFuncName(&type);
(*compute_op_str)["func_name"] = type;
......
......@@ -76,7 +76,7 @@ class TbeKernelBuild {
std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool IsDynamicInput(const CNodePtr &cnode);
static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
static std::map<std::string, std::string> buffer_fussion_op_map_;
std::string GetRealOpType(const std::string &origin_type);
};
class TbeKernelJsonCreator {
......
......@@ -200,6 +200,7 @@ const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
......
......@@ -555,7 +555,7 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::
// DepthwiseConvolution--->Elemwise
auto depthwise_conv = cnode->input(1);
MS_EXCEPTION_IF_NULL(depthwise_conv);
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(depthwise_conv) == prim::kPrimDepthwiseConv2dNative->name()) {
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
......@@ -566,8 +566,7 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::
// Elemwise-->DepthwiseConvolution
auto relu = cnode->input(1);
MS_EXCEPTION_IF_NULL(relu);
if (cnode->isa<CNode>() &&
(AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() || AnfAlgo::GetCNodeName(relu) == kReluV2OpName)) {
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
std::unordered_set<AnfNodePtr> record{cnode, relu};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册