Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c23731e5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c23731e5
编写于
5月 18, 2020
作者:
C
changzherui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Incremental subgraph initialization
上级
311b7e71
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
59 addition
and
15 deletion
+59
-15
mindspore/ccsrc/ir/meta_tensor.cc
mindspore/ccsrc/ir/meta_tensor.cc
+22
-0
mindspore/ccsrc/ir/meta_tensor.h
mindspore/ccsrc/ir/meta_tensor.h
+3
-1
mindspore/ccsrc/transform/convert.cc
mindspore/ccsrc/transform/convert.cc
+11
-9
mindspore/common/api.py
mindspore/common/api.py
+6
-5
mindspore/common/parameter.py
mindspore/common/parameter.py
+3
-0
mindspore/common/tensor.py
mindspore/common/tensor.py
+14
-0
未找到文件。
mindspore/ccsrc/ir/meta_tensor.cc
浏览文件 @
c23731e5
...
...
@@ -374,6 +374,10 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return
data_type_
;
}
bool
Tensor
::
is_init
()
{
return
init_flag_
;
}
void
Tensor
::
set_init_flag
(
bool
flag
)
{
init_flag_
=
flag
;
}
bool
Tensor
::
convert_data
(
const
py
::
array
&
in
,
const
TypeId
in_data_type
,
py
::
array
*
const
out
,
const
TypeId
out_data_type
)
{
if
(
out
==
nullptr
)
{
...
...
@@ -499,6 +503,24 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.size()
6
)mydelimiter"
)
.
def
(
"is_init"
,
&
Tensor
::
is_init
,
R"mydelimiter(
Get tensor init_flag.
Returns:
bool, whether the tensor init.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.is_init()
False
)mydelimiter"
)
.
def
(
"set_init_flag"
,
&
Tensor
::
set_init_flag
,
R"mydelimiter(
Set tensor init_flag.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.set_init_flag(True)
)mydelimiter"
)
.
def
(
"dim"
,
&
Tensor
::
DataDim
,
R"mydelimiter(
Get tensor's data dimension.
...
...
mindspore/ccsrc/ir/meta_tensor.h
浏览文件 @
c23731e5
...
...
@@ -389,6 +389,8 @@ class Tensor : public MetaTensor {
std
::
string
ToStringRepr
()
const
;
py
::
array
data_
;
// < Tensor's data value
const
bool
parse_info_
=
true
;
bool
is_init
();
void
set_init_flag
(
bool
flag
);
private:
// brief init tensor
...
...
@@ -398,7 +400,7 @@ class Tensor : public MetaTensor {
// return true if succeed, false if failed.
void
init
(
const
py
::
array
&
input
,
const
TypeId
&
data_type
);
void
init
(
const
py
::
array
&
input
,
const
TypePtr
&
type_ptr
);
bool
init_flag_
{
false
};
// brief init tensor attribute
//
// param data_type [TypeId] Data type of the tensor.
...
...
mindspore/ccsrc/transform/convert.cc
浏览文件 @
c23731e5
...
...
@@ -646,7 +646,6 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
if
(
adpt
==
nullptr
)
continue
;
auto
param_op
=
adpt
->
generate
(
name
+
"_data"
);
MS_LOG
(
INFO
)
<<
"Add parameter "
<<
name
<<
" as input, index "
<<
index
<<
"."
;
(
void
)
std
::
static_pointer_cast
<
Data
>
(
param_op
)
->
set_attr_index
(
index
++
);
if
(
!
training_
)
{
auto
adpt_const
=
FindAdapter
(
kNameConst
,
training_
);
...
...
@@ -675,14 +674,17 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
// we need three variable ops for each graph with same name
// build init subgraph
auto
init_var
=
std
::
make_shared
<
Variable
>
(
name
);
auto
assign_op
=
std
::
make_shared
<
Assign
>
(
"assign_"
+
name
);
(
void
)
init_var
->
update_output_desc_y
(
*
desc
);
(
void
)
assign_op
->
set_input_ref
(
*
init_var
).
set_input_value
(
*
param_op
);
init_input
.
push_back
(
*
init_var
);
init_ops_
.
push_back
(
param_op
);
init_ops_
.
push_back
(
assign_op
);
init_ops_
.
push_back
(
init_var
);
if
(
it
.
second
->
is_init
()
==
0
)
{
(
void
)
std
::
static_pointer_cast
<
Data
>
(
param_op
)
->
set_attr_index
(
index
++
);
auto
init_var
=
std
::
make_shared
<
Variable
>
(
name
);
auto
assign_op
=
std
::
make_shared
<
Assign
>
(
"assign_"
+
name
);
(
void
)
init_var
->
update_output_desc_y
(
*
desc
);
(
void
)
assign_op
->
set_input_ref
(
*
init_var
).
set_input_value
(
*
param_op
);
init_input
.
push_back
(
*
init_var
);
init_ops_
.
push_back
(
param_op
);
init_ops_
.
push_back
(
assign_op
);
init_ops_
.
push_back
(
init_var
);
}
auto
variable
=
std
::
make_shared
<
Variable
>
(
name
);
(
void
)
variable
->
update_output_desc_y
(
*
desc
);
...
...
mindspore/common/api.py
浏览文件 @
c23731e5
...
...
@@ -82,14 +82,15 @@ def _wrap_func(fn):
def
_exec_init_graph
(
obj
,
init_phase
):
"""Execute the parameter initializer graph."""
inst_executor
=
Executor_
.
get_instance
()
exec_init_graph
=
False
for
param
in
obj
.
get_parameter
s
():
param_dict
=
OrderedDict
()
for
name
,
param
in
obj
.
parameters_dict
().
item
s
():
if
not
param
.
is_init
:
param_dict
[
name
]
=
param
param
.
is_init
=
True
exec_init_graph
=
True
param
.
data
.
init_flag
=
True
if
exec_init_graph
:
inst_executor
.
run_init_graph
(
obj
.
parameters_dict
()
,
init_phase
)
if
param_dict
:
inst_executor
.
run_init_graph
(
param_dict
,
init_phase
)
class
_MindSporeFunction
:
...
...
mindspore/common/parameter.py
浏览文件 @
c23731e5
...
...
@@ -188,11 +188,14 @@ class Parameter:
if
isinstance
(
data
,
Tensor
):
# make a copy of Tensor to init the parameter
data
=
Tensor
(
data
.
asnumpy
().
copy
())
data
.
init_flag
=
False
elif
isinstance
(
data
,
Initializer
):
self
.
init_mode
=
data
data
=
MetaTensor
(
self
.
init_mode
.
dtype
,
self
.
init_mode
.
shape
)
else
:
data
=
Tensor
(
data
)
data
.
init_flag
=
False
self
.
default_input
=
data
...
...
mindspore/common/tensor.py
浏览文件 @
c23731e5
...
...
@@ -65,6 +65,7 @@ class Tensor(Tensor_):
else
:
super
(
Tensor
,
self
).
__init__
(
input_data
,
dtype
)
self
.
_virtual_flag
=
False
self
.
_init_flag
=
False
def
__repr__
(
self
):
return
str
(
self
.
__str__
())
...
...
@@ -153,3 +154,16 @@ class Tensor(Tensor_):
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"virtual_flag must be bool."
)
self
.
_virtual_flag
=
value
@
property
def
init_flag
(
self
):
"""whether the tensor is init."""
return
self
.
_init_flag
@
init_flag
.
setter
def
init_flag
(
self
,
value
):
"""Set the tensor is init_flag."""
if
not
isinstance
(
value
,
bool
):
raise
TypeError
(
"init_flag must be bool."
)
self
.
set_init_flag
(
value
)
self
.
_init_flag
=
value
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录