Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ce57e02d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce57e02d
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1562 don't set parameter's format when it's has been setted before
Merge pull request !1562 from lianliguang/r0.3
上级
07724c70
085d8f12
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
25 addition
and
26 deletion
+25
-26
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
+2
-12
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
+14
-9
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
+9
-5
未找到文件。
mindspore/ccsrc/device/ascend/kernel_select_ascend.cc
浏览文件 @
ce57e02d
...
...
@@ -166,20 +166,10 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
std
::
shared_ptr
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
// we set special device info of a input tensor.
bool
is_ref
=
false
;
auto
op_info
=
mindspore
::
kernel
::
OpLib
::
FindOp
(
AnfAlgo
::
GetCNodeName
(
kernel_node
),
kernel
::
kTBE
);
if
(
op_info
!=
nullptr
)
{
is_ref
=
op_info
->
is_ref
();
}
MS_EXCEPTION_IF_NULL
(
MsContext
::
GetInstance
());
if
(
MsContext
::
GetInstance
()
->
execution_mode
()
==
kPynativeMode
&&
AnfAlgo
::
GetOutputDeviceDataType
(
real_input_node
,
0
)
!=
kTypeUnknown
)
{
continue
;
}
if
(
AnfAlgo
::
GetOutputDeviceDataType
(
real_input_node
,
0
)
==
kTypeUnknown
||
is_ref
)
{
if
(
AnfAlgo
::
GetOutputDeviceDataType
(
real_input_node
,
0
)
==
kTypeUnknown
)
{
std
::
vector
<
std
::
string
>
output_format
=
{
selected_kernel_info
.
GetInputFormat
(
input_index
)};
builder
->
SetOutputsFormat
(
output_format
);
std
::
vector
<
TypeId
>
output_type
=
{
selected_kernel_info
.
GetInputDeviceType
(
input_index
)};
std
::
vector
<
TypeId
>
output_type
=
{
AnfAlgo
::
GetOutputInferDataType
(
real_input_node
,
0
)};
builder
->
SetOutputsDeviceType
(
output_type
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
builder
->
Build
(),
real_input_node
.
get
());
}
...
...
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc
浏览文件 @
ce57e02d
...
...
@@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
return
false
;
}
std
::
vector
<
Axis
>
reshape_type
;
if
(
!
StringToAxisVector
(
input
->
reshape_type
(),
&
reshape_type
))
{
return
false
;
}
if
(
param_type
==
"dynamic"
)
{
if
(
dyn_input_sizes
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"
;
...
...
@@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto
type_id
=
tbe
::
DtypeToTypeId
(
dtypes
[
builder_idex
]);
inputs_device_type
.
push_back
(
type_id
);
inputs_format
.
push_back
(
formats
[
builder_idex
]);
reshape_types
.
push_back
(
reshape_type
);
}
dyn_input_idx
++
;
}
else
if
(
param_type
==
"required"
)
{
...
...
@@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto
type_id
=
tbe
::
DtypeToTypeId
(
dtypes
[
builder_idex
]);
inputs_device_type
.
push_back
(
type_id
);
inputs_format
.
push_back
(
formats
[
builder_idex
]);
reshape_types
.
push_back
(
reshape_type
);
}
else
{
if
(
kernel_info_index
<
real_input_num
)
{
MS_LOG
(
INFO
)
<<
"Set input kernel builder info, input type is optional, input index is "
<<
kernel_info_index
;
...
...
@@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inp
auto
type_id
=
tbe
::
DtypeToTypeId
(
dtypes
[
builder_idex
]);
inputs_device_type
.
push_back
(
type_id
);
inputs_format
.
push_back
(
formats
[
builder_idex
]);
reshape_types
.
push_back
(
reshape_type
);
}
}
std
::
vector
<
Axis
>
reshape_type
;
if
(
!
StringToAxisVector
(
input
->
reshape_type
(),
&
reshape_type
))
{
return
false
;
}
reshape_types
.
push_back
(
reshape_type
);
}
builder
->
SetInputReshapeType
(
reshape_types
);
...
...
@@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
MS_LOG
(
WARNING
)
<<
"real_output_num: "
<<
real_output_num
<<
", output_idx: "
<<
output_idx
<<
"is out of limit!"
;
continue
;
}
std
::
vector
<
Axis
>
reshape_type
;
if
(
!
StringToAxisVector
(
output
->
reshape_type
(),
&
reshape_type
))
{
return
false
;
}
size_t
output_num
=
0
;
if
(
output
->
param_type
()
==
"dynamic"
)
{
if
(
outputs
.
size
()
>
1
)
{
...
...
@@ -467,12 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &ou
auto
type_id
=
tbe
::
DtypeToTypeId
(
dtypes
[
builder_idex
]);
outputs_device_type
.
push_back
(
type_id
);
outputs_format
.
push_back
(
formats
[
builder_idex
]);
reshape_types
.
push_back
(
reshape_type
);
output_idx
++
;
}
std
::
vector
<
Axis
>
reshape_type
;
if
(
!
StringToAxisVector
(
output
->
reshape_type
(),
&
reshape_type
))
{
return
false
;
}
reshape_types
.
push_back
(
reshape_type
);
}
...
...
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
浏览文件 @
ce57e02d
...
...
@@ -33,12 +33,15 @@ using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace
{
kernel
::
KernelBuildInfoPtr
RefreshKernelBuildInfo
(
const
std
::
string
&
input_format
,
const
std
::
string
&
output_format
,
const
AnfNodePtr
&
node
,
const
TypeId
device_type
,
const
kernel
::
KernelBuildInfo
&
ori_build_info
)
{
const
kernel
::
KernelBuildInfo
&
ori_build_info
,
const
std
::
vector
<
kernel
::
Axis
>
&
reshape_type
)
{
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
input_format
});
builder
.
SetOutputsFormat
({
output_format
});
builder
.
SetInputsDeviceType
({
device_type
});
builder
.
SetOutputsDeviceType
({
device_type
});
builder
.
SetOutputReshapeType
({
reshape_type
});
builder
.
SetInputReshapeType
({
reshape_type
});
builder
.
SetKernelType
(
ori_build_info
.
kernel_type
());
builder
.
SetFusionType
(
ori_build_info
.
fusion_type
());
builder
.
SetProcessor
(
ori_build_info
.
processor
());
...
...
@@ -175,6 +178,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
AnfNodePtr
trans_node
=
nullptr
;
AnfNodePtr
input_node
=
node
;
AnfNodePtr
trans_data
=
nullptr
;
std
::
vector
<
kernel
::
Axis
>
reshape_type
=
AnfAlgo
::
GetOutputReshapeType
(
node
,
0
);
TypeId
dtype
=
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
0
);
MS_EXCEPTION_IF_NULL
(
node
);
if
(
origin_format
.
empty
()
||
dest_format
.
empty
())
{
...
...
@@ -189,6 +193,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
dtype
=
AnfAlgo
::
GetInputDeviceDataType
(
cnode
,
insert_index
);
MS_EXCEPTION_IF_NULL
(
cnode
);
input_node
=
AnfAlgo
::
GetInputNode
(
cnode
,
insert_index
);
reshape_type
=
AnfAlgo
::
GetInputReshapeType
(
node
,
insert_index
);
}
bool
need_padding
=
false
;
if
(
is_insert_input
)
{
...
...
@@ -222,7 +227,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
MS_EXCEPTION_IF_NULL
(
trans_data
);
MS_EXCEPTION_IF_NULL
(
trans_data
->
kernel_info
());
auto
trans_ori_build_info
=
trans_data
->
kernel_info
()
->
select_kernel_build_info
();
auto
kernel_build_info
=
RefreshKernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
dtype
,
*
trans_ori_build_info
);
auto
kernel_build_info
=
RefreshKernelBuildInfo
(
origin_format
,
dest_format
,
input_node
,
dtype
,
*
trans_ori_build_info
,
reshape_type
);
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info
,
trans_data
.
get
());
return
trans_node
;
}
...
...
@@ -309,9 +315,7 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod
auto
cur_input
=
AnfAlgo
::
GetInputNode
(
cnode
,
input_index
);
auto
kernel_with_index
=
AnfAlgo
::
VisitKernel
(
cur_input
,
0
);
auto
is_weight_boundary
=
[](
const
AnfNodePtr
&
node
)
->
bool
{
if
(
node
->
isa
<
ValueNode
>
())
{
return
true
;
}
else
if
(
node
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
IsParameterWeight
(
node
->
cast
<
ParameterPtr
>
()))
{
if
(
node
->
isa
<
ValueNode
>
()
||
node
->
isa
<
Parameter
>
())
{
return
true
;
}
return
false
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录