未验证 提交 ffd40860 编写于 作者: S ShenLiang 提交者: GitHub

[Hybrid Parallel] Support dp & mp in dygraph (#32323)

* support dp & mp
上级 4d69eeaa
......@@ -27,6 +27,9 @@ from .runtime_factory import RuntimeFactory
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
from ..meta_optimizers import HybridParallelOptimizer
def _inited_runtime_handler_(func):
......@@ -219,6 +222,9 @@ class Fleet(object):
if paddle.fluid.framework.in_dygraph_mode():
if self.worker_num() == 1:
# if worker_num is 1, should construct default topology & hcg
self._topology = tp.CommunicateTopology()
self._hcg = tp.HybridCommunicateGroup(self._topology)
return
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn(
......@@ -694,10 +700,12 @@ class Fleet(object):
self._context = {}
# TODO(shenliang03): This is a temporary solution to support amp. In the case of a dynamic graph,
# the optimizer is returned directly. This problem will be fixed in the future.
if paddle.fluid.framework.in_dygraph_mode():
return optimizer
if self.worker_num() > 1:
return HybridParallelOptimizer(optimizer, self._hcg,
self._user_defined_strategy)
else:
return optimizer
return self
@dygraph_only
......@@ -756,15 +764,22 @@ class Fleet(object):
"""
assert model is not None
self.model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
return self.model
assert model is not None, "model should not be None"
if self.worker_num() <= 1:
return model
if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
distributed_model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
distributed_model = ModelParallel(
model, self._hcg, strategy=self._user_defined_strategy)
return distributed_model
@dygraph_only
def state_dict(self):
......
......@@ -17,6 +17,10 @@ from ..meta_optimizers import *
meta_optimizer_names = list(
filter(lambda name: name.endswith("Optimizer"), dir()))
# Because HybridParallelOptimizer is dygraph optimizer, it
# should be removed
meta_optimizer_names.remove("HybridParallelOptimizer")
class MetaOptimizerFactory(object):
def __init__(self):
......
......@@ -24,8 +24,16 @@ __all__ = ['CommunicateTopology', 'HybridCommunicateGroup']
_HYBRID_PARALLEL_GROUP = None
class ParallelMode(object):
DATA_PARALLEL = 0
MODEL_PARALLEL = 1
PIPELINE_PARALLEL = 2
class CommunicateTopology(object):
def __init__(self, hybrid_group_names, dims):
def __init__(self,
hybrid_group_names=["data", "pipe", "model"],
dims=[1, 1, 1]):
self._parallel_names = hybrid_group_names
self._dims = dims
self.coordinate = collections.namedtuple('Coordinate',
......@@ -118,15 +126,29 @@ class HybridCommunicateGroup(object):
# create comm group for data parallel
self._dp_group, self._dp_comm_group = self._set_comm_group("data")
print("data parallel group", self._dp_group, file=sys.stderr)
# create comm group for model parallel
self._mp_group, self._mp_comm_group = self._set_comm_group("model")
print("data parallel group", self._mp_group, file=sys.stderr)
debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d\n" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree)
debug_str += "dp_group: %s, mp_group: %s" % (self._dp_group,
self._mp_group)
print(debug_str, file=sys.stderr)
global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self
def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
return ParallelMode.MODEL_PARALLEL
elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL
def _check_vaild_topo(self):
return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks
......
......@@ -25,3 +25,4 @@ from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer
from .dygraph_optimizer import HybridParallelOptimizer
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.optimizer import Optimizer
from ...utils.hybrid_parallel_util import fused_allreduce_gradients
from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework
from paddle.fluid.framework import Variable
class HybridParallelOptimizer:
def __init__(self, optimizer, hcg, strategy):
self._inner_opt = optimizer
self._strategy = strategy
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
@imperative_base.no_grad
@framework.dygraph_only
def step(self):
if self._is_mp and self._need_dp:
fused_allreduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
self._inner_opt.step()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
assert isinstance(loss, Variable), "The loss should be an Tensor."
parameter_list = parameters if parameters \
else self._parameter_list
if self._is_mp and self._need_dp:
fused_allreduce_gradients(list(parameter_list), self._hcg)
return self._inner_opt.minimize(loss, startup_program, parameters,
no_grad_set)
def __getattr__(self, item):
return getattr(self._inner_opt, item)
......@@ -13,3 +13,4 @@
# limitations under the License.
from .mp_utils import *
from .model_parallel import ModelParallel
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.dygraph.layers import Layer
import logging
class MetaParallelBase(Layer):
def __init__(self, layers, hcg, strategy):
super(MetaParallelBase,
self).__init__(layers.full_name() + "_meta_parallel_base")
self._layers = layers
self._hcg = hcg
self._prepare_for_model()
def _prepare_for_model(self):
pass
def _pre_forward(self, *inputs, **kwargs):
pass
def forward(self, *inputs, **kwargs):
self._pre_forward(*inputs, **kwargs)
output = self._layers(*inputs, **kwargs)
self._post_forward(output)
return output
def _post_forward(self, output):
pass
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import *
class ModelParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super(ModelParallel, self).__init__(layers, hcg, **kwargs)
def _prepare_for_model(self):
broadcast_mp_parameters(self._layers, self._hcg)
broadcast_dp_parameters(self._layers, self._hcg)
def _pre_forward(self, *inputs, **kwargs):
return broadcast_input_data(self._hcg, *inputs, **kwargs)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np
import warnings
from paddle import framework
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, construct_groups
from collections import OrderedDict
def _apply_collective_grads(parameters, comm_group):
grad_var_set = set()
grad_vars = []
sparse_grad_vars = []
for param in parameters:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
assert not g_var._is_sparse(
), "Now, it doesn't support sparse parameters"
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
coalesced_grads_and_vars = construct_groups(grad_vars, 128 * 1024 * 1024)
for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad = coalesced_grad / comm_group.nranks
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
_split_tensors(coalesced_grads_and_vars)
def broadcast_input_data(hcg, *inputs, **kwargs):
model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank()
for input_ in inputs:
if isinstance(input_, core.VarBase):
with framework.no_grad():
paddle.distributed.broadcast(
input_,
src=src_rank,
group=model_parallel_group,
use_calc_stream=True)
else:
print("it doesn't support data type {}".format(type(input_)))
for k, v in kwargs.items():
if isinstance(v, core.VarBase):
with framework.no_grad():
paddle.distributed.broadcast(
v,
src=src_rank,
group=model_parallel_group,
use_calc_stream=True)
kwargs[k] = v
else:
print("it doesn't support data type {}".format(type(v)))
return inputs, kwargs
def broadcast_mp_parameters(model, hcg):
model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank()
sync_params_buffers(
model, model_parallel_group, src_rank, is_model_parallel=True)
def broadcast_dp_parameters(model, hcg):
data_parallel_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
sync_params_buffers(
model, data_parallel_group, src_rank, is_model_parallel=False)
def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = hcg.get_data_parallel_group()
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
......@@ -25,6 +25,7 @@ from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
from ..layers import collective
from paddle.fluid.dygraph import base as imperative_base
import warnings
import paddle
import itertools
......@@ -320,6 +321,62 @@ def scale_loss(loss):
return scaled_loss
@imperative_base.no_grad
@framework.dygraph_only
def construct_groups(vars, group_size):
group_idx = 0
memory_counter = 0
var_groups = OrderedDict()
dtype = vars[0].dtype
for var in vars:
bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
if memory_counter < group_size and dtype == var.dtype:
memory_counter += bytes
else:
memory_counter = 0
dtype = var.dtype
group_idx += 1
var_groups.setdefault(group_idx, []).append(var)
return _coalesce_tensors(var_groups)
@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(model,
comm_group=None,
src_rank=0,
is_model_parallel=False):
model_vars = []
for _, param in model.state_dict().items():
if not isinstance(param, core.VarBase):
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
# is_distributed param not need to sync when in mp mode
if is_model_parallel and param.is_distributed:
continue
model_vars.append(param.detach())
if len(model_vars) == 0:
return
# group size is 128M
coalesced_vars = construct_groups(model_vars, 128 * 1024 * 1024)
for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast(
coalesced_var, src=src_rank, group=comm_group, use_calc_stream=True)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
paddle.fluid.framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len,
'axis': 0})
class DataParallel(layers.Layer):
"""
Run the dygraph module with data parallelism.
......@@ -443,7 +500,7 @@ class DataParallel(layers.Layer):
# TODO(liuyuhui) Currently not support xpu. xpu is
# still broadcasting parameters when calling layer
if not paddle.is_compiled_with_xpu():
self._sync_params_buffers()
sync_params_buffers(self._layers)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
......@@ -516,46 +573,6 @@ class DataParallel(layers.Layer):
return itertools.chain(*map(self._find_varbase, obj.values()))
return []
def _sync_params_buffers(self):
model_vars = []
for _, param in self._layers.state_dict().items():
if not isinstance(param, core.VarBase):
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
model_vars.append(param.detach())
if len(model_vars) == 0:
return
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
var_groups = OrderedDict()
dtype = model_vars[0].dtype
for var in model_vars:
bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
if memory_counter < mega_bytes and dtype == var.dtype:
memory_counter += bytes
else:
memory_counter = 0
dtype = var.dtype
group_idx += 1
var_groups.setdefault(group_idx, []).append(var)
coalesced_vars = _coalesce_tensors(var_groups)
for coalesced_var, _, _ in coalesced_vars:
collective._broadcast(coalesced_var, root=0, sync_mode=True)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len,
'axis': 0})
def forward(self, *inputs, **kwargs):
outputs = self._layers(*inputs, **kwargs)
if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
import paddle.fluid.generator as generator
from paddle.io import DataLoader, Dataset
import unittest
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + rank_id)
vocab_size = 5
hidden_size = 10
inner_size = 8
output_size = 2
seq_length = 2
class SimpleMPNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2, mp_id):
super(SimpleMPNet, self).__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, :(inner_size // 2)]
init_fc2_data = np_fc2[:(inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2):]
init_fc2_data = np_fc2[(inner_size // 2):, :]
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)),
gather_output=False,
has_bias=True)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)),
input_is_parallel=True,
has_bias=True)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5))
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
class SimpleDPNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2):
super(SimpleDPNet, self).__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5))
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
return x
class TrainDataset(Dataset):
def __init__(self, length):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, index):
np_input_data = np.random.randint(0, vocab_size, (seq_length, ))
return np_input_data
class TestDistTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 2
self.data_parallel_size = 1
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": 1
}
fleet.init(is_collective=True, strategy=strategy)
def test_mp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
mp_id = hcg.get_model_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
set_random_seed(1024, dp_id, rank_id)
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
train_data = TrainDataset(length=10000)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_data,
batch_size=4,
shuffle=False,
num_replicas=self.data_parallel_size,
rank=dp_id)
train_data_loader = DataLoader(
dataset=train_data,
batch_sampler=train_batch_sampler,
num_workers=0,
return_list=True)
model_a = SimpleMPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2, mp_id)
optimizer_a = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model_a.parameters())
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2)
optimizer_b = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model_b.parameters())
for step, batch in enumerate(train_data_loader):
if step > 5:
return
output_a = model_a(batch)
loss_a = output_a.mean()
loss_a.backward()
optimizer_a.step()
optimizer_a.clear_grad()
output_b = model_b(batch)
loss_b = output_b.mean()
loss_b.backward()
optimizer_b.step()
optimizer_b.clear_grad()
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,6 @@
from __future__ import print_function
import unittest
import time
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
......@@ -28,6 +27,9 @@ class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_mp_random(self):
self.run_mnist_2gpu('hybrid_parallel_mp_random.py')
def test_hybrid_parallel_mp_model(self):
self.run_mnist_2gpu('hybrid_parallel_mp_model.py')
if __name__ == "__main__":
unittest.main()
......@@ -150,6 +150,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding',
'paddle.distributed.fleet.meta_optimizers.ascend',
'paddle.distributed.fleet.meta_optimizers.dygraph_optimizer',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册