提交 9514b52a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2147 Add a callback named SummaryCollector and delete SummaryStep callback

Merge pull request !2147 from ougongchang/master
......@@ -14,7 +14,10 @@
# ============================================================================
"""Train utility."""
import os
from collections.abc import Iterable
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
......@@ -213,6 +216,7 @@ def _check_to_numpy(plugin, tensor):
raise ValueError('The tensor should not be empty.')
return np_value
def _check_lineage_value(plugin, value):
"""Check the lineage value."""
def raises(plugin, prototype):
......@@ -229,3 +233,20 @@ def _check_lineage_value(plugin, value):
if plugin == 'custom_lineage_data' and not isinstance(value, UserDefinedInfo):
raises(plugin, UserDefinedInfo)
def check_value_type(arg_name, arg_value, valid_types):
"""Checks whether a value is instance of some types."""
valid_types = tuple(valid_types) if isinstance(valid_types, Iterable) else (valid_types,)
is_valid = True
# bool is subclass of int, so for a bool value, we need to extra check
if isinstance(arg_value, int) and isinstance(arg_value, bool) and bool not in valid_types:
is_valid = False
if not isinstance(arg_value, valid_types):
is_valid = False
if not is_valid:
raise TypeError(f'For `{arg_name}` the type should be a valid type of {[t.__name__ for t in valid_types]}, '
f'bug got {type(arg_value).__name__}.')
......@@ -22,7 +22,8 @@ from ._checkpoint import CheckpointConfig
from ._checkpoint import CheckpointManager as _CheckpointManager
from ._checkpoint import ModelCheckpoint
from ._loss_monitor import LossMonitor
from ._summary_step import SummaryStep
from ._time_monitor import TimeMonitor
from ._summary_collector import SummaryCollector
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]
__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint",
"SummaryCollector", "CheckpointConfig", "RunContext"]
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define dataset graph related operations."""
import json
from importlib import import_module
from mindspore.train import lineage_pb2
class DatasetGraph:
"""Handle the data graph and packages it into binary data."""
def package_dataset_graph(self, dataset):
"""
packages dataset graph into binary data
Args:
dataset (MindData): refer to MindDataset
Returns:
DatasetGraph, a object of lineage_pb2.DatasetGraph.
"""
dataset_package = import_module('mindspore.dataset')
dataset_dict = dataset_package.serialize(dataset)
json_str = json.dumps(dataset_dict, indent=2)
dataset_dict = json.loads(json_str)
dataset_graph_proto = lineage_pb2.DatasetGraph()
if "children" in dataset_dict:
children = dataset_dict.pop("children")
if children:
self._package_children(children=children, message=dataset_graph_proto)
self._package_current_dataset(operation=dataset_dict, message=dataset_graph_proto)
return dataset_graph_proto
def _package_children(self, children, message):
"""
Package children in dataset operation.
Args:
children (list[dict]): Child operations.
message (DatasetGraph): Children proto message.
"""
for child in children:
if child:
child_graph_message = getattr(message, "children").add()
grandson = child.pop("children")
if grandson:
self._package_children(children=grandson, message=child_graph_message)
# package other parameters
self._package_current_dataset(operation=child, message=child_graph_message)
def _package_current_dataset(self, operation, message):
"""
Package operation parameters in event message.
Args:
operation (dict): Operation dict.
message (Operation): Operation proto message.
"""
for key, value in operation.items():
if value and key == "operations":
for operator in value:
self._package_enhancement_operation(
operator,
message.operations.add()
)
elif value and key == "sampler":
self._package_enhancement_operation(
value,
message.sampler
)
else:
self._package_parameter(key, value, message.parameter)
def _package_enhancement_operation(self, operation, message):
"""
Package enhancement operation in MapDataset.
Args:
operation (dict): Enhancement operation.
message (Operation): Enhancement operation proto message.
"""
for key, value in operation.items():
if isinstance(value, list):
if all(isinstance(ele, int) for ele in value):
message.size.extend(value)
else:
message.weights.extend(value)
else:
self._package_parameter(key, value, message.operationParam)
@staticmethod
def _package_parameter(key, value, message):
"""
Package parameters in operation.
Args:
key (str): Operation name.
value (Union[str, bool, int, float, list, None]): Operation args.
message (OperationParameter): Operation proto message.
"""
if isinstance(value, str):
message.mapStr[key] = value
elif isinstance(value, bool):
message.mapBool[key] = value
elif isinstance(value, int):
message.mapInt[key] = value
elif isinstance(value, float):
message.mapDouble[key] = value
elif isinstance(value, list) and key != "operations":
if value:
replace_value_list = list(map(lambda x: "" if x is None else x, value))
message.mapStrList[key].strValue.extend(replace_value_list)
elif value is None:
message.mapStr[key] = "None"
else:
raise ValueError(f"Parameter {key} is not supported in event package.")
此差异已折叠。
......@@ -13,6 +13,8 @@
# limitations under the License.
# ============================================================================
"""Model."""
from collections.abc import Iterable
import numpy as np
from mindspore import log as logger
......@@ -345,7 +347,8 @@ class Model:
cb_params.parallel_mode = self._parallel_mode
cb_params.device_number = self._device_number
cb_params.train_dataset = train_dataset
cb_params.list_callback = callbacks
cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.train_dataset_element = None
# build callback list
with _CallbackManager(callbacks) as list_callback:
......@@ -358,6 +361,17 @@ class Model:
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
@staticmethod
def _transform_callbacks(callbacks):
"""Transform callback to a list."""
if callbacks is None:
return []
if isinstance(callbacks, Iterable):
return list(callbacks)
return [callbacks]
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
Training process. The data would be passed to network through dataset channel.
......@@ -449,6 +463,7 @@ class Model:
scaling_sens = self._get_scaling_sens()
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
cb_params.train_dataset_element = next_element
outputs = self._train_network(*next_element)
cb_params.net_outputs = outputs
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
......@@ -628,6 +643,7 @@ class Model:
cb_params.batch_num = valid_dataset.get_dataset_size()
cb_params.mode = "eval"
cb_params.cur_step_num = 0
cb_params.list_callback = self._transform_callbacks(callbacks)
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
......
......@@ -12,45 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SummaryStep Callback class."""
"""Summary's enumeration file."""
from enum import Enum
from ._callback import Callback
class BaseEnum(Enum):
"""The base enum class."""
class SummaryStep(Callback):
"""
The summary callback class.
@classmethod
def to_list(cls):
"""Converts the enumeration into a list."""
return [member.value for member in cls.__members__.values()]
Args:
summary (Object): Summary recode object.
flush_step (int): Number of interval steps to execute. Default: 10.
"""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0:
raise ValueError("`flush_step` should be int and greater than 0")
self._summary = summary
self._flush_step = flush_step
class PluginEnum(BaseEnum):
"""The list of plugins currently supported by the summary."""
GRAPH = 'graph'
SCALAR = 'scalar'
IMAGE = 'image'
TENSOR = 'tensor'
HISTOGRAM = 'histogram'
TRAIN_LINEAGE = 'train_lineage'
EVAL_LINEAGE = 'eval_lineage'
DATASET_GRAPH = 'dataset_graph'
def __enter__(self):
self._summary.__enter__()
return self
def __exit__(self, *err):
return self._summary.__exit__(*err)
def step_end(self, run_context):
"""
Save summary.
Args:
run_context (RunContext): Context of the train running.
"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % self._flush_step == 0:
self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property
def summary_file_name(self):
return self._summary.full_file_name
class ModeEnum(BaseEnum):
"""The modes currently supported by the summary."""
TRAIN = 'train'
EVAL = 'eval'
......@@ -75,7 +75,7 @@ class TestGpuSummary:
if not os.path.exists(self.summary_dir):
os.mkdir(self.summary_dir)
def teardown_emthod(self):
def teardown_method(self):
"""Run after method."""
if os.path.exists(self.summary_dir):
shutil.rmtree(self.summary_dir)
......
......@@ -20,8 +20,8 @@ import numpy as np
import mindspore.nn as nn
from mindspore import Model, context
from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord
from mindspore.train.summary import SummaryRecord
from mindspore.train.callback import SummaryCollector
from .....dataset_mock import MindData
CUR_DIR = os.getcwd()
......@@ -107,16 +107,9 @@ def test_graph_summary_sample():
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
model.train(2, dataset)
# step 2: create the Event
for i in range(1, 5):
test_writer.record(i)
# step 3: send the event to mq
# step 4: accept the event and write the file
log.debug("finished test_graph_summary_sample")
def test_graph_summary_callback():
dataset = get_dataset()
......@@ -125,18 +118,8 @@ def test_graph_summary_callback():
optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=model._train_network) as test_writer:
summary_cb = SummaryStep(test_writer, 1)
model.train(2, dataset, callbacks=summary_cb)
def test_graph_summary_callback2():
dataset = get_dataset()
net = Net()
loss = nn.SoftmaxCrossEntropyWithLogits()
optim = Momentum(net.trainable_params(), 0.1, 0.9)
context.set_context(mode=context.GRAPH_MODE)
model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
with SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=net) as test_writer:
summary_cb = SummaryStep(test_writer, 1)
model.train(2, dataset, callbacks=summary_cb)
summary_collector = SummaryCollector(SUMMARY_DIR,
collect_freq=1,
keep_default_action=False,
collect_specified_data={'collect_graph': True})
model.train(1, dataset, callbacks=[summary_collector])
......@@ -26,9 +26,8 @@ import mindspore.nn as nn
from mindspore import Model, context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord, \
_cache_summary_tensor_data
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from mindspore.train.callback import Callback
from .....dataset_mock import MindData
CUR_DIR = os.getcwd()
......@@ -155,7 +154,8 @@ def get_dataset():
return dataset
class ImageSummaryCallback:
class ImageSummaryCallback(Callback):
"""Image summary callback."""
def __init__(self, summary_record):
self._summary_record = summary_record
......@@ -164,9 +164,10 @@ class ImageSummaryCallback:
return self
def __exit__(self, *err):
pass
self._summary_record.close()
def record(self, step, train_network=None):
"""record data."""
self._summary_record.record(step, train_network)
self._summary_record.flush()
......@@ -183,9 +184,8 @@ def test_image_summary_train():
# step 2: create the Event
model = get_model()
fn = ImageSummaryCallback(test_writer)
summary_recode = SummaryStep(fn, 1)
model.train(2, dataset, callbacks=summary_recode)
callback = ImageSummaryCallback(test_writer)
model.train(2, dataset, callbacks=[callback])
# step 3: send the event to mq
......
......@@ -24,11 +24,9 @@ import random
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
CUR_DIR = os.getcwd()
......@@ -192,16 +190,6 @@ def test_scalar_summary_with_ge_2():
def test_validate():
with SummaryRecord(SUMMARY_DIR) as sr:
with pytest.raises(ValueError):
SummaryStep(sr, 0)
with pytest.raises(ValueError):
SummaryStep(sr, -1)
with pytest.raises(ValueError):
SummaryStep(sr, 1.2)
with pytest.raises(ValueError):
SummaryStep(sr, True)
with pytest.raises(ValueError):
SummaryStep(sr, "str")
sr.record(1)
with pytest.raises(ValueError):
sr.record(False)
......@@ -215,17 +203,3 @@ def test_validate():
sr.record("str")
with pytest.raises(ValueError):
sr.record(sr)
SummaryStep(sr, 1)
with pytest.raises(ValueError):
SummaryStep(sr, 1.2)
with pytest.raises(ValueError):
SummaryStep(sr, False)
with pytest.raises(ValueError):
SummaryStep(sr, "str")
with pytest.raises(ValueError):
SummaryStep(sr, (1, 2))
with pytest.raises(ValueError):
SummaryStep(sr, [3, 4])
with pytest.raises(ValueError):
SummaryStep(sr, sr)
......@@ -59,7 +59,8 @@ def test_summaryrecord_input_null_string():
log.debug("begin test_summaryrecord_input_null_string")
# step 0: create the thread
try:
SummaryRecord("")
with SummaryRecord(""):
pass
except:
assert True
else:
......@@ -71,7 +72,8 @@ def test_summaryrecord_input_None():
log.debug("begin test_summaryrecord_input_None")
# step 0: create the thread
try:
SummaryRecord(None)
with SummaryRecord(None):
pass
except:
assert True
else:
......@@ -83,7 +85,8 @@ def test_summaryrecord_input_relative_dir_1():
log.debug("begin test_summaryrecord_input_relative_dir_1")
# step 0: create the thread
try:
SummaryRecord("./test_temp_summary_event_file/")
with SummaryRecord("./test_temp_summary_event_file/"):
pass
except:
assert False
else:
......@@ -95,7 +98,8 @@ def test_summaryrecord_input_relative_dir_2():
log.debug("begin test_summaryrecord_input_relative_dir_2")
# step 0: create the thread
try:
SummaryRecord("../summary/")
with SummaryRecord("../summary/"):
pass
except:
assert False
else:
......@@ -107,7 +111,8 @@ def test_summaryrecord_input_invalid_type_dir():
log.debug("begin test_summaryrecord_input_invalid_type_dir")
# step 0: create the thread
try:
SummaryRecord(32)
with SummaryRecord(32):
pass
except:
assert True
else:
......@@ -119,7 +124,8 @@ def test_mulit_layer_directory():
log.debug("begin test_mulit_layer_directory")
# step 0: create the thread
try:
SummaryRecord("./test_temp_summary_event_file/test/t1/")
with SummaryRecord("./test_temp_summary_event_file/test/t1/"):
pass
except:
assert False
else:
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the exception parameter scenario for summary collector."""
import os
import tempfile
import shutil
import pytest
from mindspore.train.callback import SummaryCollector
class TestSummaryCollector:
"""Test the exception parameter for summary collector."""
base_summary_dir = ''
def setup_class(self):
"""Run before test this class."""
self.base_summary_dir = tempfile.mkdtemp(suffix='summary')
def teardown_class(self):
"""Run after test this class."""
if os.path.exists(self.base_summary_dir):
shutil.rmtree(self.base_summary_dir)
@pytest.mark.parametrize("summary_dir", [1234, None, True, ''])
def test_params_with_summary_dir_value_error(self, summary_dir):
"""Test the exception scenario for summary dir."""
if isinstance(summary_dir, str):
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert str(exc.value) == 'For `summary_dir` the value should be a valid string of path, ' \
'but got empty string.'
else:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir)
assert 'For `summary_dir` the type should be a valid type' in str(exc.value)
def test_params_with_summary_dir_not_dir(self):
"""Test the given summary dir parameter is not a directory."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
summary_file = os.path.join(summary_dir, 'temp_file.txt')
with open(summary_file, 'w') as file_handle:
file_handle.write('temp')
print(os.path.isfile(summary_file))
with pytest.raises(NotADirectoryError):
SummaryCollector(summary_dir=summary_file)
@pytest.mark.parametrize("collect_freq", [None, 0, 0.01])
def test_params_with_collect_freq_exception(self, collect_freq):
"""Test the exception scenario for collect freq."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
if isinstance(collect_freq, int):
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f'For `collect_freq` the value should be greater than 0, but got `{collect_freq}`.'
assert expected_msg == str(exc.value)
else:
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, collect_freq=collect_freq)
expected_msg = f"For `collect_freq` the type should be a valid type of ['int'], " \
f'bug got {type(collect_freq).__name__}.'
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("action", [None, 123, '', '123'])
def test_params_with_action_exception(self, action):
"""Test the exception scenario for action."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir=summary_dir, keep_default_action=action)
expected_msg = f"For `keep_default_action` the type should be a valid type of ['bool'], " \
f"bug got {type(action).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [123])
def test_params_with_collect_specified_data_type_error(self, collect_specified_data):
"""Test type error scenario for collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
expected_msg = f"For `collect_specified_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(collect_specified_data).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [
{
123: 123
},
{
None: True
}
])
def test_params_with_collect_specified_data_key_type_error(self, collect_specified_data):
"""Test the key of collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data)[0]
expected_msg = f"For `{param_name}` the type should be a valid type of ['str'], " \
f"bug got {type(param_name).__name__}."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("collect_specified_data", [
{
'collect_metric': None
},
{
'collect_graph': 123
},
{
'histogram_regular': 123
},
])
def test_params_with_collect_specified_data_value_type_error(self, collect_specified_data):
"""Test the value of collect specified data param."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, collect_specified_data=collect_specified_data)
param_name = list(collect_specified_data)[0]
param_value = collect_specified_data[param_name]
expected_type = "['bool']" if param_name != 'histogram_regular' else "['str', 'NoneType']"
expected_msg = f'For `{param_name}` the type should be a valid type of {expected_type}, ' \
f'bug got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
def test_params_with_collect_specified_data_unexpected_key(self):
"""Test the collect_specified_data parameter with unexpected key."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
data = {'unexpected_key': True}
with pytest.raises(ValueError) as exc:
SummaryCollector(summary_dir, collect_specified_data=data)
expected_msg = f"For `collect_specified_data` the keys {set(data)} are unsupported."
assert expected_msg == str(exc.value)
@pytest.mark.parametrize("custom_lineage_data", [
123,
{
'custom': {}
},
{
'custom': None
},
{
123: 'custom'
}
])
def test_params_with_custom_lineage_data_type_error(self, custom_lineage_data):
"""Test the custom lineage data parameter type error."""
summary_dir = tempfile.mkdtemp(dir=self.base_summary_dir)
with pytest.raises(TypeError) as exc:
SummaryCollector(summary_dir, custom_lineage_data=custom_lineage_data)
if not isinstance(custom_lineage_data, dict):
expected_msg = f"For `custom_lineage_data` the type should be a valid type of ['dict', 'NoneType'], " \
f"bug got {type(custom_lineage_data).__name__}."
else:
param_name = list(custom_lineage_data)[0]
param_value = custom_lineage_data[param_name]
if not isinstance(param_name, str):
arg_name = f'custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['str'], " \
f'bug got {type(param_name).__name__}.'
else:
arg_name = f'the value of custom_lineage_data -> {param_name}'
expected_msg = f"For `{arg_name}` the type should be a valid type of ['int', 'str', 'float'], " \
f'bug got {type(param_value).__name__}.'
assert expected_msg == str(exc.value)
......@@ -20,8 +20,8 @@ import pytest
import mindspore.nn as nn
from mindspore import Model, context
from mindspore import Tensor
from mindspore.train.callback import Callback
from mindspore.nn.optim import Momentum
from mindspore.train.callback import SummaryStep
from ..ut_filter import non_graph_engine
from ....dataset_mock import MindData
......@@ -174,7 +174,7 @@ class TestGraphMode:
model.train(1, dataset)
class CallbackTest:
class CallbackTest(Callback):
""" CallbackTest definition """
def __init__(self):
......@@ -186,19 +186,19 @@ class CallbackTest:
def __exit__(self, *err):
pass
def record(self, step, *args):
print(step, args)
def step_end(self, run_context):
cb_params = run_context.original_args()
print(cb_params.cur_epoch_num, cb_params.cur_step_num)
def test_train_callback(test_with_simu):
""" test_train_callback """
dataset = get_dataset()
model = get_model()
fn = CallbackTest()
summary_recode = SummaryStep(fn, 2)
callback = CallbackTest()
if test_with_simu:
return
model.train(2, dataset, callbacks=summary_recode)
model.train(2, dataset, callbacks=callback)
log = logging.getLogger("test")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册