提交 820f2cb4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5351 move ParalleMode to Context

Merge pull request !5351 from yao_yf/parallel_context_collation
...@@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context ...@@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context
_reset_auto_parallel_context _reset_auto_parallel_context
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context'] 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
GRAPH_MODE = 0 GRAPH_MODE = 0
PYNATIVE_MODE = 1 PYNATIVE_MODE = 1
...@@ -647,3 +647,26 @@ def get_context(attr_key): ...@@ -647,3 +647,26 @@ def get_context(attr_key):
raise ValueError( raise ValueError(
"Get context keyword %s is not recognized!" % attr_key) "Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key) return getattr(_context(), attr_key)
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
...@@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter ...@@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator from mindspore._checkparam import Validator
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator, Rel from ..._checkparam import Validator as validator, Rel
...@@ -130,8 +130,8 @@ class EmbeddingLookup(Cell): ...@@ -130,8 +130,8 @@ class EmbeddingLookup(Cell):
param_init (str): The initialize way of embedding table. Default: 'normal'. param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'. ['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'. nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode. manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs: Inputs:
......
...@@ -29,7 +29,7 @@ from mindspore._checkparam import Validator as validator ...@@ -29,7 +29,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule from mindspore.nn.learning_rate_schedule import LearningRateSchedule
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Cell_wrapper.""" """Cell_wrapper."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C from ...ops import composite as C
......
...@@ -251,8 +251,9 @@ class DistributedGradReducer(Cell): ...@@ -251,8 +251,9 @@ class DistributedGradReducer(Cell):
>>> from mindspore.ops import operations as P >>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Loss scale cell for loss scale training.""" """Loss scale cell for loss scale training."""
import mindspore.context as context import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, RowTensor from ...common import Tensor, RowTensor
......
...@@ -18,8 +18,7 @@ High-Level training interfaces. ...@@ -18,8 +18,7 @@ High-Level training interfaces.
Helper functions in train piplines. Helper functions in train piplines.
""" """
from .model import Model from .model import Model
from .parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
from . import amp from . import amp
__all__ = ["Model", "ParallelMode", "DatasetHelper", "amp"] __all__ = ["Model", "DatasetHelper", "amp"]
...@@ -23,7 +23,7 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell ...@@ -23,7 +23,7 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode from ..parallel._utils import _get_parallel_mode
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .parallel_utils import ParallelMode from ..context import ParallelMode
from .. import context from .. import context
__all__ = ["build_train_network"] __all__ = ["build_train_network"]
......
...@@ -30,7 +30,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r ...@@ -30,7 +30,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..nn.metrics import Loss from ..nn.metrics import Loss
from .. import nn from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._utils import _need_to_full, _to_full_tensor
from ..common import dtype as mstype from ..common import dtype as mstype
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parallel utils"""
__all__ = ["ParallelMode"]
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
...@@ -17,7 +17,8 @@ import argparse ...@@ -17,7 +17,8 @@ import argparse
from mindspore import context from mindspore import context
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore import Model, ParallelMode from mindspore import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset from src.md_dataset import create_dataset
......
...@@ -26,7 +26,8 @@ import mindspore.common.dtype as mstype ...@@ -26,7 +26,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD from mindspore.nn import SGD
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
......
...@@ -28,7 +28,8 @@ from mindspore import context ...@@ -28,7 +28,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cifar_cfg as cfg from src.config import cifar_cfg as cfg
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.rmsprop import RMSProp from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
......
...@@ -24,7 +24,8 @@ import mindspore.common.dtype as mstype ...@@ -24,7 +24,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD from mindspore.nn import SGD
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
......
...@@ -30,7 +30,8 @@ from mindspore.nn.loss.loss import _Loss ...@@ -30,7 +30,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
......
...@@ -28,7 +28,8 @@ from mindspore.nn.loss.loss import _Loss ...@@ -28,7 +28,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
...@@ -22,7 +22,8 @@ from mindspore import Tensor ...@@ -22,7 +22,8 @@ from mindspore import Tensor
from mindspore import dataset as de from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
......
...@@ -21,7 +21,8 @@ from mindspore import context ...@@ -21,7 +21,8 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
......
...@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell): ...@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
from mindspore.train.callback import RunContext from mindspore.train.callback import RunContext
from mindspore import context from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full, _to_full_tensor from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
......
...@@ -22,7 +22,7 @@ from mindspore import context ...@@ -22,7 +22,7 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import dataset as de from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
......
...@@ -20,7 +20,7 @@ import datetime ...@@ -20,7 +20,7 @@ import datetime
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, context from mindspore import Tensor, context
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint from mindspore.train.callback import ModelCheckpoint
......
...@@ -19,6 +19,7 @@ import mindspore.common.dtype as mstype ...@@ -19,6 +19,7 @@ import mindspore.common.dtype as mstype
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter, context, Tensor from mindspore import Parameter, context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -388,7 +389,7 @@ class TrainingWrapper(nn.Cell): ...@@ -388,7 +389,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")
......
...@@ -21,7 +21,8 @@ import mindspore.nn as nn ...@@ -21,7 +21,8 @@ import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from src.config import config from src.config import config
......
...@@ -29,7 +29,8 @@ from mindspore import context ...@@ -29,7 +29,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.dataset import vgg_create_dataset from src.dataset import vgg_create_dataset
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import numpy as np import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
......
...@@ -21,7 +21,8 @@ import numpy as np ...@@ -21,7 +21,8 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore import dataset as de from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.nn.wrap import WithLossCell from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank
......
...@@ -25,7 +25,7 @@ from pycocotools.coco import COCO ...@@ -25,7 +25,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from mindspore import Tensor from mindspore import Tensor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms import mindspore as ms
......
...@@ -17,6 +17,7 @@ import mindspore as ms ...@@ -17,6 +17,7 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore import context from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell): ...@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")
......
...@@ -18,7 +18,7 @@ import time ...@@ -18,7 +18,7 @@ import time
import argparse import argparse
import datetime import datetime
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
......
...@@ -25,7 +25,7 @@ from pycocotools.coco import COCO ...@@ -25,7 +25,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval
from mindspore import Tensor from mindspore import Tensor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms import mindspore as ms
......
...@@ -17,6 +17,7 @@ import mindspore as ms ...@@ -17,6 +17,7 @@ import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore import context from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell): ...@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")
......
...@@ -19,7 +19,7 @@ import time ...@@ -19,7 +19,7 @@ import time
import argparse import argparse
import datetime import datetime
from mindspore import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
...@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell): ...@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")
......
...@@ -29,7 +29,8 @@ import mindspore.nn as nn ...@@ -29,7 +29,8 @@ import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
......
...@@ -24,7 +24,7 @@ import mindspore.communication.management as D ...@@ -24,7 +24,7 @@ import mindspore.communication.management as D
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
...@@ -25,7 +25,7 @@ from mindspore.common.tensor import Tensor ...@@ -25,7 +25,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from .bert_for_pre_training import clip_grad from .bert_for_pre_training import clip_grad
......
...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor ...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from mindspore.ops import _selected_ops from mindspore.ops import _selected_ops
......
...@@ -35,7 +35,7 @@ from mindspore import log as logger ...@@ -35,7 +35,7 @@ from mindspore import log as logger
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
_current_dir = os.path.dirname(os.path.realpath(__file__)) _current_dir = os.path.dirname(os.path.realpath(__file__))
......
...@@ -27,7 +27,7 @@ from mindspore.ops import _selected_ops ...@@ -27,7 +27,7 @@ from mindspore.ops import _selected_ops
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from .bert_model import BertModel from .bert_model import BertModel
from .config import cfg from .config import cfg
from .lr_generator import get_bert_damping from .lr_generator import get_bert_damping
......
...@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell): ...@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
...@@ -36,7 +36,7 @@ from mindspore.parallel._utils import _need_to_full ...@@ -36,7 +36,7 @@ from mindspore.parallel._utils import _need_to_full
from mindspore.train import amp from mindspore.train import amp
from mindspore.parallel._utils import _to_full_tensor from mindspore.parallel._utils import _to_full_tensor
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
......
...@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor ...@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from .transformer import Transformer from .transformer import Transformer
......
...@@ -26,7 +26,8 @@ from mindspore.nn.optim import Adam, Lamb ...@@ -26,7 +26,8 @@ from mindspore.nn.optim import Adam, Lamb
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore import context, ParallelMode, Parameter from mindspore import context, Parameter
from mindspore.context import ParallelMode
from mindspore.communication import management as MultiAscend from mindspore.communication import management as MultiAscend
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
......
...@@ -24,7 +24,7 @@ import mindspore.common.dtype as mstype ...@@ -24,7 +24,7 @@ import mindspore.common.dtype as mstype
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor from mindspore.train.callback import TimeMonitor
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.optim import AdamWeightDecay from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import log as logger from mindspore import log as logger
......
...@@ -26,7 +26,7 @@ from mindspore.common import dtype as mstype ...@@ -26,7 +26,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS from .tinybert_model import BertModel, TinyBertModel, BertModelCLS
......
...@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor ...@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
......
...@@ -29,7 +29,7 @@ from mindspore.train.callback import Callback, TimeMonitor ...@@ -29,7 +29,7 @@ from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.communication.management as D import mindspore.communication.management as D
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \ from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
......
...@@ -19,7 +19,8 @@ import argparse ...@@ -19,7 +19,8 @@ import argparse
import random import random
import numpy as np import numpy as np
from mindspore import context, ParallelMode from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
......
...@@ -17,7 +17,7 @@ callbacks ...@@ -17,7 +17,7 @@ callbacks
import time import time
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore import context from mindspore import context
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank from mindspore.communication.management import get_rank
def add_write(file_path, out_str): def add_write(file_path, out_str):
......
...@@ -23,7 +23,7 @@ from mindspore.ops import operations as P ...@@ -23,7 +23,7 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL, LazyAdam from mindspore.nn.optim import Adam, FTRL, LazyAdam
from mindspore.common.initializer import Uniform, initializer from mindspore.common.initializer import Uniform, initializer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import numpy as np import numpy as np
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import numpy as np import numpy as np
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
...@@ -24,7 +24,7 @@ from mindspore.ops import operations as P ...@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout, Flatten from mindspore.nn import Dropout, Flatten
from mindspore.nn.optim import Adam, FTRL from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer from mindspore.common.initializer import Uniform, initializer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import TimeMonitor from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
...@@ -28,7 +28,8 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -28,7 +28,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) context.set_context(device_id=int(os.getenv('DEVICE_ID')))
......
...@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
......
...@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import sys import sys
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import TimeMonitor from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
......
...@@ -25,7 +25,7 @@ from mindspore.nn.optim import Adam, FTRL ...@@ -25,7 +25,7 @@ from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
import numpy as np import numpy as np
......
...@@ -20,7 +20,7 @@ import sys ...@@ -20,7 +20,7 @@ import sys
import numpy as np import numpy as np
from mindspore import Model, context from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
...@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell): ...@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False self.reducer_flag = False
self.grad_reducer = None self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("mirror_mean")
......
...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor ...@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from .bert_model import BertModel from .bert_model import BertModel
......
...@@ -26,7 +26,7 @@ from mindspore.common.tensor import Tensor ...@@ -26,7 +26,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
def _send_data(dataset): def _send_data(dataset):
"""Engine dataset to write data to tdt queue.""" """Engine dataset to write data to tdt queue."""
......
...@@ -103,7 +103,8 @@ class DistributedGradReducerThor(Cell): ...@@ -103,7 +103,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F >>> from mindspore.ops import functional as F
>>> from mindspore import context >>> from mindspore import context
>>> from mindspore import nn >>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple >>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>> >>>
>>> device_id = int(os.environ["DEVICE_ID"]) >>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
...@@ -30,7 +30,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_ ...@@ -30,7 +30,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper
......
...@@ -24,7 +24,8 @@ import numpy as np ...@@ -24,7 +24,8 @@ import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.nn as nn import mindspore.nn as nn
......
...@@ -32,7 +32,8 @@ from mindspore.communication.management import init ...@@ -32,7 +32,8 @@ from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
......
...@@ -32,7 +32,8 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -32,7 +32,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
......
...@@ -25,7 +25,7 @@ from mindspore.common.api import _executor ...@@ -25,7 +25,7 @@ from mindspore.common.api import _executor
from mindspore.nn import Momentum from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
class DenseMMNet(nn.Cell): class DenseMMNet(nn.Cell):
......
...@@ -21,7 +21,8 @@ import numpy as np ...@@ -21,7 +21,8 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Model, ParallelMode from mindspore import Tensor, Model
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
......
...@@ -19,7 +19,8 @@ import numpy as np ...@@ -19,7 +19,8 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Model, ParallelMode from mindspore import Tensor, Model
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import TensorAdd
from ....dataset_mock import MindData from ....dataset_mock import MindData
......
...@@ -26,7 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell ...@@ -26,7 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from tests.ops_common import convert from tests.ops_common import convert
from ....train_step_wrap import train_step_with_loss_warp from ....train_step_wrap import train_step_with_loss_warp
......
...@@ -22,7 +22,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -22,7 +22,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -23,7 +23,8 @@ from mindspore.common.parameter import Parameter ...@@ -23,7 +23,8 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
......
...@@ -29,7 +29,8 @@ from mindspore.ops import operations as P ...@@ -29,7 +29,8 @@ from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0) context.set_context(device_id=0)
......
...@@ -26,7 +26,8 @@ from mindspore.nn.layer.pooling import MaxPool2d ...@@ -26,7 +26,8 @@ from mindspore.nn.layer.pooling import MaxPool2d
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
dev_num = 8 dev_num = 8
......
...@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -22,7 +22,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple ...@@ -22,7 +22,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C, functional as F, operations as P from mindspore.ops import composite as C, functional as F, operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -23,7 +23,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -23,7 +23,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
class Dataset(MindData): class Dataset(MindData):
......
...@@ -27,7 +27,8 @@ from mindspore.nn.optim import Momentum ...@@ -27,7 +27,8 @@ from mindspore.nn.optim import Momentum
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
device_number = 32 device_number = 32
......
...@@ -25,7 +25,8 @@ from mindspore.ops import functional as F ...@@ -25,7 +25,8 @@ from mindspore.ops import functional as F
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -25,7 +25,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -25,7 +25,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
......
...@@ -25,7 +25,8 @@ from mindspore.nn.optim.momentum import Momentum ...@@ -25,7 +25,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
......
...@@ -29,7 +29,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -29,7 +29,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd from mindspore.ops.operations import TensorAdd
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
dev_num = 8 dev_num = 8
......
...@@ -23,7 +23,7 @@ from mindspore.nn import Dense ...@@ -23,7 +23,7 @@ from mindspore.nn import Dense
from mindspore.nn import Momentum from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode from mindspore.context import ParallelMode
class Net(nn.Cell): class Net(nn.Cell):
......
...@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits ...@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
......
...@@ -28,7 +28,8 @@ from mindspore.ops import functional as F ...@@ -28,7 +28,8 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
......
...@@ -21,7 +21,8 @@ from mindspore.common.parameter import Parameter ...@@ -21,7 +21,8 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData from tests.dataset_mock import MindData
......
...@@ -20,7 +20,8 @@ import mindspore.context as context ...@@ -20,7 +20,8 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore import amp from mindspore import amp
from mindspore import nn from mindspore import nn
from mindspore.train import Model, ParallelMode from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from ....dataset_mock import MindData from ....dataset_mock import MindData
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册