提交 820f2cb4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5351 move ParalleMode to Context

Merge pull request !5351 from yao_yf/parallel_context_collation
......@@ -28,7 +28,7 @@ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context
_reset_auto_parallel_context
__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
'get_auto_parallel_context', 'reset_auto_parallel_context']
'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode']
GRAPH_MODE = 0
PYNATIVE_MODE = 1
......@@ -647,3 +647,26 @@ def get_context(attr_key):
raise ValueError(
"Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key)
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
......@@ -20,7 +20,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore._checkparam import Validator
from mindspore.communication.management import get_group_size
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
from ..cell import Cell
from ..._checkparam import Validator as validator, Rel
......@@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi_auto_parallel/auto_parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.
Inputs:
......
......@@ -29,7 +29,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
......
......@@ -15,7 +15,7 @@
"""Cell_wrapper."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from ...common import dtype as mstype
from ...common.parameter import Parameter, ParameterTuple
from ...ops import composite as C
......
......@@ -251,8 +251,9 @@ class DistributedGradReducer(Cell):
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore.context import ParallelMode
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
......@@ -15,7 +15,7 @@
"""Loss scale cell for loss scale training."""
import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell
from ...common import Tensor, RowTensor
......
......@@ -18,8 +18,7 @@ High-Level training interfaces.
Helper functions in train piplines.
"""
from .model import Model
from .parallel_utils import ParallelMode
from .dataset_helper import DatasetHelper
from . import amp
__all__ = ["Model", "ParallelMode", "DatasetHelper", "amp"]
__all__ = ["Model", "DatasetHelper", "amp"]
......@@ -23,7 +23,7 @@ from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode
from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from .parallel_utils import ParallelMode
from ..context import ParallelMode
from .. import context
__all__ = ["build_train_network"]
......
......@@ -30,7 +30,7 @@ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_r
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parallel utils"""
__all__ = ["ParallelMode"]
class ParallelMode:
"""
Parallel mode options.
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
"""
STAND_ALONE = "stand_alone"
DATA_PARALLEL = "data_parallel"
HYBRID_PARALLEL = "hybrid_parallel"
SEMI_AUTO_PARALLEL = "semi_auto_parallel"
AUTO_PARALLEL = "auto_parallel"
MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
......@@ -17,7 +17,8 @@ import argparse
from mindspore import context
from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum
from mindspore import Model, ParallelMode
from mindspore import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import Callback, CheckpointConfig, ModelCheckpoint, TimeMonitor
from src.md_dataset import create_dataset
......
......@@ -26,7 +26,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD
import mindspore.dataset.engine as de
......
......@@ -28,7 +28,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import cifar_cfg as cfg
......
......@@ -21,7 +21,7 @@ import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.rmsprop import RMSProp
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
......
......@@ -24,7 +24,8 @@ import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SGD
import mindspore.dataset.engine as de
......
......@@ -30,7 +30,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
......@@ -22,7 +22,8 @@ import numpy as np
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint
......
......@@ -28,7 +28,8 @@ from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
......@@ -22,7 +22,8 @@ from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
......
......@@ -21,7 +21,8 @@ from mindspore import context
from mindspore import Tensor
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint
......
......@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
......@@ -18,7 +18,7 @@ import math
from mindspore.train.callback import RunContext
from mindspore import context
from mindspore import nn
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.model import Model
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype
......
......@@ -22,7 +22,7 @@ from mindspore import context
from mindspore import Tensor
from mindspore import dataset as de
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init, get_rank, get_group_size
......
......@@ -20,7 +20,7 @@ import datetime
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint
......
......@@ -19,6 +19,7 @@ import mindspore.common.dtype as mstype
import mindspore as ms
import mindspore.nn as nn
from mindspore import Parameter, context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
......@@ -388,7 +389,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
......
......@@ -21,7 +21,8 @@ import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from src.config import config
......
......@@ -29,7 +29,8 @@ from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from src.dataset import vgg_create_dataset
......
......@@ -16,7 +16,7 @@
import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
......
......@@ -21,7 +21,8 @@ import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_group_size, get_rank
......
......@@ -25,7 +25,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import Tensor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
......
......@@ -17,6 +17,7 @@ import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
......@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
......
......@@ -18,7 +18,7 @@ import time
import argparse
import datetime
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor
import mindspore.nn as nn
......
......@@ -25,7 +25,7 @@ from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import Tensor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms
......
......@@ -17,6 +17,7 @@ import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.ops import operations as P
......@@ -417,7 +418,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
......
......@@ -19,7 +19,7 @@ import time
import argparse
import datetime
from mindspore import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore import Tensor
from mindspore import context
......
......@@ -19,6 +19,7 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.common.initializer import TruncatedNormal
......@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
......
......@@ -29,7 +29,8 @@ import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
......
......@@ -24,7 +24,7 @@ import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
......@@ -25,7 +25,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from .bert_for_pre_training import clip_grad
......
......@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.ops import _selected_ops
......
......@@ -35,7 +35,7 @@ from mindspore import log as logger
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
_current_dir = os.path.dirname(os.path.realpath(__file__))
......
......@@ -27,7 +27,7 @@ from mindspore.ops import _selected_ops
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from .bert_model import BertModel
from .config import cfg
from .lr_generator import get_bert_damping
......
......@@ -102,7 +102,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
......@@ -36,7 +36,7 @@ from mindspore.parallel._utils import _need_to_full
from mindspore.train import amp
from mindspore.parallel._utils import _to_full_tensor
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from .dataset_helper import DatasetHelper
......
......@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from .transformer import Transformer
......
......@@ -26,7 +26,8 @@ from mindspore.nn.optim import Adam, Lamb
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore import context, ParallelMode, Parameter
from mindspore import context, Parameter
from mindspore.context import ParallelMode
from mindspore.communication import management as MultiAscend
from mindspore.train.serialization import load_checkpoint
......
......@@ -24,7 +24,7 @@ import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.callback import TimeMonitor
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import log as logger
......
......@@ -26,7 +26,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from .tinybert_model import BertModel, TinyBertModel, BertModelCLS
......
......@@ -22,7 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.communication.management import get_group_size
from mindspore import context
......
......@@ -29,7 +29,7 @@ from mindspore.train.callback import Callback, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.dataset.engine as de
import mindspore.communication.management as D
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore import context
from src.transformer_for_train import TransformerTrainOneStepCell, TransformerNetworkWithLoss, \
......
......@@ -19,7 +19,8 @@ import argparse
import random
import numpy as np
from mindspore import context, ParallelMode
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
......
......@@ -17,7 +17,7 @@ callbacks
import time
from mindspore.train.callback import Callback
from mindspore import context
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank
def add_write(file_path, out_str):
......
......@@ -23,7 +23,7 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout
from mindspore.nn.optim import Adam, FTRL, LazyAdam
from mindspore.common.initializer import Uniform, initializer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
......
......@@ -20,7 +20,7 @@ import sys
import mindspore.dataset.engine as de
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
......
......@@ -20,7 +20,7 @@ import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
......@@ -20,7 +20,7 @@ import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
......@@ -24,7 +24,7 @@ from mindspore.ops import operations as P
from mindspore.nn import Dropout, Flatten
from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
......
......@@ -20,7 +20,7 @@ import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
......@@ -28,7 +28,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters
from mindspore.train.callback import Callback
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
......
......@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1)
np.random.seed(1)
......
......@@ -30,7 +30,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1)
np.random.seed(1)
......
......@@ -19,7 +19,7 @@ import os
import sys
from mindspore import Model, context
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.parallel import set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
......
......@@ -25,7 +25,7 @@ from mindspore.nn.optim import Adam, FTRL
from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
import numpy as np
......
......@@ -20,7 +20,7 @@ import sys
import numpy as np
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
......
......@@ -19,6 +19,7 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.context import ParallelMode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.common.initializer import TruncatedNormal
......@@ -652,7 +653,7 @@ class TrainingWrapper(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
......
......@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from .bert_model import BertModel
......
......@@ -26,7 +26,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel
......
......@@ -16,7 +16,7 @@
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_full_shapes
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
def _send_data(dataset):
"""Engine dataset to write data to tdt queue."""
......
......@@ -103,7 +103,8 @@ class DistributedGradReducerThor(Cell):
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParallelMode, ParameterTuple
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
......
......@@ -30,7 +30,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from mindspore.train import amp
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from .dataset_helper import DatasetHelper
......
......@@ -24,7 +24,8 @@ import numpy as np
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.nn as nn
......
......@@ -32,7 +32,8 @@ from mindspore.communication.management import init
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1)
np.random.seed(1)
......
......@@ -32,7 +32,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import Callback
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
random.seed(1)
np.random.seed(1)
......
......@@ -25,7 +25,7 @@ from mindspore.common.api import _executor
from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
class DenseMMNet(nn.Cell):
......
......@@ -21,7 +21,8 @@ import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Model, ParallelMode
from mindspore import Tensor, Model
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
......
......@@ -19,7 +19,8 @@ import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Model, ParallelMode
from mindspore import Tensor, Model
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.ops.operations import TensorAdd
from ....dataset_mock import MindData
......
......@@ -26,7 +26,7 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from tests.ops_common import convert
from ....train_step_wrap import train_step_with_loss_warp
......
......@@ -22,7 +22,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
......
......@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
......
......@@ -23,7 +23,8 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
......
......@@ -29,7 +29,8 @@ from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
......
......@@ -26,7 +26,8 @@ from mindspore.nn.layer.pooling import MaxPool2d
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
dev_num = 8
......
......@@ -27,7 +27,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
......
......@@ -22,7 +22,8 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C, functional as F, operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
from tests.dataset_mock import MindData
......
......@@ -23,7 +23,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
class Dataset(MindData):
......
......@@ -27,7 +27,8 @@ from mindspore.nn.optim import Momentum
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE)
device_number = 32
......
......@@ -25,7 +25,8 @@ from mindspore.ops import functional as F
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.nn as nn
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
......
......@@ -25,7 +25,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel._utils import _reset_op_id
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE)
......
......@@ -25,7 +25,8 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
......
......@@ -29,7 +29,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.ops.operations import TensorAdd
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
dev_num = 8
......
......@@ -23,7 +23,7 @@ from mindspore.nn import Dense
from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
class Net(nn.Cell):
......
......@@ -24,7 +24,8 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE)
......
......@@ -28,7 +28,8 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.parallel import set_algo_parameters
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
......
......@@ -21,7 +21,8 @@ from mindspore.common.parameter import Parameter
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
......
......@@ -20,7 +20,8 @@ import mindspore.context as context
from mindspore import Tensor
from mindspore import amp
from mindspore import nn
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype
from ....dataset_mock import MindData
from mindspore.parallel._auto_parallel_context import auto_parallel_context
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册