Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
7ce9f9bd
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ce9f9bd
编写于
6月 19, 2020
作者:
L
Li Hongzhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove TrainLineage, EvalLineage and related ut/st
上级
ed594114
变更
18
展开全部
隐藏空白更改
内联
并排
Showing
18 changed file
with
252 addition
and
1287 deletion
+252
-1287
mindinsight/lineagemgr/__init__.py
mindinsight/lineagemgr/__init__.py
+3
-10
mindinsight/lineagemgr/collection/model/__init__.py
mindinsight/lineagemgr/collection/model/__init__.py
+0
-14
mindinsight/lineagemgr/common/exceptions/exceptions.py
mindinsight/lineagemgr/common/exceptions/exceptions.py
+0
-19
mindinsight/lineagemgr/common/validator/model_parameter.py
mindinsight/lineagemgr/common/validator/model_parameter.py
+0
-63
mindinsight/lineagemgr/common/validator/validate.py
mindinsight/lineagemgr/common/validator/validate.py
+2
-75
mindinsight/lineagemgr/summary/_summary_adapter.py
mindinsight/lineagemgr/summary/_summary_adapter.py
+0
-198
tests/st/func/lineagemgr/collection/model/test_model_lineage.py
...st/func/lineagemgr/collection/model/test_model_lineage.py
+13
-269
tests/ut/lineagemgr/collection/__init__.py
tests/ut/lineagemgr/collection/__init__.py
+0
-14
tests/ut/lineagemgr/collection/model/__init__.py
tests/ut/lineagemgr/collection/model/__init__.py
+0
-14
tests/ut/lineagemgr/collection/model/test_model_lineage.py
tests/ut/lineagemgr/collection/model/test_model_lineage.py
+0
-456
tests/ut/lineagemgr/summary/test_event_writer.py
tests/ut/lineagemgr/summary/test_event_writer.py
+0
-49
tests/ut/lineagemgr/summary/test_summary_record.py
tests/ut/lineagemgr/summary/test_summary_record.py
+0
-80
tests/utils/lineage_writer/__init__.py
tests/utils/lineage_writer/__init__.py
+5
-0
tests/utils/lineage_writer/_event_writer.py
tests/utils/lineage_writer/_event_writer.py
+2
-2
tests/utils/lineage_writer/_summary_adapter.py
tests/utils/lineage_writer/_summary_adapter.py
+212
-0
tests/utils/lineage_writer/_summary_record.py
tests/utils/lineage_writer/_summary_record.py
+2
-4
tests/utils/lineage_writer/base.py
tests/utils/lineage_writer/base.py
+0
-0
tests/utils/lineage_writer/model_lineage.py
tests/utils/lineage_writer/model_lineage.py
+13
-20
未找到文件。
mindinsight/lineagemgr/__init__.py
浏览文件 @
7ce9f9bd
...
...
@@ -15,19 +15,12 @@
"""
Lineagemgr Module Introduction.
This module provides Python APIs to collect and query the lineage of models.
Users can add the TrainLineage/EvalLineage callback to the MindSpore train/eval callback list to
collect the key parameters and results, such as, the name of the network and optimizer, the
evaluation metric and results.
This module provides Python APIs to query the lineage of models.
The APIs can be used to get the lineage information of the models. For example,
what hyperparameter is used in the model training, which model has the highest
accuracy among all the versions, etc.
"""
from
mindinsight.lineagemgr.api.model
import
get_summary_lineage
,
filter_summary_lineage
from
mindinsight.lineagemgr.common.log
import
logger
try
:
from
mindinsight.lineagemgr.collection.model.model_lineage
import
TrainLineage
,
EvalLineage
except
(
ModuleNotFoundError
,
NameError
,
ImportError
):
logger
.
warning
(
'Not found MindSpore!'
)
__all__
=
[
"TrainLineage"
,
"EvalLineage"
,
"get_summary_lineage"
,
"filter_summary_lineage"
]
__all__
=
[
"get_summary_lineage"
,
"filter_summary_lineage"
]
mindinsight/lineagemgr/collection/model/__init__.py
已删除
100644 → 0
浏览文件 @
ed594114
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
mindinsight/lineagemgr/common/exceptions/exceptions.py
浏览文件 @
7ce9f9bd
...
...
@@ -37,16 +37,6 @@ class LineageParamValueError(MindInsightException):
)
class
LineageParamMissingError
(
MindInsightException
):
"""The parameter missing error in lineage module."""
def
__init__
(
self
,
msg
):
super
(
LineageParamMissingError
,
self
).
__init__
(
error
=
LineageErrors
.
PARAM_MISSING_ERROR
,
message
=
LineageErrorMsg
.
PARAM_MISSING_ERROR
.
value
.
format
(
msg
)
)
class
LineageParamRunContextError
(
MindInsightException
):
"""The input parameter run_context error in lineage module."""
...
...
@@ -67,15 +57,6 @@ class LineageGetModelFileError(MindInsightException):
)
class
LineageSearchModelParamError
(
MindInsightException
):
"""The lineage search model param error."""
def
__init__
(
self
,
msg
):
super
(
LineageSearchModelParamError
,
self
).
__init__
(
error
=
LineageErrors
.
LINEAGE_PARAM_NOT_SUPPORT_ERROR
,
message
=
LineageErrorMsg
.
LINEAGE_PARAM_NOT_SUPPORT_ERROR
.
value
.
format
(
msg
)
)
class
LineageSummaryAnalyzeException
(
MindInsightException
):
"""The summary analyze error in lineage module."""
...
...
mindinsight/lineagemgr/common/validator/model_parameter.py
浏览文件 @
7ce9f9bd
...
...
@@ -14,7 +14,6 @@
# ============================================================================
"""Define schema of model lineage input parameters."""
from
marshmallow
import
Schema
,
fields
,
ValidationError
,
pre_load
,
validates
from
marshmallow.validate
import
Range
from
mindinsight.lineagemgr.common.exceptions.error_code
import
LineageErrorMsg
,
\
LineageErrors
...
...
@@ -28,72 +27,10 @@ from mindinsight.utils.exceptions import MindInsightException
try
:
from
mindspore.dataset.engine
import
Dataset
from
mindspore.nn
import
Cell
,
Optimizer
from
mindspore.common.tensor
import
Tensor
from
mindspore.train.callback
import
_ListCallback
except
(
ImportError
,
ModuleNotFoundError
):
logger
.
error
(
'MindSpore Not Found!'
)
class
RunContextArgs
(
Schema
):
"""Define the parameter schema for RunContext."""
optimizer
=
fields
.
Function
(
allow_none
=
True
)
loss_fn
=
fields
.
Function
(
allow_none
=
True
)
net_outputs
=
fields
.
Function
(
allow_none
=
True
)
train_network
=
fields
.
Function
(
allow_none
=
True
)
train_dataset
=
fields
.
Function
(
allow_none
=
True
)
epoch_num
=
fields
.
Int
(
allow_none
=
True
,
validate
=
Range
(
min
=
1
))
batch_num
=
fields
.
Int
(
allow_none
=
True
,
validate
=
Range
(
min
=
0
))
cur_step_num
=
fields
.
Int
(
allow_none
=
True
,
validate
=
Range
(
min
=
0
))
parallel_mode
=
fields
.
Str
(
allow_none
=
True
)
device_number
=
fields
.
Int
(
allow_none
=
True
,
validate
=
Range
(
min
=
1
))
list_callback
=
fields
.
Function
(
allow_none
=
True
)
@
pre_load
def
check_optimizer
(
self
,
data
,
**
kwargs
):
optimizer
=
data
.
get
(
"optimizer"
)
if
optimizer
and
not
isinstance
(
optimizer
,
Optimizer
):
raise
ValidationError
({
'optimizer'
:
[
"Parameter optimizer must be an instance of mindspore.nn.optim.Optimizer."
]})
return
data
@
pre_load
def
check_train_network
(
self
,
data
,
**
kwargs
):
train_network
=
data
.
get
(
"train_network"
)
if
train_network
and
not
isinstance
(
train_network
,
Cell
):
raise
ValidationError
({
'train_network'
:
[
"Parameter train_network must be an instance of mindspore.nn.Cell."
]})
return
data
@
pre_load
def
check_train_dataset
(
self
,
data
,
**
kwargs
):
train_dataset
=
data
.
get
(
"train_dataset"
)
if
train_dataset
and
not
isinstance
(
train_dataset
,
Dataset
):
raise
ValidationError
({
'train_dataset'
:
[
"Parameter train_dataset must be an instance of "
"mindspore.dataengine.datasets.Dataset"
]})
return
data
@
pre_load
def
check_loss
(
self
,
data
,
**
kwargs
):
net_outputs
=
data
.
get
(
"net_outputs"
)
if
net_outputs
and
not
isinstance
(
net_outputs
,
Tensor
):
raise
ValidationError
({
'net_outpus'
:
[
"The parameter net_outputs is invalid. It should be a Tensor."
]})
return
data
@
pre_load
def
check_list_callback
(
self
,
data
,
**
kwargs
):
list_callback
=
data
.
get
(
"list_callback"
)
if
list_callback
and
not
isinstance
(
list_callback
,
_ListCallback
):
raise
ValidationError
({
'list_callback'
:
[
"Parameter list_callback must be an instance of "
"mindspore.train.callback._ListCallback."
]})
return
data
class
EvalParameter
(
Schema
):
"""Define the parameter schema for Evaluation job."""
...
...
mindinsight/lineagemgr/common/validator/validate.py
浏览文件 @
7ce9f9bd
...
...
@@ -18,18 +18,13 @@ import re
from
marshmallow
import
ValidationError
from
mindinsight.lineagemgr.common.exceptions.error_code
import
LineageErrors
,
LineageErrorMsg
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParam
Missing
Error
,
\
LineageParam
TypeError
,
LineageParam
ValueError
,
LineageDirNotExistError
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParam
Type
Error
,
\
LineageParamValueError
,
LineageDirNotExistError
from
mindinsight.lineagemgr.common.log
import
logger
as
log
from
mindinsight.lineagemgr.common.validator.validate_path
import
safe_normalize_path
from
mindinsight.lineagemgr.querier.query_model
import
FIELD_MAPPING
from
mindinsight.utils.exceptions
import
MindInsightException
,
ParamValueError
try
:
from
mindspore.nn
import
Cell
from
mindspore.train.summary
import
SummaryRecord
except
(
ImportError
,
ModuleNotFoundError
):
log
.
warning
(
'MindSpore Not Found!'
)
# Named string regular expression
_name_re
=
r
"^\w+[0-9a-zA-Z\_\.]*$"
...
...
@@ -144,31 +139,6 @@ def validate_int_params(int_param, param_name):
message
=
LineageErrorMsg
.
PARAM_BATCH_SIZE_ERROR
.
value
)
def
validate_network
(
network
):
"""
Verify if the network is valid.
Args:
network (Cell): See mindspore.nn.Cell.
Raises:
LineageParamMissingError: If the network is None.
MindInsightException: If the network is invalid.
"""
if
not
network
:
error_msg
=
"The input network for TrainLineage should not be None."
log
.
error
(
error_msg
)
raise
LineageParamMissingError
(
error_msg
)
if
not
isinstance
(
network
,
Cell
):
log
.
error
(
"Invalid network. Network should be an instance"
"of mindspore.nn.Cell."
)
raise
MindInsightException
(
error
=
LineageErrors
.
PARAM_TRAIN_NETWORK_ERROR
,
message
=
LineageErrorMsg
.
PARAM_TRAIN_NETWORK_ERROR
.
value
)
def
validate_file_path
(
file_path
,
allow_empty
=
False
):
"""
Verify that the file_path is valid.
...
...
@@ -190,28 +160,6 @@ def validate_file_path(file_path, allow_empty=False):
message
=
str
(
error
))
def
validate_train_run_context
(
schema
,
data
):
"""
Validate mindspore train run_context data according to schema.
Args:
schema (Schema): data schema.
data (dict): data to check schema.
Raises:
MindInsightException: If the parameters are invalid.
"""
errors
=
schema
().
validate
(
data
)
for
error_key
,
error_msg
in
errors
.
items
():
if
error_key
in
TRAIN_RUN_CONTEXT_ERROR_MAPPING
.
keys
():
error_code
=
TRAIN_RUN_CONTEXT_ERROR_MAPPING
.
get
(
error_key
)
if
TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING
.
get
(
error_key
):
error_msg
=
TRAIN_RUN_CONTEXT_ERROR_MSG_MAPPING
.
get
(
error_key
)
log
.
error
(
error_msg
)
raise
MindInsightException
(
error
=
error_code
,
message
=
error_msg
)
def
validate_eval_run_context
(
schema
,
data
):
"""
Validate mindspore evaluation job run_context data according to schema.
...
...
@@ -257,27 +205,6 @@ def validate_search_model_condition(schema, data):
raise
MindInsightException
(
error
=
error_code
,
message
=
error_msg
)
def
validate_summary_record
(
summary_record
):
"""
Validate summary_record.
Args:
summary_record (SummaryRecord): SummaryRecord is used to record
the summary value, and summary_record is an instance of SummaryRecord,
see mindspore.train.summary.SummaryRecord
Raises:
MindInsightException: If the parameters are invalid.
"""
if
not
isinstance
(
summary_record
,
SummaryRecord
):
log
.
error
(
"Invalid summary_record. It should be an instance "
"of mindspore.train.summary.SummaryRecord."
)
raise
MindInsightException
(
error
=
LineageErrors
.
PARAM_SUMMARY_RECORD_ERROR
,
message
=
LineageErrorMsg
.
PARAM_SUMMARY_RECORD_ERROR
.
value
)
def
validate_raise_exception
(
raise_exception
):
"""
Validate raise_exception.
...
...
mindinsight/lineagemgr/summary/_summary_adapter.py
浏览文件 @
7ce9f9bd
...
...
@@ -13,131 +13,6 @@
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import
socket
import
time
from
mindinsight.datavisual.proto_files.mindinsight_lineage_pb2
import
LineageEvent
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParamTypeError
from
mindinsight.lineagemgr.common.log
import
logger
as
log
# Set the Event mark
EVENT_FILE_NAME_MARK
=
"out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK
=
"_lineage"
def
package_dataset_graph
(
graph
):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event
=
LineageEvent
()
dataset_graph_event
.
wall_time
=
time
.
time
()
dataset_graph
=
dataset_graph_event
.
dataset_graph
if
"children"
in
graph
:
children
=
graph
.
pop
(
"children"
)
if
children
:
_package_children
(
children
=
children
,
message
=
dataset_graph
)
_package_current_dataset
(
operation
=
graph
,
message
=
dataset_graph
)
return
dataset_graph_event
def
_package_children
(
children
,
message
):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for
child
in
children
:
if
child
:
child_graph_message
=
getattr
(
message
,
"children"
).
add
()
grandson
=
child
.
pop
(
"children"
)
if
grandson
:
_package_children
(
children
=
grandson
,
message
=
child_graph_message
)
# package other parameters
_package_current_dataset
(
operation
=
child
,
message
=
child_graph_message
)
def
_package_current_dataset
(
operation
,
message
):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
value
and
key
==
"operations"
:
for
operator
in
value
:
_package_enhancement_operation
(
operator
,
message
.
operations
.
add
()
)
elif
value
and
key
==
"sampler"
:
_package_enhancement_operation
(
value
,
message
.
sampler
)
else
:
_package_parameter
(
key
,
value
,
message
.
parameter
)
def
_package_enhancement_operation
(
operation
,
message
):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
isinstance
(
value
,
list
):
if
all
(
isinstance
(
ele
,
int
)
for
ele
in
value
):
message
.
size
.
extend
(
value
)
else
:
message
.
weights
.
extend
(
value
)
else
:
_package_parameter
(
key
,
value
,
message
.
operationParam
)
def
_package_parameter
(
key
,
value
,
message
):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if
isinstance
(
value
,
str
):
message
.
mapStr
[
key
]
=
value
elif
isinstance
(
value
,
bool
):
message
.
mapBool
[
key
]
=
value
elif
isinstance
(
value
,
int
):
message
.
mapInt
[
key
]
=
value
elif
isinstance
(
value
,
float
):
message
.
mapDouble
[
key
]
=
value
elif
isinstance
(
value
,
list
)
and
key
!=
"operations"
:
if
value
:
replace_value_list
=
list
(
map
(
lambda
x
:
""
if
x
is
None
else
x
,
value
))
message
.
mapStrList
[
key
].
strValue
.
extend
(
replace_value_list
)
elif
value
is
None
:
message
.
mapStr
[
key
]
=
"None"
else
:
error_msg
=
"Parameter {} is not supported "
\
"in event package."
.
format
(
key
)
log
.
error
(
error_msg
)
raise
LineageParamTypeError
(
error_msg
)
def
organize_graph
(
graph_message
):
"""
...
...
@@ -296,76 +171,3 @@ def _organize_parameter(parameter):
parameter_result
.
update
(
result_str_list_para
)
return
parameter_result
def
package_user_defined_info
(
user_dict
):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event
=
LineageEvent
()
user_event
.
wall_time
=
time
.
time
()
user_defined_info
=
user_event
.
user_defined_info
_package_user_defined_info
(
user_dict
,
user_defined_info
)
return
user_event
def
_package_user_defined_info
(
user_defined_dict
,
user_defined_message
):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for
key
,
value
in
user_defined_dict
.
items
():
if
not
isinstance
(
key
,
str
):
error_msg
=
f
"Invalid key type in user defined info. The
{
key
}
's type"
\
f
"'
{
type
(
key
).
__name__
}
' is not supported. It should be str."
log
.
error
(
error_msg
)
if
isinstance
(
value
,
int
):
attr_name
=
"map_int32"
elif
isinstance
(
value
,
float
):
attr_name
=
"map_double"
elif
isinstance
(
value
,
str
):
attr_name
=
"map_str"
else
:
attr_name
=
"attr_name"
add_user_defined_info
=
user_defined_message
.
user_info
.
add
()
try
:
getattr
(
add_user_defined_info
,
attr_name
)[
key
]
=
value
except
AttributeError
:
error_msg
=
f
"Invalid value type in user defined info. The
{
value
}
's type"
\
f
"'
{
type
(
value
).
__name__
}
' is not supported. It should be float, int or str."
log
.
error
(
error_msg
)
def
get_lineage_file_name
():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second
=
str
(
int
(
time
.
time
()))
hostname
=
socket
.
gethostname
()
file_name
=
f
'
{
EVENT_FILE_NAME_MARK
}
summary.
{
time_second
}
.
{
hostname
}{
LINEAGE_FILE_NAME_MARK
}
'
return
file_name
tests/st/func/lineagemgr/collection/model/test_model_lineage.py
浏览文件 @
7ce9f9bd
...
...
@@ -23,25 +23,23 @@ Usage:
import
os
import
shutil
import
time
from
unittest
import
mock
,
TestCase
from
unittest
import
TestCase
,
mock
import
numpy
as
np
import
pytest
from
mindinsight.lineagemgr
import
get_summary_lineage
,
filter_summary_lineage
from
mindinsight.lineagemgr.collection.model.model_lineage
import
TrainLineage
,
EvalLineage
,
\
AnalyzeObject
from
mindinsight.lineagemgr.common.utils
import
make_directory
from
mindinsight.lineagemgr
import
get_summary_lineage
from
mindinsight.lineagemgr.common.exceptions.error_code
import
LineageErrors
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParamRunContextError
from
mindinsight.utils.exceptions
import
MindInsightException
from
mindspore.application.model_zoo.resnet
import
ResNet
from
mindspore.common.tensor
import
Tensor
from
mindspore.dataset.engine
import
MindDataset
from
mindspore.nn
import
Momentum
,
SoftmaxCrossEntropyWithLogits
,
WithLossCell
from
mindspore.train.callback
import
RunContext
,
ModelCheckpoin
t
,
SummaryStep
,
_ListCallback
from
mindspore.train.callback
import
ModelCheckpoint
,
RunContex
t
,
SummaryStep
,
_ListCallback
from
mindspore.train.summary
import
SummaryRecord
from
...conftest
import
SUMMARY_DIR
,
SUMMARY_DIR_2
,
SUMMARY_DIR_3
,
BASE_SUMMARY_DIR
from
tests.utils.lineage_writer.model_lineage
import
AnalyzeObject
,
EvalLineage
,
TrainLineage
from
...conftest
import
SUMMARY_DIR
,
SUMMARY_DIR_2
,
SUMMARY_DIR_3
from
.train_one_step
import
TrainOneStep
...
...
@@ -78,65 +76,6 @@ class TestModelLineage(TestCase):
"version"
:
"v1"
}
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_train_begin
(
self
):
"""Test the begin function in TrainLineage."""
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
assert
train_callback
.
initial_learning_rate
==
0.12
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
assert
os
.
path
.
isfile
(
lineage_log_path
)
is
True
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_train_begin_with_user_defined_info
(
self
):
"""Test TrainLineage with nested user defined info."""
user_defined_info
=
{
"info"
:
"info1"
,
"version"
:
"v1"
,
"network"
:
"LeNet"
}
train_callback
=
TrainLineage
(
self
.
summary_record
,
False
,
user_defined_info
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
assert
train_callback
.
initial_learning_rate
==
0.12
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
assert
os
.
path
.
isfile
(
lineage_log_path
)
is
True
res
=
filter_summary_lineage
(
os
.
path
.
dirname
(
lineage_log_path
))
assert
self
.
user_defined_info
==
res
[
'object'
][
0
][
'model_lineage'
][
'user_defined'
]
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_train_lineage_with_log_dir
(
self
):
"""Test TrainLineage with log_dir."""
summary_dir
=
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'log_dir'
)
train_callback
=
TrainLineage
(
summary_record
=
summary_dir
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
assert
summary_dir
==
train_callback
.
lineage_log_dir
lineage_log_path
=
train_callback
.
lineage_summary
.
lineage_log_path
assert
os
.
path
.
isfile
(
lineage_log_path
)
is
True
if
os
.
path
.
exists
(
summary_dir
):
shutil
.
rmtree
(
summary_dir
)
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -148,7 +87,7 @@ class TestModelLineage(TestCase):
def
test_training_end
(
self
,
*
args
):
"""Test the end function in TrainLineage."""
args
[
0
].
return_value
=
64
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
train_callback
=
TrainLineage
(
SUMMARY_DIR
,
True
,
self
.
user_defined_info
)
train_callback
.
initial_learning_rate
=
0.12
train_callback
.
end
(
RunContext
(
self
.
run_context
))
res
=
get_summary_lineage
(
SUMMARY_DIR
)
...
...
@@ -175,33 +114,6 @@ class TestModelLineage(TestCase):
eval_run_context
[
'step_num'
]
=
32
eval_callback
.
end
(
RunContext
(
eval_run_context
))
@
pytest
.
mark
.
scene_eval
(
3
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_eval_only
(
self
):
"""Test record evaluation event only."""
summary_dir
=
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'eval_only_dir'
)
summary_record
=
SummaryRecord
(
summary_dir
)
eval_run_context
=
self
.
run_context
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.58
}
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
eval_run_context
[
'step_num'
]
=
32
eval_only_callback
=
EvalLineage
(
summary_record
)
eval_only_callback
.
end
(
RunContext
(
eval_run_context
))
res
=
get_summary_lineage
(
summary_dir
,
[
'metric'
,
'dataset_graph'
])
expect_res
=
{
'summary_dir'
:
summary_dir
,
'dataset_graph'
:
{},
'metric'
:
{
'accuracy'
:
0.58
}
}
assert
res
==
expect_res
shutil
.
rmtree
(
summary_dir
)
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -209,7 +121,7 @@ class TestModelLineage(TestCase):
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
@
mock
.
patch
(
'
mindinsight.lineagemgr.summary.
summary_record.get_lineage_file_name'
)
@
mock
.
patch
(
'
tests.utils.lineage_writer._
summary_record.get_lineage_file_name'
)
@
mock
.
patch
(
'os.path.getsize'
)
def
test_multiple_trains
(
self
,
*
args
):
"""
...
...
@@ -247,89 +159,6 @@ class TestModelLineage(TestCase):
file_num
=
os
.
listdir
(
SUMMARY_DIR_2
)
assert
len
(
file_num
)
==
8
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
@
mock
.
patch
(
'mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name'
)
@
mock
.
patch
(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size'
)
def
test_train_eval
(
self
,
*
args
):
"""Callback for train once and eval once."""
args
[
0
].
return_value
=
10
summary_dir
=
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'train_eval'
)
make_directory
(
summary_dir
)
args
[
1
].
return_value
=
os
.
path
.
join
(
summary_dir
,
f
'train_out.events.summary.
{
str
(
int
(
time
.
time
()))
}
.ubuntu_lineage'
)
train_callback
=
TrainLineage
(
summary_dir
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
train_callback
.
end
(
RunContext
(
self
.
run_context
))
args
[
1
].
return_value
=
os
.
path
.
join
(
summary_dir
,
f
'eval_out.events.summary.
{
str
(
int
(
time
.
time
())
+
1
)
}
.ubuntu_lineage'
)
eval_callback
=
EvalLineage
(
summary_dir
)
eval_run_context
=
self
.
run_context
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.78
}
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
eval_run_context
[
'step_num'
]
=
32
eval_callback
.
end
(
RunContext
(
eval_run_context
))
res
=
get_summary_lineage
(
summary_dir
)
assert
res
.
get
(
'hyper_parameters'
,
{}).
get
(
'loss_function'
)
\
==
'SoftmaxCrossEntropyWithLogits'
assert
res
.
get
(
'algorithm'
,
{}).
get
(
'network'
)
==
'ResNet'
if
os
.
path
.
exists
(
summary_dir
):
shutil
.
rmtree
(
summary_dir
)
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
@
mock
.
patch
(
'mindinsight.lineagemgr.summary.summary_record.get_lineage_file_name'
)
@
mock
.
patch
(
'mindinsight.lineagemgr.collection.model.model_lineage.AnalyzeObject.get_file_size'
)
def
test_train_multi_eval
(
self
,
*
args
):
"""Callback for train once and eval twice."""
args
[
0
].
return_value
=
10
summary_dir
=
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'train_multi_eval'
)
make_directory
(
summary_dir
)
args
[
1
].
return_value
=
os
.
path
.
join
(
summary_dir
,
'train_out.events.summary.1590107366.ubuntu_lineage'
)
train_callback
=
TrainLineage
(
summary_dir
,
True
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
train_callback
.
end
(
RunContext
(
self
.
run_context
))
args
[
1
].
return_value
=
os
.
path
.
join
(
summary_dir
,
'eval_out.events.summary.1590107367.ubuntu_lineage'
)
eval_callback
=
EvalLineage
(
summary_dir
,
True
)
eval_run_context
=
self
.
run_context
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.79
}
eval_callback
.
end
(
RunContext
(
eval_run_context
))
res
=
get_summary_lineage
(
summary_dir
)
assert
res
.
get
(
'metric'
,
{}).
get
(
'accuracy'
)
==
0.79
args
[
1
].
return_value
=
os
.
path
.
join
(
summary_dir
,
'eval_out.events.summary.1590107368.ubuntu_lineage'
)
eval_callback
=
EvalLineage
(
summary_dir
,
True
)
eval_run_context
=
self
.
run_context
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.80
}
eval_callback
.
end
(
RunContext
(
eval_run_context
))
res
=
get_summary_lineage
(
summary_dir
)
assert
res
.
get
(
'metric'
,
{}).
get
(
'accuracy'
)
==
0.80
if
os
.
path
.
exists
(
summary_dir
):
shutil
.
rmtree
(
summary_dir
)
@
pytest
.
mark
.
scene_train
(
2
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
...
...
@@ -341,7 +170,7 @@ class TestModelLineage(TestCase):
def
test_train_with_customized_network
(
self
,
*
args
):
"""Test train with customized network."""
args
[
0
].
return_value
=
64
train_callback
=
TrainLineage
(
self
.
summary_record
,
True
,
self
.
user_defined_info
)
train_callback
=
TrainLineage
(
SUMMARY_DIR
,
True
,
self
.
user_defined_info
)
run_context_customized
=
self
.
run_context
del
run_context_customized
[
'optimizer'
]
del
run_context_customized
[
'net_outputs'
]
...
...
@@ -363,27 +192,6 @@ class TestModelLineage(TestCase):
assert
res
.
get
(
'algorithm'
,
{}).
get
(
'network'
)
==
'ResNet'
assert
res
.
get
(
'hyper_parameters'
,
{}).
get
(
'optimizer'
)
==
'Momentum'
@
pytest
.
mark
.
scene_exception
(
1
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_raise_exception
(
self
):
"""Test exception when raise_exception is set True."""
summary_record
=
SummaryRecord
(
SUMMARY_DIR_3
)
full_file_name
=
summary_record
.
full_file_name
assert
os
.
path
.
isfile
(
full_file_name
)
is
True
assert
os
.
path
.
isfile
(
full_file_name
+
"_lineage"
)
is
False
train_callback
=
TrainLineage
(
summary_record
,
True
)
eval_callback
=
EvalLineage
(
summary_record
,
False
)
with
self
.
assertRaises
(
LineageParamRunContextError
):
train_callback
.
begin
(
self
.
run_context
)
eval_callback
.
end
(
self
.
run_context
)
file_num
=
os
.
listdir
(
SUMMARY_DIR_3
)
assert
len
(
file_num
)
==
1
assert
os
.
path
.
isfile
(
full_file_name
+
"_lineage"
)
is
False
@
pytest
.
mark
.
scene_exception
(
1
)
@
pytest
.
mark
.
level0
...
...
@@ -392,53 +200,7 @@ class TestModelLineage(TestCase):
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_raise_exception_init
(
self
):
"""Test exception when error happened during the initialization process."""
if
os
.
path
.
exists
(
SUMMARY_DIR_3
):
shutil
.
rmtree
(
SUMMARY_DIR_3
)
summary_record
=
SummaryRecord
(
SUMMARY_DIR_3
)
train_callback
=
TrainLineage
(
'fake_summary_record'
,
False
)
eval_callback
=
EvalLineage
(
'fake_summary_record'
,
False
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
eval_callback
.
end
(
RunContext
(
self
.
run_context
))
file_num
=
os
.
listdir
(
SUMMARY_DIR_3
)
full_file_name
=
summary_record
.
full_file_name
assert
len
(
file_num
)
==
1
assert
os
.
path
.
isfile
(
full_file_name
+
"_lineage"
)
is
False
@
pytest
.
mark
.
scene_exception
(
1
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_raise_exception_create_file
(
self
):
"""Test exception when error happened after creating file."""
if
os
.
path
.
exists
(
SUMMARY_DIR_3
):
shutil
.
rmtree
(
SUMMARY_DIR_3
)
summary_record
=
SummaryRecord
(
SUMMARY_DIR_3
)
eval_callback
=
EvalLineage
(
summary_record
,
False
)
full_file_name
=
summary_record
.
full_file_name
+
"_lineage"
eval_run_context
=
self
.
run_context
eval_run_context
[
'metrics'
]
=
{
'accuracy'
:
0.78
}
eval_run_context
[
'step_num'
]
=
32
eval_run_context
[
'valid_dataset'
]
=
self
.
run_context
[
'train_dataset'
]
with
open
(
full_file_name
,
'ab'
):
with
mock
.
patch
(
'builtins.open'
)
as
mock_handler
:
mock_handler
.
return_value
.
__enter__
.
return_value
.
write
.
side_effect
=
IOError
eval_callback
.
end
(
RunContext
(
eval_run_context
))
assert
os
.
path
.
isfile
(
full_file_name
)
is
True
assert
os
.
path
.
getsize
(
full_file_name
)
==
0
@
pytest
.
mark
.
scene_exception
(
1
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
@
mock
.
patch
(
'mindinsight.lineagemgr.collection.model.model_lineage.validate_eval_run_context'
)
@
mock
.
patch
(
'tests.utils.lineage_writer.model_lineage.validate_eval_run_context'
)
@
mock
.
patch
.
object
(
AnalyzeObject
,
'get_file_size'
,
return_value
=
64
)
def
test_raise_exception_record_trainlineage
(
self
,
*
args
):
"""Test exception when error happened after recording training infos."""
...
...
@@ -446,32 +208,14 @@ class TestModelLineage(TestCase):
shutil
.
rmtree
(
SUMMARY_DIR_3
)
args
[
1
].
side_effect
=
MindInsightException
(
error
=
LineageErrors
.
PARAM_RUN_CONTEXT_ERROR
,
message
=
"RunContext error."
)
summary_record
=
SummaryRecord
(
SUMMARY_DIR_3
)
train_callback
=
TrainLineage
(
summary_record
,
True
)
train_callback
=
TrainLineage
(
SUMMARY_DIR_3
,
True
)
train_callback
.
begin
(
RunContext
(
self
.
run_context
))
full_file_name
=
train_callback
.
lineage_summary
.
lineage_log_path
file_size1
=
os
.
path
.
getsize
(
full_file_name
)
train_callback
.
end
(
RunContext
(
self
.
run_context
))
file_size2
=
os
.
path
.
getsize
(
full_file_name
)
assert
file_size2
>
file_size1
eval_callback
=
EvalLineage
(
summary_record
,
False
)
eval_callback
=
EvalLineage
(
SUMMARY_DIR_3
,
False
)
eval_callback
.
end
(
RunContext
(
self
.
run_context
))
file_size3
=
os
.
path
.
getsize
(
full_file_name
)
assert
file_size3
==
file_size2
@
pytest
.
mark
.
scene_exception
(
1
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_arm_ascend_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_ascend_training
@
pytest
.
mark
.
platform_x86_cpu
@
pytest
.
mark
.
env_single
def
test_raise_exception_non_lineage_file
(
self
):
"""Test exception when lineage summary file cannot be found."""
summary_dir
=
os
.
path
.
join
(
BASE_SUMMARY_DIR
,
'run4'
)
if
os
.
path
.
exists
(
summary_dir
):
shutil
.
rmtree
(
summary_dir
)
summary_record
=
SummaryRecord
(
summary_dir
,
file_suffix
=
'_MS_lineage_none'
)
full_file_name
=
summary_record
.
full_file_name
assert
full_file_name
.
endswith
(
'_lineage_none'
)
assert
os
.
path
.
isfile
(
full_file_name
)
tests/ut/lineagemgr/collection/__init__.py
已删除
100644 → 0
浏览文件 @
ed594114
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
tests/ut/lineagemgr/collection/model/__init__.py
已删除
100644 → 0
浏览文件 @
ed594114
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
tests/ut/lineagemgr/collection/model/test_model_lineage.py
已删除
100644 → 0
浏览文件 @
ed594114
此差异已折叠。
点击以展开。
tests/ut/lineagemgr/summary/test_event_writer.py
已删除
100644 → 0
浏览文件 @
ed594114
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class EventWriter."""
import
os
from
unittest
import
TestCase
from
mindinsight.lineagemgr.summary.event_writer
import
EventWriter
from
mindinsight.lineagemgr.summary.summary_record
import
LineageSummary
from
mindinsight.lineagemgr.summary.lineage_summary_analyzer
import
LineageSummaryAnalyzer
class
TestEventWriter
(
TestCase
):
"""Test write_event_to_file."""
def
setUp
(
self
):
"""The setup of test."""
self
.
log_path
=
"./test.log"
def
test_write_event_to_file
(
self
):
"""Test write event to file."""
run_context_args
=
{
"train_network"
:
"res"
}
content
=
LineageSummary
.
package_train_message
(
run_context_args
).
SerializeToString
()
event_writer
=
EventWriter
(
self
.
log_path
,
True
)
event_writer
.
write_event_to_file
(
content
)
lineage_info
=
LineageSummaryAnalyzer
.
get_summary_infos
(
self
.
log_path
)
self
.
assertEqual
(
lineage_info
.
train_lineage
.
train_lineage
.
algorithm
.
network
,
run_context_args
.
get
(
"train_network"
)
)
def
tearDown
(
self
):
"""The setup of test."""
if
os
.
path
.
exists
(
self
.
log_path
):
try
:
os
.
remove
(
self
.
log_path
)
except
IOError
:
pass
tests/ut/lineagemgr/summary/test_summary_record.py
已删除
100644 → 0
浏览文件 @
ed594114
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test class SummaryRecord."""
import
json
from
unittest
import
TestCase
,
mock
from
mindinsight.lineagemgr.summary.event_writer
import
EventWriter
from
mindinsight.lineagemgr.summary.summary_record
import
LineageSummary
class
TestSummaryRecord
(
TestCase
):
"""Test summary record."""
def
setUp
(
self
):
"""The setup of test."""
self
.
run_context_args
=
dict
()
self
.
run_context_args
[
"train_network"
]
=
"test_train_network"
self
.
run_context_args
[
"loss"
]
=
0.1
self
.
run_context_args
[
"learning_rate"
]
=
0.1
self
.
run_context_args
[
"optimizer"
]
=
"test_optimizer"
self
.
run_context_args
[
"loss_function"
]
=
"test_loss_function"
self
.
run_context_args
[
"epoch"
]
=
1
self
.
run_context_args
[
"parallel_mode"
]
=
"test_parallel_mode"
self
.
run_context_args
[
"device_num"
]
=
1
self
.
run_context_args
[
"batch_size"
]
=
1
self
.
run_context_args
[
"train_dataset_path"
]
=
"test_train_dataset_path"
self
.
run_context_args
[
"train_dataset_size"
]
=
1
self
.
run_context_args
[
"model_path"
]
=
"test_model_path"
self
.
run_context_args
[
"model_size"
]
=
1
self
.
eval_args
=
dict
()
self
.
eval_args
[
"metrics"
]
=
json
.
dumps
({
"acc"
:
"test"
})
self
.
eval_args
[
"valid_dataset_path"
]
=
"test_valid_dataset_path"
self
.
eval_args
[
"valid_dataset_size"
]
=
1
self
.
hard_info_args
=
dict
()
self
.
hard_info_args
[
"pid"
]
=
1
self
.
hard_info_args
[
"process_start_time"
]
=
921226.0
def
test_package_train_message
(
self
):
"""Test package_train_message."""
event
=
LineageSummary
.
package_train_message
(
self
.
run_context_args
)
self
.
assertEqual
(
event
.
train_lineage
.
algorithm
.
network
,
self
.
run_context_args
.
get
(
"train_network"
))
self
.
assertEqual
(
event
.
train_lineage
.
hyper_parameters
.
optimizer
,
self
.
run_context_args
.
get
(
"optimizer"
))
self
.
assertEqual
(
event
.
train_lineage
.
train_dataset
.
train_dataset_path
,
self
.
run_context_args
.
get
(
"train_dataset_path"
)
)
@
mock
.
patch
.
object
(
EventWriter
,
"write_event_to_file"
)
def
test_record_train_lineage
(
self
,
write_file
):
"""Test record_train_lineage."""
write_file
.
return_value
=
True
lineage_summray
=
LineageSummary
(
lineage_log_dir
=
"test.log"
)
lineage_summray
.
record_train_lineage
(
self
.
run_context_args
)
def
test_package_evaluation_message
(
self
):
"""Test package_evaluation_message."""
event
=
LineageSummary
.
package_evaluation_message
(
self
.
eval_args
)
self
.
assertEqual
(
event
.
evaluation_lineage
.
metric
,
self
.
eval_args
.
get
(
"metrics"
))
@
mock
.
patch
.
object
(
EventWriter
,
"write_event_to_file"
)
def
test_record_eval_lineage
(
self
,
write_file
):
"""Test record_eval_lineage."""
write_file
.
return_value
=
True
lineage_summray
=
LineageSummary
(
lineage_log_dir
=
"test.log"
)
lineage_summray
.
record_evaluation_lineage
(
self
.
eval_args
)
mindinsight/lineagemgr/collection
/__init__.py
→
tests/utils/lineage_writer
/__init__.py
浏览文件 @
7ce9f9bd
...
...
@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Lineage writer module."""
from
._summary_record
import
LineageSummary
__all__
=
[
"LineageSummary"
]
mindinsight/lineagemgr/summary/
event_writer.py
→
tests/utils/lineage_writer/_
event_writer.py
浏览文件 @
7ce9f9bd
...
...
@@ -17,7 +17,7 @@ import os
import
stat
import
struct
from
mindinsight.datavisual
.utils
import
crc32
from
tests
.utils
import
crc32
class
EventWriter
:
...
...
@@ -84,6 +84,6 @@ class EventWriter:
Returns:
bytes, crc of content, 4 bytes.
"""
crc_value
=
crc32
.
GetMaskCrc32cValue
(
content
,
len
(
content
)
)
crc_value
=
crc32
.
get_mask_from_string
(
content
)
return
struct
.
pack
(
"<L"
,
crc_value
)
tests/utils/lineage_writer/_summary_adapter.py
0 → 100644
浏览文件 @
7ce9f9bd
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""The converter between proto format event of lineage and dict."""
import
socket
import
time
from
mindinsight.datavisual.proto_files.mindinsight_lineage_pb2
import
LineageEvent
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParamTypeError
from
mindinsight.lineagemgr.common.log
import
logger
as
log
# Set the Event mark
EVENT_FILE_NAME_MARK
=
"out.events."
# Set lineage file mark
LINEAGE_FILE_NAME_MARK
=
"_lineage"
def
package_dataset_graph
(
graph
):
"""
Package dataset graph.
Args:
graph (dict): Dataset graph.
Returns:
LineageEvent, the proto message event contains dataset graph.
"""
dataset_graph_event
=
LineageEvent
()
dataset_graph_event
.
wall_time
=
time
.
time
()
dataset_graph
=
dataset_graph_event
.
dataset_graph
if
"children"
in
graph
:
children
=
graph
.
pop
(
"children"
)
if
children
:
_package_children
(
children
=
children
,
message
=
dataset_graph
)
_package_current_dataset
(
operation
=
graph
,
message
=
dataset_graph
)
return
dataset_graph_event
def
_package_children
(
children
,
message
):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for
child
in
children
:
if
child
:
child_graph_message
=
getattr
(
message
,
"children"
).
add
()
grandson
=
child
.
pop
(
"children"
)
if
grandson
:
_package_children
(
children
=
grandson
,
message
=
child_graph_message
)
# package other parameters
_package_current_dataset
(
operation
=
child
,
message
=
child_graph_message
)
def
_package_current_dataset
(
operation
,
message
):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
value
and
key
==
"operations"
:
for
operator
in
value
:
_package_enhancement_operation
(
operator
,
message
.
operations
.
add
()
)
elif
value
and
key
==
"sampler"
:
_package_enhancement_operation
(
value
,
message
.
sampler
)
else
:
_package_parameter
(
key
,
value
,
message
.
parameter
)
def
_package_enhancement_operation
(
operation
,
message
):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for
key
,
value
in
operation
.
items
():
if
isinstance
(
value
,
list
):
if
all
(
isinstance
(
ele
,
int
)
for
ele
in
value
):
message
.
size
.
extend
(
value
)
else
:
message
.
weights
.
extend
(
value
)
else
:
_package_parameter
(
key
,
value
,
message
.
operationParam
)
def
_package_parameter
(
key
,
value
,
message
):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if
isinstance
(
value
,
str
):
message
.
mapStr
[
key
]
=
value
elif
isinstance
(
value
,
bool
):
message
.
mapBool
[
key
]
=
value
elif
isinstance
(
value
,
int
):
message
.
mapInt
[
key
]
=
value
elif
isinstance
(
value
,
float
):
message
.
mapDouble
[
key
]
=
value
elif
isinstance
(
value
,
list
)
and
key
!=
"operations"
:
if
value
:
replace_value_list
=
list
(
map
(
lambda
x
:
""
if
x
is
None
else
x
,
value
))
message
.
mapStrList
[
key
].
strValue
.
extend
(
replace_value_list
)
elif
value
is
None
:
message
.
mapStr
[
key
]
=
"None"
else
:
error_msg
=
"Parameter {} is not supported "
\
"in event package."
.
format
(
key
)
log
.
error
(
error_msg
)
raise
LineageParamTypeError
(
error_msg
)
def
package_user_defined_info
(
user_dict
):
"""
Package user defined info.
Args:
user_dict(dict): User defined info dict.
Returns:
LineageEvent, the proto message event contains user defined info.
"""
user_event
=
LineageEvent
()
user_event
.
wall_time
=
time
.
time
()
user_defined_info
=
user_event
.
user_defined_info
_package_user_defined_info
(
user_dict
,
user_defined_info
)
return
user_event
def
_package_user_defined_info
(
user_defined_dict
,
user_defined_message
):
"""
Setting attribute in user defined proto message.
Args:
user_defined_dict (dict): User define info dict.
user_defined_message (LineageEvent): Proto message of user defined info.
Raises:
LineageParamValueError: When the value is out of range.
LineageParamTypeError: When given a type not support yet.
"""
for
key
,
value
in
user_defined_dict
.
items
():
if
not
isinstance
(
key
,
str
):
error_msg
=
f
"Invalid key type in user defined info. The
{
key
}
's type"
\
f
"'
{
type
(
key
).
__name__
}
' is not supported. It should be str."
log
.
error
(
error_msg
)
if
isinstance
(
value
,
int
):
attr_name
=
"map_int32"
elif
isinstance
(
value
,
float
):
attr_name
=
"map_double"
elif
isinstance
(
value
,
str
):
attr_name
=
"map_str"
else
:
attr_name
=
"attr_name"
add_user_defined_info
=
user_defined_message
.
user_info
.
add
()
try
:
getattr
(
add_user_defined_info
,
attr_name
)[
key
]
=
value
except
AttributeError
:
error_msg
=
f
"Invalid value type in user defined info. The
{
value
}
's type"
\
f
"'
{
type
(
value
).
__name__
}
' is not supported. It should be float, int or str."
log
.
error
(
error_msg
)
def
get_lineage_file_name
():
"""
Get lineage file name.
Lineage filename format is:
EVENT_FILE_NAME_MARK + "summary." + time(seconds) + "." + Hostname + lineage_suffix.
Returns:
str, the name of event log file.
"""
time_second
=
str
(
int
(
time
.
time
()))
hostname
=
socket
.
gethostname
()
file_name
=
f
'
{
EVENT_FILE_NAME_MARK
}
summary.
{
time_second
}
.
{
hostname
}{
LINEAGE_FILE_NAME_MARK
}
'
return
file_name
mindinsight/lineagemgr/summary/
summary_record.py
→
tests/utils/lineage_writer/_
summary_record.py
浏览文件 @
7ce9f9bd
...
...
@@ -17,8 +17,7 @@ import os
import
time
from
mindinsight.datavisual.proto_files.mindinsight_lineage_pb2
import
LineageEvent
from
mindinsight.lineagemgr.common.validator.validate
import
validate_file_path
from
mindinsight.lineagemgr.summary.event_writer
import
EventWriter
from
._event_writer
import
EventWriter
from
._summary_adapter
import
package_dataset_graph
,
package_user_defined_info
,
get_lineage_file_name
...
...
@@ -40,11 +39,10 @@ class LineageSummary:
>>> lineage_summary.record_train_lineage(train_lineage)
"""
def
__init__
(
self
,
lineage_log_dir
=
None
,
lineage_log_dir
,
override
=
False
):
lineage_log_name
=
get_lineage_file_name
()
self
.
lineage_log_path
=
os
.
path
.
join
(
lineage_log_dir
,
lineage_log_name
)
validate_file_path
(
self
.
lineage_log_path
)
self
.
event_writer
=
EventWriter
(
self
.
lineage_log_path
,
override
)
@
staticmethod
...
...
mindinsight/lineagemgr/collection/model
/base.py
→
tests/utils/lineage_writer
/base.py
浏览文件 @
7ce9f9bd
文件已移动
mindinsight/lineagemgr/collection/model
/model_lineage.py
→
tests/utils/lineage_writer
/model_lineage.py
浏览文件 @
7ce9f9bd
...
...
@@ -17,22 +17,20 @@ import json
import
os
import
numpy
as
np
from
mindinsight.lineagemgr.summary.summary_record
import
LineageSummary
from
mindinsight.utils.exceptions
import
\
MindInsightException
from
mindinsight.lineagemgr.common.validator.validate
import
validate_train_run_context
,
\
validate_eval_run_context
,
validate_file_path
,
validate_network
,
\
validate_int_params
,
validate_summary_record
,
validate_raise_exception
,
\
validate_user_defined_info
from
mindinsight.lineagemgr.common.exceptions.error_code
import
LineageErrors
,
LineageErrorMsg
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
LineageParamRunContextError
,
\
LineageGetModelFileError
,
LineageLogError
from
mindinsight.lineagemgr.common.exceptions.error_code
import
LineageErrorMsg
,
LineageErrors
from
mindinsight.lineagemgr.common.exceptions.exceptions
import
(
LineageGetModelFileError
,
LineageLogError
,
LineageParamRunContextError
)
from
mindinsight.lineagemgr.common.log
import
logger
as
log
from
mindinsight.lineagemgr.common.utils
import
try_except
,
make_directory
from
mindinsight.lineagemgr.common.validator.model_parameter
import
RunContextArgs
,
\
EvalParameter
from
mindinsight.lineagemgr.collection.model.base
import
Metadata
from
mindinsight.lineagemgr.common.utils
import
make_directory
,
try_except
from
mindinsight.lineagemgr.common.validator.model_parameter
import
EvalParameter
from
mindinsight.lineagemgr.common.validator.validate
import
(
validate_eval_run_context
,
validate_file_path
,
validate_int_params
,
validate_raise_exception
,
validate_user_defined_info
)
from
mindinsight.utils.exceptions
import
MindInsightException
from
._summary_record
import
LineageSummary
from
.base
import
Metadata
try
:
from
mindspore.common.tensor
import
Tensor
...
...
@@ -91,7 +89,6 @@ class TrainLineage(Callback):
# make directory if not exist
self
.
lineage_log_dir
=
make_directory
(
summary_record
)
else
:
validate_summary_record
(
summary_record
)
summary_log_path
=
summary_record
.
full_file_name
validate_file_path
(
summary_log_path
)
self
.
lineage_log_dir
=
os
.
path
.
dirname
(
summary_log_path
)
...
...
@@ -145,7 +142,6 @@ class TrainLineage(Callback):
log
.
debug
(
'initial_learning_rate: %s'
,
self
.
initial_learning_rate
)
else
:
network
=
run_context_args
.
get
(
'train_network'
)
validate_network
(
network
)
optimizer
=
AnalyzeObject
.
get_optimizer_by_network
(
network
)
self
.
initial_learning_rate
=
AnalyzeObject
.
analyze_optimizer
(
optimizer
)
log
.
debug
(
'initial_learning_rate: %s'
,
self
.
initial_learning_rate
)
...
...
@@ -183,7 +179,6 @@ class TrainLineage(Callback):
raise
LineageParamRunContextError
(
error_msg
)
run_context_args
=
run_context
.
original_args
()
validate_train_run_context
(
RunContextArgs
,
run_context_args
)
train_lineage
=
dict
()
train_lineage
=
AnalyzeObject
.
get_network_args
(
...
...
@@ -277,7 +272,6 @@ class EvalLineage(Callback):
# make directory if not exist
self
.
lineage_log_dir
=
make_directory
(
summary_record
)
else
:
validate_summary_record
(
summary_record
)
summary_log_path
=
summary_record
.
full_file_name
validate_file_path
(
summary_log_path
)
self
.
lineage_log_dir
=
os
.
path
.
dirname
(
summary_log_path
)
...
...
@@ -639,7 +633,6 @@ class AnalyzeObject:
dict, the lineage metadata.
"""
network
=
run_context_args
.
get
(
'train_network'
)
validate_network
(
network
)
optimizer
=
run_context_args
.
get
(
'optimizer'
)
if
not
optimizer
:
optimizer
=
AnalyzeObject
.
get_optimizer_by_network
(
network
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录