提交 e712c6cf 编写于 作者: L lichenever

autoparallel support dataset in gpu

上级 8f4bab4e
......@@ -44,7 +44,10 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
return inst_context_;
}
ParallelContext::ParallelContext() { Reset(); }
ParallelContext::ParallelContext() {
communication_backend_ = HCCL_BACKEND;
Reset();
}
void ParallelContext::Reset() {
mirror_mean_ = false;
......@@ -53,7 +56,6 @@ void ParallelContext::Reset() {
loss_repeated_mean_ = true;
device_num_ = 1;
global_rank_ = 0;
communication_backend_ = HCCL_BACKEND;
device_num_is_set_ = false;
global_rank_is_set_ = false;
parallel_mode_ = STAND_ALONE;
......
......@@ -30,6 +30,8 @@ from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from .parallel_utils import ParallelMode
from ._utils import _to_full_tensor
from ..parallel._utils import _need_to_full
from ..common import dtype as mstype
from .dataset_helper import DatasetHelper
from . import amp
......@@ -418,6 +420,8 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full():
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册