From e712c6cfe5c4f0a98958276bb1a1b03e854082c9 Mon Sep 17 00:00:00 2001 From: lichenever Date: Wed, 22 Jul 2020 16:50:31 +0800 Subject: [PATCH] autoparallel support dataset in gpu --- mindspore/ccsrc/frontend/parallel/context.cc | 6 ++++-- mindspore/train/model.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 7164660be..9da8f0a65 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -44,7 +44,10 @@ std::shared_ptr ParallelContext::GetInstance() { return inst_context_; } -ParallelContext::ParallelContext() { Reset(); } +ParallelContext::ParallelContext() { + communication_backend_ = HCCL_BACKEND; + Reset(); +} void ParallelContext::Reset() { mirror_mean_ = false; @@ -53,7 +56,6 @@ void ParallelContext::Reset() { loss_repeated_mean_ = true; device_num_ = 1; global_rank_ = 0; - communication_backend_ = HCCL_BACKEND; device_num_is_set_ = false; global_rank_is_set_ = false; parallel_mode_ = STAND_ALONE; diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 072610802..a439951b6 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -30,6 +30,8 @@ from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from .parallel_utils import ParallelMode +from ._utils import _to_full_tensor +from ..parallel._utils import _need_to_full from ..common import dtype as mstype from .dataset_helper import DatasetHelper from . import amp @@ -418,6 +420,8 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: + if _need_to_full(): + inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) list_callback.step_begin(run_context) outputs = self._train_network(*inputs) cb_params.cur_step_num += dataset_helper.sink_size() -- GitLab