Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0fbec0f5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0fbec0f5
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!750 refresh parameter format
Merge pull request !750 from liubuyu/master
上级
420ef2a3
8f48db29
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
113 addition
and
0 deletion
+113
-0
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
.../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
+2
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc
...pre_activate/ascend/ir_fusion/refresh_parameter_format.cc
+71
-0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h
.../pre_activate/ascend/ir_fusion/refresh_parameter_format.h
+40
-0
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
浏览文件 @
0fbec0f5
...
...
@@ -38,6 +38,7 @@
#include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h"
#include "pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h"
#include "pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.h"
#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h"
#include "pre_activate/ascend/ir_fusion/transpose_transdata_fusion.h"
#include "pre_activate/ascend/ir_fission/transdata_split.h"
#include "pre_activate/ascend/ir_fission/topk_split.h"
...
...
@@ -265,6 +266,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
other_pm
->
AddPass
(
std
::
make_shared
<
AllReduceFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
AllGatherFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
ParameterTransOpFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
RefreshParameterFormat
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
BufferFusion
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
other_pm
->
AddPass
(
std
::
make_shared
<
CommonSubexpressionElimination
>
());
...
...
mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.cc
0 → 100644
浏览文件 @
0fbec0f5
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fusion/refresh_parameter_format.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "operator/ops.h"
#include "device/kernel_info.h"
#include "pre_activate/common/helper.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/ascend/ascend_helper.h"
namespace
mindspore
{
namespace
opt
{
void
DoRefresh
(
const
CNodePtr
&
cnode
)
{
if
(
cnode
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"node is nullptr"
;
}
for
(
size_t
input_index
=
0
;
input_index
<
AnfAlgo
::
GetInputTensorNum
(
cnode
);
input_index
++
)
{
auto
input_kernel_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
input_index
);
if
(
input_kernel_node
->
isa
<
Parameter
>
())
{
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
auto
cnode_input_format
=
AnfAlgo
::
GetInputFormat
(
cnode
,
input_index
);
auto
kernel_node_format
=
AnfAlgo
::
GetOutputFormat
(
input_kernel_node
,
0
);
auto
dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
input_kernel_node
,
0
);
if
(
kernel_node_format
!=
cnode_input_format
)
{
builder
->
SetOutputsFormat
({
cnode_input_format
});
builder
->
SetOutputsDeviceType
({
dtype
});
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
input_kernel_node
.
get
());
}
}
}
}
bool
RefreshParameterFormat
::
Run
(
const
FuncGraphPtr
&
func_graph
)
{
if
(
func_graph
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"func_graph is nullptr."
;
return
false
;
}
std
::
vector
<
AnfNodePtr
>
node_list
=
TopoSort
(
func_graph
->
get_return
());
for
(
auto
node
:
node_list
)
{
if
(
node
==
nullptr
||
!
node
->
isa
<
CNode
>
())
{
continue
;
}
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
if
(
cnode
==
nullptr
)
{
continue
;
}
auto
node_name
=
AnfAlgo
::
GetCNodeName
(
cnode
);
if
(
node_name
==
kBNTrainingUpdateOpName
)
{
DoRefresh
(
cnode
);
}
}
return
true
;
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/ascend/ir_fusion/refresh_parameter_format.h
0 → 100644
浏览文件 @
0fbec0f5
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
#include <vector>
#include <memory>
#include <utility>
#include "ir/anf.h"
#include "pre_activate/common/pass.h"
namespace
mindspore
{
namespace
opt
{
class
RefreshParameterFormat
:
public
Pass
{
public:
explicit
RefreshParameterFormat
(
size_t
groups
=
1
)
:
Pass
(
"refresh_parameter_format"
),
groups_
(
groups
)
{}
~
RefreshParameterFormat
()
override
=
default
;
bool
Run
(
const
FuncGraphPtr
&
graph
)
override
;
private:
size_t
groups_
=
1
;
};
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REFRESH_PARAMETER_FORMAT_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录