提交 228a448b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4 Fix pylint && merge the mocked mindspore in st and ut

Merge pull request !4 from yelihua/master
......@@ -168,7 +168,7 @@ class TrainLineage(Callback):
train_lineage = AnalyzeObject.get_network_args(
run_context_args, train_lineage
)
train_dataset = run_context_args.get('train_dataset')
callbacks = run_context_args.get('list_callback')
list_callback = getattr(callbacks, '_callbacks', [])
......@@ -601,7 +601,7 @@ class AnalyzeObject:
loss = None
else:
loss = run_context_args.get('net_outputs')
if loss:
log.info('Calculating loss...')
loss_numpy = loss.asnumpy()
......@@ -610,7 +610,7 @@ class AnalyzeObject:
train_lineage[Metadata.loss] = loss
else:
train_lineage[Metadata.loss] = None
# Analyze classname of optimizer, loss function and training network.
train_lineage[Metadata.optimizer] = type(optimizer).__name__ \
if optimizer else None
......
......@@ -21,7 +21,7 @@ import tempfile
import pytest
from .collection.model import mindspore
from ....utils import mindspore
sys.modules['mindspore'] = mindspore
......
......@@ -14,6 +14,6 @@
# ============================================================================
"""Import the mocked mindspore."""
import sys
from .lineagemgr.collection.model import mindspore
from ..utils import mindspore
sys.modules['mindspore'] = mindspore
......@@ -14,6 +14,6 @@
# ============================================================================
"""Import the mocked mindspore."""
import sys
from .collection.model import mindspore
from ...utils import mindspore
sys.modules['mindspore'] = mindspore
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore Interface."""
from .application.model_zoo.resnet import ResNet
from .common.tensor import Tensor
from .dataset import MindDataset
from .nn import *
from .train.callback import _ListCallback, Callback, RunContext, ModelCheckpoint, SummaryStep
from .train.summary import SummaryRecord
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore ResNet class."""
from ...nn.cell import Cell
class ResNet(Cell):
"""Mocked ResNet."""
def __init__(self):
super(ResNet, self).__init__()
self._cells = {}
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/common/tensor.py."""
import numpy as np
class Tensor:
"""Mock the MindSpore Tensor class."""
def __init__(self, value=0):
self._value = value
def asnumpy(self):
"""Get value in numpy format."""
return np.array(self._value)
def __repr__(self):
return str(self.asnumpy())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
def get_context(key):
"""Get key in context."""
context = {"device_id": 1}
return context.get(key)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset."""
from .engine import MindDataset
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock mindspore.dataset.engine."""
from .datasets import MindDataset, Dataset
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/dataset/engine/datasets.py."""
class Dataset:
"""Mock the MindSpore Dataset class."""
def __init__(self, dataset_size=None, dataset_path=None):
self.dataset_size = dataset_size
self.dataset_path = dataset_path
self.input = []
def get_dataset_size(self):
"""Mocked get_dataset_size."""
return self.dataset_size
class MindDataset(Dataset):
"""Mock the MindSpore MindDataset class."""
def __init__(self, dataset_size=None, dataset_file=None):
super(MindDataset, self).__init__(dataset_size)
self.dataset_file = dataset_file
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the mindspore.nn package."""
from .optim import Optimizer, Momentum
from .loss.loss import SoftmaxCrossEntropyWithLogits, _Loss
from .cell import Cell, WithLossCell, TrainOneStepWithLossScaleCell
__all__ = ['Optimizer', 'Momentum', 'SoftmaxCrossEntropyWithLogits',
'_Loss', 'Cell', 'WithLossCell',
'TrainOneStepWithLossScaleCell']
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
class Cell:
"""Mock the Cell class."""
def __init__(self, auto_prefix=True, pips=None):
if pips is None:
pips = dict()
self._auto_prefix = auto_prefix
self._pips = pips
@property
def auto_prefix(self):
"""The property of auto_prefix."""
return self._auto_prefix
@property
def pips(self):
"""The property of pips."""
return self._pips
class WithLossCell(Cell):
"""Mocked WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False, pips=backbone.pips)
self._backbone = backbone
self._loss_fn = loss_fn
class TrainOneStepWithLossScaleCell(Cell):
"""Mocked TrainOneStepWithLossScaleCell."""
def __init__(self):
super(TrainOneStepWithLossScaleCell, self).__init__()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore SoftmaxCrossEntropyWithLogits class."""
from ..cell import Cell
class _Loss(Cell):
"""Mocked _Loss."""
def __init__(self, reduction='mean'):
super(_Loss, self).__init__()
self.reduction = reduction
def construct(self, base, target):
"""Mocked construct function."""
raise NotImplementedError
class SoftmaxCrossEntropyWithLogits(_Loss):
"""Mocked SoftmaxCrossEntropyWithLogits."""
def __init__(self, weight=None):
super(SoftmaxCrossEntropyWithLogits, self).__init__(weight)
def construct(self, base, target):
"""Mocked construct."""
return 1
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/nn/optim.py."""
from .cell import Cell
class Parameter:
"""Mock the MindSpore Parameter class."""
def __init__(self, learning_rate):
self._name = "Parameter"
self.default_input = learning_rate
@property
def name(self):
"""The property of name."""
return self._name
def __repr__(self):
format_str = 'Parameter (name={name})'
return format_str.format(name=self._name)
class Optimizer(Cell):
"""Mock the MindSpore Optimizer class."""
def __init__(self, learning_rate):
super(Optimizer, self).__init__()
self.learning_rate = Parameter(learning_rate)
class Momentum(Optimizer):
"""Mock the MindSpore Momentum class."""
def __init__(self, learning_rate):
super(Momentum, self).__init__(learning_rate)
self.dynamic_lr = False
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore wrap package."""
from .loss_scale import TrainOneStepWithLossScaleCell
from .cell_wrapper import WithLossCell
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore cell_wrapper.py."""
from ..cell import Cell
class WithLossCell(Cell):
"""Mock the WithLossCell class."""
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__()
self._backbone = backbone
self._loss_fn = loss_fn
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock MindSpore loss_scale.py."""
from ..cell import Cell
class TrainOneStepWithLossScaleCell(Cell):
"""Mock the TrainOneStepWithLossScaleCell class."""
def __init__(self, network, optimizer):
super(TrainOneStepWithLossScaleCell, self).__init__()
self.network = network
self.optimizer = optimizer
def construct(self, data, label):
"""Mock the construct method."""
raise NotImplementedError
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/callback.py."""
import os
class RunContext:
"""Mock the RunContext class."""
def __init__(self, original_args=None):
self._original_args = original_args
self._stop_requested = False
def original_args(self):
"""Mock original_args."""
return self._original_args
def stop_requested(self):
"""Mock stop_requested method."""
return self._stop_requested
class Callback:
"""Mock the Callback class."""
def __init__(self):
pass
def begin(self, run_context):
"""Called once before network training."""
def epoch_begin(self, run_context):
"""Called before each epoch begin."""
class _ListCallback(Callback):
"""Mock the _ListCallabck class."""
def __init__(self, callbacks):
super(_ListCallback, self).__init__()
self._callbacks = callbacks
class ModelCheckpoint(Callback):
"""Mock the ModelCheckpoint class."""
def __init__(self, prefix='CKP', directory=None, config=None):
super(ModelCheckpoint, self).__init__()
self._prefix = prefix
self._directory = directory
self._config = config
self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt')
@property
def model_file_name(self):
"""Get the file name of model."""
return self._model_file_name
@property
def latest_ckpt_file_name(self):
"""Get the latest file name fo checkpoint."""
return self._latest_ckpt_file_name
class SummaryStep(Callback):
"""Mock the SummaryStep class."""
def __init__(self, summary, flush_step=10):
super(SummaryStep, self).__init__()
self._sumamry = summary
self._flush_step = flush_step
self.summary_file_name = summary.full_file_name
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
from .summary_record import SummaryRecord
__all__ = ["SummaryRecord"]
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MindSpore Mock Interface"""
import os
import time
class SummaryRecord:
"""Mock the MindSpore SummaryRecord class."""
def __init__(self,
log_dir: str,
file_prefix: str = "events.",
file_suffix: str = ".MS",
create_time=int(time.time())):
self.log_dir = log_dir
self.prefix = file_prefix
self.suffix = file_suffix
file_name = file_prefix + 'summary.' + str(create_time) + file_suffix
self.full_file_name = os.path.join(log_dir, file_name)
def flush(self):
"""Mock flush method."""
def close(self):
"""Mock close method."""
......@@ -31,6 +31,8 @@ from mindspore.train.callback import RunContext, ModelCheckpoint, SummaryStep
from mindspore.train.summary import SummaryRecord
@mock.patch('builtins.open')
@mock.patch('os.makedirs')
class TestModelLineage(TestCase):
"""Test TrainLineage and EvalLineage class in model_lineage.py."""
......@@ -51,9 +53,9 @@ class TestModelLineage(TestCase):
cls.summary_log_path = '/path/to/summary_log'
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_summary_record_exception(self, mock_validate_summary):
def test_summary_record_exception(self, *args):
"""Test SummaryRecord with exception."""
mock_validate_summary.return_value = None
args[0].return_value = None
summary_record = self.my_summary_record(self.summary_log_path)
with self.assertRaises(MindInsightException) as context:
self.my_train_module(summary_record=summary_record, raise_exception=1)
......@@ -150,9 +152,9 @@ class TestModelLineage(TestCase):
args[6].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_train_end_exception(self, mock_validate_summary):
def test_train_end_exception(self, *args):
"""Test TrainLineage.end method when exception."""
mock_validate_summary.return_value = True
args[0].return_value = True
train_lineage = self.my_train_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
train_lineage.end(self.run_context)
......@@ -218,9 +220,9 @@ class TestModelLineage(TestCase):
self.assertTrue('End error in TrainLineage:' in str(context.exception))
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_exception_train_id_none(self, mock_validate_summary):
def test_eval_exception_train_id_none(self, *args):
"""Test EvalLineage.end method with initialization error."""
mock_validate_summary.return_value = True
args[0].return_value = True
with self.assertRaises(MindInsightException) as context:
self.my_eval_module(self.my_summary_record(self.summary_log_path), raise_exception=2)
self.assertTrue('Invalid value for raise_exception.' in str(context.exception))
......@@ -242,9 +244,9 @@ class TestModelLineage(TestCase):
args[0].assert_called()
@mock.patch('mindinsight.lineagemgr.collection.model.model_lineage.validate_summary_record')
def test_eval_end_except_run_context(self, mock_validate_summary):
def test_eval_end_except_run_context(self, *args):
"""Test EvalLineage.end method when run_context is invalid.."""
mock_validate_summary.return_value = True
args[0].return_value = True
eval_lineage = self.my_eval_module(self.my_summary_record(self.summary_log_path), True)
with self.assertRaises(Exception) as context:
eval_lineage.end(self.run_context)
......@@ -284,8 +286,9 @@ class TestModelLineage(TestCase):
eval_lineage.end(self.my_run_context(self.run_context))
self.assertTrue('End error in EvalLineage' in str(context.exception))
def test_epoch_is_zero(self):
def test_epoch_is_zero(self, *args):
"""Test TrainLineage.end method."""
args[0].return_value = None
run_context = self.run_context
run_context['epoch_num'] = 0
with self.assertRaises(MindInsightException):
......@@ -345,7 +348,7 @@ class TestAnalyzer(TestCase):
)
res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train')
res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid')
assert res1 == {'step_num': 10,
assert res1 == {'step_num': 10,
'train_dataset_path': '/path/to',
'train_dataset_size': 50,
'epoch': 2}
......
......@@ -48,7 +48,7 @@ class WithLossCell(Cell):
class TrainOneStepWithLossScaleCell(Cell):
"""Mocked TrainOneStepWithLossScaleCell."""
def __init__(self, network, optimizer):
def __init__(self, network=None, optimizer=None):
super(TrainOneStepWithLossScaleCell, self).__init__()
self.network = network
self.optimizer = optimizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册