diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 7164660be04eb13f2b5af19f4a0cc9162bfc4fbe..9da8f0a65bae27305e34c8e2227251f79643f3a2 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -44,7 +44,10 @@ std::shared_ptr ParallelContext::GetInstance() { return inst_context_; } -ParallelContext::ParallelContext() { Reset(); } +ParallelContext::ParallelContext() { + communication_backend_ = HCCL_BACKEND; + Reset(); +} void ParallelContext::Reset() { mirror_mean_ = false; @@ -53,7 +56,6 @@ void ParallelContext::Reset() { loss_repeated_mean_ = true; device_num_ = 1; global_rank_ = 0; - communication_backend_ = HCCL_BACKEND; device_num_is_set_ = false; global_rank_is_set_ = false; parallel_mode_ = STAND_ALONE; diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 07261080282198fff005abbc341b5d8d64e45a36..a439951b6debd98c61d72a67cd066ba4ded29d5f 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -30,6 +30,8 @@ from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from .parallel_utils import ParallelMode +from ._utils import _to_full_tensor +from ..parallel._utils import _need_to_full from ..common import dtype as mstype from .dataset_helper import DatasetHelper from . import amp @@ -418,6 +420,8 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: + if _need_to_full(): + inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) outputs = self._train_network(*inputs) cb_params.cur_step_num += dataset_helper.sink_size()