提交 cd6ed0e3 编写于 作者: J jjfeing

support buffer fusion

上级 bf1d0031
...@@ -122,10 +122,12 @@ def get_args(op_info, arg_type): ...@@ -122,10 +122,12 @@ def get_args(op_info, arg_type):
elif arg_type == 'attrs': elif arg_type == 'attrs':
for item in op_info[arg_type]: for item in op_info[arg_type]:
if 'value' not in item: if item["valid"]:
raise ValueError("Json string Errors, attr key:value not found.") if 'value' not in item:
if item["name"] != "isRef": raise ValueError("Json string Errors, attr key:value not found.")
args.append(item['value']) if item["name"] != "isRef":
args.append(item['value'])
return args return args
......
...@@ -108,7 +108,8 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo> ...@@ -108,7 +108,8 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
} }
if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) { if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) {
MS_LOG(DEBUG) << "fuison op build failed, err log: " << task_result << " change to single op build."; MS_LOG(INFO) << "Fusion warning: Fuison op build failed, err log: " << task_result
<< " change to single op build.";
build_failed_num++; build_failed_num++;
} }
auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false);
......
...@@ -153,6 +153,52 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec ...@@ -153,6 +153,52 @@ void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vec
} }
} }
void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector<nlohmann::json> &inputs_list,
std::vector<nlohmann::json> *inputs_json) {
MS_EXCEPTION_IF_NULL(inputs_json);
if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
(void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
} else {
if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
inputs_json->emplace_back(inputs_list[2]);
inputs_json->emplace_back(inputs_list[0]);
inputs_json->emplace_back(inputs_list[1]);
for (size_t i = 3; i < inputs_list.size(); ++i) {
inputs_json->emplace_back(inputs_list[i]);
}
} else {
inputs_json->emplace_back(inputs_list[1]);
inputs_json->emplace_back(inputs_list[0]);
for (size_t i = 2; i < inputs_list.size(); ++i) {
inputs_json->emplace_back(inputs_list[i]);
}
}
}
}
void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector<AnfNodePtr> &data_layer,
std::vector<AnfNodePtr> *reorder_data_layer) {
MS_EXCEPTION_IF_NULL(reorder_data_layer);
if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
(void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer)));
} else {
if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
reorder_data_layer->emplace_back(data_layer[2]);
reorder_data_layer->emplace_back(data_layer[0]);
reorder_data_layer->emplace_back(data_layer[1]);
for (size_t i = 3; i < data_layer.size(); ++i) {
reorder_data_layer->emplace_back(data_layer[i]);
}
} else {
reorder_data_layer->emplace_back(data_layer[1]);
reorder_data_layer->emplace_back(data_layer[0]);
for (size_t i = 2; i < data_layer.size(); ++i) {
reorder_data_layer->emplace_back(data_layer[i]);
}
}
}
}
std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = { std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
{"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass}, {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass},
{"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass}, {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass},
......
...@@ -44,15 +44,12 @@ class TbeAdapter { ...@@ -44,15 +44,12 @@ class TbeAdapter {
static void GenTopKV2IndicesTensorInfo(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index, static void GenTopKV2IndicesTensorInfo(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
std::vector<nlohmann::json> *input_list, kCreaterType creater_type); std::vector<nlohmann::json> *input_list, kCreaterType creater_type);
static void FusionInputOrderPass(const std::string &op_name, const std::vector<nlohmann::json> &inputs_list,
std::vector<nlohmann::json> *inputs_json);
static void FusionDataOrderPass(const std::string &op_name, const std::vector<AnfNodePtr> &data_layer,
std::vector<AnfNodePtr> *reorder_data_layer);
private: private:
static void Conv2DAttrJsonPass(const AnfNodePtr &anf_node, const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void Conv2DBackpropFilterAttrJsonPass(const AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void Conv2DBackpropInputAttrJsonPass(const AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json);
static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node, static void MaximumGradAttrJsonPass(const AnfNodePtr &anf_node,
const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs, const std::vector<std::shared_ptr<OpAttr>> &op_info_attrs,
nlohmann::json *attrs_json); nlohmann::json *attrs_json);
......
...@@ -375,20 +375,26 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no ...@@ -375,20 +375,26 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_no
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
for (const auto &attr_ptr : attrs_ptr) { for (const auto &attr_ptr : attrs_ptr) {
std::string attr_name = attr_ptr->name(); std::string attr_name = attr_ptr->name();
nlohmann::json attr_obj;
attr_obj["name"] = attr_name;
if (primitive->GetAttr(attr_name) != nullptr) { if (primitive->GetAttr(attr_name) != nullptr) {
nlohmann::json attr_obj;
auto value = primitive->GetAttr(attr_name); auto value = primitive->GetAttr(attr_name);
std::string type = attr_ptr->type(); std::string type = attr_ptr->type();
ParseAttrValue(type, value, &attr_obj); ParseAttrValue(type, value, &attr_obj);
attr_obj["name"] = attr_name;
attr_obj["valid"] = true; attr_obj["valid"] = true;
(*attrs_json).push_back(attr_obj);
} else { } else {
if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD && op_info->impl_path() != "") { if (op_info->impl_path().empty()) {
MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name attr_obj["valid"] = false;
<< " is required, but not set."; } else {
if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD) {
MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name
<< " is required, but not set.";
} else {
attr_obj["valid"] = false;
}
} }
} }
(*attrs_json).push_back(attr_obj);
} }
return true; return true;
} }
...@@ -484,7 +490,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp ...@@ -484,7 +490,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp
MS_EXCEPTION_IF_NULL(fusion_kernel); MS_EXCEPTION_IF_NULL(fusion_kernel);
// get input layer info // get input layer info
std::vector<std::vector<mindspore::AnfNodePtr>> input_layers; std::vector<std::vector<mindspore::AnfNodePtr>> input_layers;
if (!GetInputLayers(input_nodes, compute_nodes, &input_layers)) { std::map<const AnfNodePtr, FusionDataType> spec_data_input;
if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) {
return false; return false;
} }
// gen fusion scopre_op jsom // gen fusion scopre_op jsom
...@@ -505,8 +512,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp ...@@ -505,8 +512,8 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp
for (const auto &layer : input_layers) { for (const auto &layer : input_layers) {
for (const auto &data_input : layer) { for (const auto &data_input : layer) {
nlohmann::json data_str; nlohmann::json data_str;
if (!GenFusionDataInputJson(data_input, &data_str, &index)) { if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) {
MS_LOG(DEBUG) << "GenFusionDataInputJson faild."; MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild.";
return false; return false;
} }
data_list.push_back(data_str); data_list.push_back(data_str);
...@@ -519,7 +526,7 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp ...@@ -519,7 +526,7 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &inp
} }
void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
size_t desc_output_idx, nlohmann::json *output_desc) { size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) {
std::string output_desc_name = anf_node->fullname_with_scope(); std::string output_desc_name = anf_node->fullname_with_scope();
if (node_out_idx > 0) { if (node_out_idx > 0) {
output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx);
...@@ -539,58 +546,109 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_ ...@@ -539,58 +546,109 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_
(*output_desc)["shape"] = shape; (*output_desc)["shape"] = shape;
auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
if (format == kOpFormat_DEFAULT) { if (format == kOpFormat_DEFAULT) {
if (ori_shape.size() == 4) { format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND;
format = kOpFormat_NCHW;
} else {
format = kOpFormat_ND;
}
} }
(*output_desc)["format"] = format; (*output_desc)["format"] = format;
(*output_desc)["ori_format"] = kOpFormat_NCHW; (*output_desc)["ori_format"] = kOpFormat_NCHW;
(*output_desc)["output_index"] = desc_output_idx; (*output_desc)["output_index"] = desc_output_idx;
if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) {
std::vector<size_t> spec_shape = {};
spec_shape.emplace_back(shape[0]);
spec_shape.emplace_back(shape[1]);
spec_shape.emplace_back(shape[2] * shape[3]);
spec_shape.emplace_back(shape[4]);
(*output_desc)["shape"] = spec_shape;
} else if (fusion_data_type == kFusionReLUGradV2 && (*output_desc)["data_type"] == "uint8") {
std::vector<size_t> spec_shape = {};
spec_shape.emplace_back(shape[0]);
spec_shape.emplace_back(shape[1]);
spec_shape.emplace_back(shape[2] * shape[3]);
spec_shape.emplace_back(16);
(*output_desc)["shape"] = spec_shape;
(*output_desc)["data_type"] = "bool";
}
} }
void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr<mindspore::AnfNode> &anf_node, size_t index, void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
size_t output_index, nlohmann::json *output_desc) { size_t output_index, nlohmann::json *output_desc) {
std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index); std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
(*output_desc)["name"] = NormalizeFullScopeName(output_desc_name); (*output_desc)["name"] = NormalizeFullScopeName(output_desc_name);
(*output_desc)["data_type"] = tbe::TypeIdToString(kNumberTypeFloat32);
(*output_desc)["output_index"] = output_index; (*output_desc)["output_index"] = output_index;
std::vector<size_t> shape; std::vector<size_t> shape;
(*output_desc)["shape"] = shape; (*output_desc)["shape"] = shape;
} }
bool TbeKernelBuild::GetInputLayers(const vector<mindspore::AnfNodePtr> &input_nodes, bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name,
const vector<mindspore::AnfNodePtr> &compute_nodes, const std::vector<mindspore::AnfNodePtr> &reorder_layer,
std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers) { std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) {
MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. ";
return false;
}
MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size();
if (op_name == kReluGradV2OpName) {
(*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2;
} else if (op_name == kAddNOpName) {
for (const auto &it : reorder_layer) {
(*spec_data_input)[it] = kFusionAddN;
}
}
return true;
}
bool TbeKernelBuild::GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes,
const std::vector<mindspore::AnfNodePtr> &compute_nodes,
std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers,
std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) {
auto op_name = AnfAlgo::GetCNodeName(it);
return op_name == kConv2DBackpropInputOpName;
});
bool need_spec = (result != compute_nodes.end());
size_t input_size = 0; size_t input_size = 0;
for (const auto &compute_node : compute_nodes) { for (const auto &compute_node : compute_nodes) {
std::vector<mindspore::AnfNodePtr> layer; std::vector<mindspore::AnfNodePtr> layer = {};
std::vector<mindspore::AnfNodePtr> reorder_layer = {};
MS_EXCEPTION_IF_NULL(compute_node); MS_EXCEPTION_IF_NULL(compute_node);
auto op_name = AnfAlgo::GetCNodeName(compute_node);
auto ccompute_node = compute_node->cast<CNodePtr>(); auto ccompute_node = compute_node->cast<CNodePtr>();
if (ccompute_node == nullptr) { if (ccompute_node == nullptr) {
MS_LOG(DEBUG) << "fusion compute node must be cnode"; MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode";
return false; return false;
} }
MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope();
for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) { for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) {
auto input = ccompute_node->input(i); auto input = ccompute_node->input(i);
auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input); auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input);
if (find_iter != input_nodes.end()) { if (find_iter != input_nodes.end()) {
MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope();
layer.emplace_back((*find_iter)); layer.emplace_back((*find_iter));
} else {
MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope()
<< ") node's output.";
}
}
TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer);
if (need_spec) {
MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten.";
if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) {
return false;
} }
} }
input_size += layer.size(); input_size += reorder_layer.size();
input_layers->emplace_back(layer); input_layers->emplace_back(reorder_layer);
} }
if (input_nodes.size() != input_size) { if (input_nodes.size() != input_size) {
MS_LOG(DEBUG) << "fusion scope error, layer input:" << input_size << ", input_node:" << input_nodes.size(); MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size
<< ", input_node:" << input_nodes.size();
return false; return false;
} }
return true; return true;
} }
bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr<mindspore::AnfNode> &data_input, nlohmann::json *data_str, bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input,
size_t *index) { const std::map<const AnfNodePtr, FusionDataType> &spec_data_input,
nlohmann::json *data_str, size_t *index) {
MS_EXCEPTION_IF_NULL(data_str); MS_EXCEPTION_IF_NULL(data_str);
MS_EXCEPTION_IF_NULL(index); MS_EXCEPTION_IF_NULL(index);
std::vector<nlohmann::json> output_desc_list; std::vector<nlohmann::json> output_desc_list;
...@@ -604,13 +662,17 @@ bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr<mindspore::AnfNode> ...@@ -604,13 +662,17 @@ bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr<mindspore::AnfNode>
output_desc_list.push_back(output_desc); output_desc_list.push_back(output_desc);
(*index)++; (*index)++;
} else { } else {
FusionDataType fusion_data_type = kFusionNormal;
if (spec_data_input.find(data_input) != spec_data_input.end()) {
fusion_data_type = spec_data_input.at(data_input);
}
auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0); auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0);
auto real_node = kernel_idx.first; auto real_node = kernel_idx.first;
size_t real_idx = kernel_idx.second; size_t real_idx = kernel_idx.second;
MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx; MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx;
// "output_desc" // "output_desc"
nlohmann::json output_desc; nlohmann::json output_desc;
GenDescJson(real_node, real_idx, real_idx, &output_desc); GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type);
output_desc_list.push_back(output_desc); output_desc_list.push_back(output_desc);
(*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope()); (*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope());
} }
...@@ -632,11 +694,12 @@ bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) { ...@@ -632,11 +694,12 @@ bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) {
auto real_input_size = cnode->inputs().size() - 1; auto real_input_size = cnode->inputs().size() - 1;
auto dyn_input_size = dyn_input_sizes.size(); auto dyn_input_size = dyn_input_sizes.size();
if (dyn_input_size != 1) { if (dyn_input_size != 1) {
MS_LOG(DEBUG) << "fusion build not support dyn_input_sizes > 1"; MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1";
return ret; return ret;
} }
if (IntToSize(dyn_input_sizes[0]) != real_input_size) { if (IntToSize(dyn_input_sizes[0]) != real_input_size) {
MS_LOG(DEBUG) << " dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" << real_input_size; MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size"
<< real_input_size;
return ret; return ret;
} }
ret = true; ret = true;
...@@ -663,6 +726,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, ...@@ -663,6 +726,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
std::vector<nlohmann::json> *input_desc_list, size_t *index) { std::vector<nlohmann::json> *input_desc_list, size_t *index) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(input_desc_list); MS_EXCEPTION_IF_NULL(input_desc_list);
std::vector<nlohmann::json> input_desc_list_tmp = {};
bool is_dynamic_input = IsDynamicInput(cnode); bool is_dynamic_input = IsDynamicInput(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) { for (size_t i = 1; i < cnode->inputs().size(); ++i) {
auto input = cnode->input(i); auto input = cnode->input(i);
...@@ -676,7 +740,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, ...@@ -676,7 +740,7 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
MS_LOG(INFO) << "node has dynamic input."; MS_LOG(INFO) << "node has dynamic input.";
input_desc["dyn_index"] = (i - 1); input_desc["dyn_index"] = (i - 1);
} }
(*input_desc_list).emplace_back(input_desc); input_desc_list_tmp.emplace_back(input_desc);
} }
size_t optional_num = GetOptionalInput(cnode, is_dynamic_input); size_t optional_num = GetOptionalInput(cnode, is_dynamic_input);
if (optional_num > 0) { if (optional_num > 0) {
...@@ -686,35 +750,24 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, ...@@ -686,35 +750,24 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
optional_input_desc["name"] = std::string(kOptional) + std::to_string(*index); optional_input_desc["name"] = std::string(kOptional) + std::to_string(*index);
(*index)++; (*index)++;
(*layer_iter)->emplace_back(nullptr); (*layer_iter)->emplace_back(nullptr);
(*input_desc_list).emplace_back(optional_input_desc); input_desc_list_tmp.emplace_back(optional_input_desc);
} }
} }
auto op_name = AnfAlgo::GetCNodeName(cnode);
TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list);
return true; return true;
} }
std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &output_used_nums) { std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &output_used_nums) {
std::vector<size_t> desc_output_index = {}; std::vector<size_t> desc_output_index = {};
bool find_reused = false;
size_t reused_num = 0;
for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { for (size_t idx = 0; idx < output_used_nums.size(); ++idx) {
auto output_use_num_item = output_used_nums[idx]; auto output_use_num_item = output_used_nums[idx];
MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item; MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item;
if (output_use_num_item == 1 || output_use_num_item == 0) { desc_output_index.emplace_back(idx);
if (output_use_num_item > 1) {
desc_output_index.emplace_back(idx); desc_output_index.emplace_back(idx);
} else {
if (!find_reused) {
desc_output_index.emplace_back(idx);
} else {
desc_output_index.emplace_back(desc_output_index[idx - 1]);
}
reused_num += (output_use_num_item - 1);
find_reused = true;
} }
} }
auto pad_value = output_used_nums.size() == 1 ? 0 : desc_output_index[desc_output_index.size() - 1] + 1;
for (size_t i = 0; i < reused_num; ++i) {
desc_output_index.emplace_back(pad_value);
}
return desc_output_index; return desc_output_index;
} }
...@@ -811,6 +864,7 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto ...@@ -811,6 +864,7 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto
} }
auto ret = GetIOSizeImpl(data_output); auto ret = GetIOSizeImpl(data_output);
input_size_list->push_back(ret); input_size_list->push_back(ret);
MS_LOG(INFO) << "Fusion info: scope input name: " << op["name"] << ", size: " << ret;
} }
} }
} }
...@@ -819,26 +873,31 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto ...@@ -819,26 +873,31 @@ bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vecto
auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0); auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0);
auto real_node = kernel_idx.first; auto real_node = kernel_idx.first;
size_t real_idx = kernel_idx.second; size_t real_idx = kernel_idx.second;
auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope());
MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx;
for (const auto &op : fusion_op_list) { for (const auto &op : fusion_op_list) {
auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope());
if (op["name"] == normal_name) { if (op["name"] == normal_name) {
auto op_output_desces = op["output_desc"]; auto op_output_desces = op["output_desc"];
if (output_node != real_node) { if (output_node != real_node) {
// tuple_get item // tuple_get item
MS_LOG(DEBUG) << "output is a tuple getitem node"; MS_LOG(INFO) << "output is a tuple getitem node";
auto output_desc = op_output_desces[real_idx]; auto output_desc = op_output_desces[real_idx];
if (output_desc["shape"].empty()) { if (output_desc["shape"].empty()) {
continue; MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
return false;
} }
auto ret = GetIOSizeImpl(output_desc); auto ret = GetIOSizeImpl(output_desc);
output_size_list->push_back(ret); output_size_list->push_back(ret);
MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret;
} else { } else {
for (const auto &output_desc : op_output_desces) { for (const auto &output_desc : op_output_desces) {
if (output_desc["shape"].empty()) { if (output_desc["shape"].empty()) {
MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output";
continue; continue;
} }
auto ret = GetIOSizeImpl(output_desc); auto ret = GetIOSizeImpl(output_desc);
output_size_list->push_back(ret); output_size_list->push_back(ret);
MS_LOG(INFO) << "Fusion info: scope output size: " << ret;
} }
} }
} }
......
...@@ -35,6 +35,8 @@ namespace kernel { ...@@ -35,6 +35,8 @@ namespace kernel {
// kernel operate type used for generate json // kernel operate type used for generate json
class TbeKernelBuild { class TbeKernelBuild {
enum FusionDataType { kFusionNormal = 0, kFusionAddN, kFusionReLUGradV2 };
public: public:
static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list, static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list); std::vector<size_t> *output_size_list);
...@@ -48,8 +50,9 @@ class TbeKernelBuild { ...@@ -48,8 +50,9 @@ class TbeKernelBuild {
private: private:
TbeKernelBuild() = default; TbeKernelBuild() = default;
~TbeKernelBuild() = default; ~TbeKernelBuild() = default;
static bool GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input, nlohmann::json *data_str, static bool GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input,
size_t *index); const std::map<const AnfNodePtr, FusionDataType> &spec_data_input,
nlohmann::json *data_str, size_t *index);
static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node,
std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter, std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index);
...@@ -60,13 +63,17 @@ class TbeKernelBuild { ...@@ -60,13 +63,17 @@ class TbeKernelBuild {
static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
std::vector<nlohmann::json> *output_desc_list); std::vector<nlohmann::json> *output_desc_list);
static void GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx, static void GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
size_t desc_output_idx, nlohmann::json *output_desc); size_t desc_output_idx, nlohmann::json *output_desc,
FusionDataType fusion_data_type = kFusionNormal);
static void GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index, static void GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
size_t output_index, nlohmann::json *output_desc); size_t output_index, nlohmann::json *output_desc);
static size_t GetIOSizeImpl(const nlohmann::json &desc); static size_t GetIOSizeImpl(const nlohmann::json &desc);
static bool GetSpecInputLayers(const std::string &op_name, const std::vector<mindspore::AnfNodePtr> &reorder_layer,
std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes, static bool GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes,
const std::vector<mindspore::AnfNodePtr> &compute_nodes, const std::vector<mindspore::AnfNodePtr> &compute_nodes,
std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers); std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers,
std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool IsDynamicInput(const CNodePtr &cnode); static bool IsDynamicInput(const CNodePtr &cnode);
static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
}; };
......
...@@ -38,6 +38,7 @@ ...@@ -38,6 +38,7 @@
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h" #include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h"
#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h" #include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h"
#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h"
#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h" #include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h"
#include "pre_activate/ascend/ir_fusion/transdata_split.h" #include "pre_activate/ascend/ir_fusion/transdata_split.h"
#include "pre_activate/ascend/ir_fission/topk_split.h" #include "pre_activate/ascend/ir_fission/topk_split.h"
...@@ -265,6 +266,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern ...@@ -265,6 +266,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm->AddPass(std::make_shared<AllReduceFusion>()); other_pm->AddPass(std::make_shared<AllReduceFusion>());
other_pm->AddPass(std::make_shared<AllGatherFusion>()); other_pm->AddPass(std::make_shared<AllGatherFusion>());
other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); other_pm->AddPass(std::make_shared<ParameterTransOpFusion>());
other_pm->AddPass(std::make_shared<RefreshParameterFormat>());
other_pm->AddPass(std::make_shared<BufferFusion>()); other_pm->AddPass(std::make_shared<BufferFusion>());
other_pm->AddPass(std::make_shared<GetitemTuple>()); other_pm->AddPass(std::make_shared<GetitemTuple>());
other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); other_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <iterator>
#include "kernel/kernel_fusion.h" #include "kernel/kernel_fusion.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
...@@ -261,23 +262,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v ...@@ -261,23 +262,24 @@ CNodePtr CreateFusionOp(const std::vector<AnfNodePtr> &inputs_list, const std::v
return buffer_fusion_kernel; return buffer_fusion_kernel;
} }
kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list_in, kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector<AnfNodePtr> &inputs_list,
const std::vector<AnfNodePtr> &inputs_list,
const std::vector<AnfNodePtr> &outputs_list) { const std::vector<AnfNodePtr> &outputs_list) {
MS_LOG(DEBUG) << "Start Create Kernel Info"; MS_LOG(DEBUG) << "Start Create Kernel Info";
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
// inputs format and data type // inputs format and data type
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_data_type; std::vector<TypeId> inputs_data_type;
for (auto node : inputs_list_in) { for (const auto &input : inputs_list) {
auto cnode = node->cast<CNodePtr>(); if (input->isa<CNode>() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) {
MS_EXCEPTION_IF_NULL(cnode); auto tuple_getitem = input->cast<CNodePtr>();
auto &inputs = cnode->inputs(); MS_EXCEPTION_IF_NULL(tuple_getitem);
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) { inputs_format.push_back(AnfAlgo::GetOutputFormat(
if (std::find(inputs_list.begin(), inputs_list.end(), inputs[input_index]) != inputs_list.end()) { tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
inputs_format.push_back(AnfAlgo::GetInputFormat(node, input_index - 1)); inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(
inputs_data_type.push_back(AnfAlgo::GetInputDeviceDataType(node, input_index - 1)); tuple_getitem->input(1), IntToSize(GetValue<int>(GetValueNode(tuple_getitem->input(2))))));
} } else {
inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0));
inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0));
} }
} }
// outputs format and data type // outputs format and data type
...@@ -360,62 +362,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi ...@@ -360,62 +362,6 @@ void ReplaceOldNode(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusi
} }
} }
void GetInputList(const CNodePtr &node, const int32_t cur_fusion_id, std::vector<AnfNodePtr> *inputs_list) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inputs_list);
auto &inputs = node->inputs();
for (size_t input_index = 1; input_index < inputs.size(); ++input_index) {
auto input = inputs[input_index];
if (AnfAlgo::IsRealCNodeKernel(input)) {
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input)) {
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input, kOpAttrFusionId);
if (fusion_id != cur_fusion_id) {
inputs_list->push_back(input);
}
} else {
inputs_list->push_back(input);
}
} else if (input->isa<CNode>()) {
for (auto &input_in : input->cast<CNodePtr>()->inputs()) {
if (AnfAlgo::IsRealCNodeKernel(input_in)) {
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input_in)) {
auto fusion_id = AnfAlgo::GetNodeAttr<int32_t>(input_in, kOpAttrFusionId);
if (fusion_id != cur_fusion_id) {
inputs_list->push_back(input);
}
} else {
inputs_list->push_back(input);
}
}
}
} else {
inputs_list->push_back(input);
}
}
}
void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
if ((*buffer_fusion_infos).find(cur_fusion_id) == (*buffer_fusion_infos).end()) {
BufferFusionInfo_t buffer_fusion_info;
(*buffer_fusion_infos)[cur_fusion_id] = buffer_fusion_info;
}
std::vector<AnfNodePtr> inputs_list;
GetInputList(node, cur_fusion_id, &inputs_list);
if (!inputs_list.empty()) {
if (!(*buffer_fusion_infos)[cur_fusion_id].inputs_list.empty()) {
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list.insert(
(*buffer_fusion_infos)[cur_fusion_id].inputs_list.end(), inputs_list.begin(), inputs_list.end());
(void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.insert(
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.end(), node);
} else {
(*buffer_fusion_infos)[cur_fusion_id].inputs_list = inputs_list;
(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.push_back(node);
}
}
}
void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
...@@ -429,6 +375,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, ...@@ -429,6 +375,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
} }
} }
void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
auto manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
auto fusion_id = buffer_fusion_info.first;
auto fusion_info = buffer_fusion_info.second;
for (const auto &node : fusion_info.anf_nodes) {
auto cnode = node->cast<CNodePtr>();
for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) {
auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0);
if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) ==
fusion_info.anf_nodes.end()) {
if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(),
(*buffer_fusion_infos)[fusion_id].inputs_list.end(),
cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) {
(*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx));
}
}
}
}
}
}
bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) {
MS_EXCEPTION_IF_NULL(node1);
MS_EXCEPTION_IF_NULL(node2);
auto getitem1 = node1->cast<CNodePtr>();
auto getitem2 = node2->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem1);
MS_EXCEPTION_IF_NULL(getitem2);
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
return output_idx1 < output_idx2;
}
void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) { std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
...@@ -454,14 +439,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, ...@@ -454,14 +439,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(),
std::back_inserter(tuple_getitem_nodes), std::back_inserter(tuple_getitem_nodes),
[](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; }); [](const std::pair<AnfNodePtr, int> &use_node) { return use_node.first; });
std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare);
[](const AnfNodePtr &node1, const AnfNodePtr &node2) {
auto getitem1 = node1->cast<CNodePtr>();
auto getitem2 = node2->cast<CNodePtr>();
auto output_idx1 = GetValue<int>(GetValueNode(getitem1->input(2)));
auto output_idx2 = GetValue<int>(GetValueNode(getitem2->input(2)));
return output_idx1 < output_idx2;
});
for (auto getitem : tuple_getitem_nodes) { for (auto getitem : tuple_getitem_nodes) {
auto getitem_ptr = getitem->cast<CNodePtr>(); auto getitem_ptr = getitem->cast<CNodePtr>();
auto input2 = getitem_ptr->input(2); auto input2 = getitem_ptr->input(2);
...@@ -484,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, ...@@ -484,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
} }
} }
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
const AnfNodePtr &fusion_kernel) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto manager = kernel_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
auto output = outputs_list[idx];
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
auto real_output = AnfAlgo::VisitKernel(output, 0);
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto input2 = output_cnode->input(2);
auto output_idx = GetValue<int>(GetValueNode(input2));
session::AnfWithOutIndex out_pair(real_output.first, output_idx);
if (kernel_graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
}
} else {
session::AnfWithOutIndex out_pair(output, 0);
if (kernel_graph->IsInRefOutputMap(out_pair)) {
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
}
}
}
}
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) { std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
...@@ -634,24 +642,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord ...@@ -634,24 +642,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const { std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
MS_EXCEPTION_IF_NULL(buffer_fusion_infos); MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
auto cur_fusion_id = AnfAlgo::GetNodeAttr<int32_t>(cnode, kOpAttrFusionId);
CheckCurrentNodeIsInput(cnode, cur_fusion_id, buffer_fusion_infos);
}
}
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos);
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
for (auto &buffer_fusion_info : *buffer_fusion_infos) { for (auto &buffer_fusion_info : *buffer_fusion_infos) {
buffer_fusion_info.second.kernel_build_info = buffer_fusion_info.second.kernel_build_info =
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list_in, buffer_fusion_info.second.inputs_list, CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
buffer_fusion_info.second.outputs_list);
} }
} }
...@@ -743,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_ ...@@ -743,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_
} }
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get());
AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get());
// replace node SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion);
ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph);
return true; return true;
} }
......
...@@ -30,7 +30,6 @@ namespace opt { ...@@ -30,7 +30,6 @@ namespace opt {
struct BufferFusionInfo_t { struct BufferFusionInfo_t {
std::vector<AnfNodePtr> anf_nodes; std::vector<AnfNodePtr> anf_nodes;
std::vector<AnfNodePtr> inputs_list; std::vector<AnfNodePtr> inputs_list;
std::vector<AnfNodePtr> inputs_list_in;
std::vector<AnfNodePtr> outputs_list; std::vector<AnfNodePtr> outputs_list;
kernel::KernelBuildInfoPtr kernel_build_info; kernel::KernelBuildInfoPtr kernel_build_info;
}; };
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "pre_activate/common/helper.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace mindspore {
namespace opt {
void DoRefresh(const CNodePtr &cnode) {
if (cnode == nullptr) {
MS_LOG(EXCEPTION) << "node is nullptr";
}
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) {
auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index);
if (input_kernel_node->isa<Parameter>()) {
std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
auto cnode_input_format = AnfAlgo::GetInputFormat(cnode, input_index);
auto kernel_node_format = AnfAlgo::GetOutputFormat(input_kernel_node, 0);
auto dtype = AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0);
if (kernel_node_format != cnode_input_format) {
builder->SetOutputsFormat({cnode_input_format});
builder->SetOutputsDeviceType({dtype});
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
}
}
}
}
bool RefreshParameterFormat::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr.";
return false;
}
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto node : node_list) {
if (node == nullptr || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (node_name == kBNTrainingUpdateOpName) {
DoRefresh(cnode);
}
}
return true;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
#include <vector>
#include <memory>
#include <utility>
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
namespace mindspore {
namespace opt {
class RefreshParameterFormat : public Pass {
public:
explicit RefreshParameterFormat(size_t groups = 1) : Pass("refresh_parameter_format"), groups_(groups) {}
~RefreshParameterFormat() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
size_t groups_ = 1;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
...@@ -825,6 +825,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n ...@@ -825,6 +825,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
static std::map<std::string, std::map<size_t, size_t>> spec_node_list = { static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
{prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
{kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}}},
{kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}}},
{prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
{prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
{prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, {prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
...@@ -835,7 +837,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n ...@@ -835,7 +837,7 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n
auto node_name = AnfAlgo::GetCNodeName(anf_node); auto node_name = AnfAlgo::GetCNodeName(anf_node);
if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) { if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
auto find = spec_node_list.find(node_name); auto find = spec_node_list.find(node_name);
if (find != spec_node_list.end()) { if (find != spec_node_list.end() && cur_index < find->second.size()) {
ret = find->second[cur_index]; ret = find->second[cur_index];
MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name; MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
} }
......
...@@ -122,6 +122,10 @@ constexpr auto kSendOpName = "Send"; ...@@ -122,6 +122,10 @@ constexpr auto kSendOpName = "Send";
constexpr auto kRecvOpName = "Recv"; constexpr auto kRecvOpName = "Recv";
constexpr auto kReluV2OpName = "ReLUV2"; constexpr auto kReluV2OpName = "ReLUV2";
constexpr auto kReluGradV2OpName = "ReluGradV2"; constexpr auto kReluGradV2OpName = "ReluGradV2";
constexpr auto kAddNOpName = "AddN";
constexpr auto kConv2DBackpropInputOpName = "Conv2DBackpropInput";
constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2";
constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2";
// attr key name // attr key name
constexpr auto kAttrInputNames = "input_names"; constexpr auto kAttrInputNames = "input_names";
......
...@@ -31,6 +31,7 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \ ...@@ -31,6 +31,7 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \
.dtype_format(DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.get_op_info() .get_op_info()
......
...@@ -654,6 +654,7 @@ class Conv2D(PrimitiveWithInfer): ...@@ -654,6 +654,7 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('data_format', "NCHW") self.add_prim_attr('data_format', "NCHW")
self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name)
self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.group = validator.check_integer('group', group, 0, Rel.GT, self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape): def infer_shape(self, x_shape, w_shape):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册