提交 683d29c0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!564 Overlength functions rectification

Merge pull request !564 from YuJianfeng/master
......@@ -70,6 +70,35 @@
namespace mindspore {
namespace opt {
namespace {
void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
MS_EXCEPTION_IF_NULL(ir_fusion_pm);
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>());
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
}
} // namespace
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
......@@ -164,29 +193,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
if (context_ptr->ir_fusion_flag()) {
ir_fusion_pm->AddPass(std::make_shared<SquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByNormNoDivSquareSumFusion>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayV1Rule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<ClipByValueFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<AdamApplyOneFusion>());
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
ir_fusion_pm->AddPass(std::make_shared<MatmulBiasaddFusion>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
}
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
......
......@@ -26,6 +26,7 @@
namespace mindspore {
namespace opt {
namespace {
const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag,
std::vector<CNodePtr> *trans_road) {
if (node == nullptr) {
......@@ -59,6 +60,24 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr
return nullptr;
}
kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type,
TypeId output_type) {
MS_EXCEPTION_IF_NULL(cast);
auto kernel_info = cast->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto cast_build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(cast_build_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({format});
builder.SetInputsFormat({format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetKernelType(cast_build_info->kernel_type());
builder.SetFusionType(cast_build_info->fusion_type());
builder.SetProcessor(cast_build_info->processor());
return builder.Build();
}
} // namespace
bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Func graph is nullptr";
......@@ -95,17 +114,7 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) {
auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0);
auto cast = trans_road[1];
auto cast_format = AnfAlgo::GetOutputFormat(cast, 0);
auto cast_build_info = cast->kernel_info()->select_kernel_build_info();
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({format});
builder.SetInputsFormat({format});
builder.SetInputsDeviceType({param_dtype});
builder.SetOutputsDeviceType({dtype});
builder.SetKernelType(cast_build_info->kernel_type());
builder.SetFusionType(cast_build_info->fusion_type());
builder.SetProcessor(cast_build_info->processor());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get());
if (param_format == format && param_dtype != dtype) {
manager->Replace(trans_road[2], final_node);
manager->Replace(cur_transop, cast);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册