未验证 提交 8460698b 编写于 作者: S ShenLiang 提交者: GitHub

Support control flow in DataParallel (#31625)

* support control flow

* supoort sync_parameters_buffers

* fix the bug of sparse embedding
上级 40e6c57b
......@@ -152,6 +152,7 @@ message DistributedStrategy {
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];
optional float last_comm_group_size_MB = 27 [ default = 1 ];
optional bool find_unused_parameters = 28 [ default = true ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......
......@@ -167,8 +167,6 @@ void BKCLParallelContext::WaitCompute(int ring_id) {
platform::errors::OutOfRange("Ring id expected < nrings,"
"but got ring id = %d, nrings = %d",
ring_id, strategy_.nrings_));
// TODO(wangxi16): [Performance optimize] Maybe need to put Wait and
// bkcl_allreduce to comm thread, for bkcl_allreduce is blocking now.
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
......@@ -188,6 +186,12 @@ void BKCLParallelContext::WaitComm(int ring_id) {
comm_dev_ctx->Wait();
}
void BKCLParallelContext::SynchronizeCompute() {
auto compute_dev_ctx = static_cast<platform::XPUDeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}
} // namespace imperative
} // namespace paddle
#endif
......@@ -47,6 +47,8 @@ class BKCLParallelContext : public ParallelContext {
void WaitCompute(int ring_id) override;
void WaitComm(int ring_id) override;
void SynchronizeCompute() override;
};
} // namespace imperative
......
......@@ -173,6 +173,12 @@ void NCCLParallelContext::WaitComm(int ring_id) {
#endif
}
void NCCLParallelContext::SynchronizeCompute() {
auto *compute_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
compute_dev_ctx->Wait();
}
#endif
} // namespace imperative
......
......@@ -65,6 +65,8 @@ class NCCLParallelContext : public ParallelContext {
void WaitComm(int ring_id) override;
void SynchronizeCompute() override;
private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std::vector<std::shared_ptr<platform::CudaEventObject>> compute_events_;
......
......@@ -66,6 +66,9 @@ class ParallelContext {
// if CPU, should do nothing.
virtual void WaitComm(int ring_id) = 0;
// synchorize compute stream
virtual void SynchronizeCompute() = 0;
inline int GetNRings() const { return strategy_.nrings_; }
inline int64_t GetNRanks() const { return strategy_.nranks_; }
......
此差异已折叠。
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -153,13 +154,20 @@ class Reducer {
void MarkGroupReady(size_t group_index);
void FusedAllReduceSchedule(int run_order, Group& group); // NOLINT
void FusedAllReduceSchedule(const int run_order, Group& group, // NOLINT
const int curr_group_index);
void FinalizeBackward();
std::vector<std::vector<size_t>> RebuildGruops();
inline bool NeedRebuildGroup() { return !has_rebuilt_group_; }
inline bool NeedRebuildGroup() {
return !has_rebuilt_group_ && !find_unused_vars_;
}
void ProcessUnusedDenseVars();
bool HasGrad(size_t var_index);
private:
std::vector<std::shared_ptr<imperative::VarBase>> vars_;
......@@ -188,7 +196,7 @@ class Reducer {
std::vector<size_t> unused_vars_;
bool has_marked_unused_vars_{false};
bool find_unused_vars_{false};
bool all_group_ready_{false};
bool groups_need_finalize_{false};
#ifdef PADDLE_WITH_XPU_BKCL
// comm_pool_ is used for scheduling allreduce in multi Kunlun cards training.
std::unique_ptr<::ThreadPool> comm_pool_{nullptr};
......@@ -196,6 +204,19 @@ class Reducer {
std::mutex mutex_;
std::condition_variable cv_;
#endif
// it just for checking hook, each parameter can only trigger one hook
std::vector<bool> vars_marked_ready_;
// Following variables are to help control flow.
// local_used_vars_ uses 0/1 to indicate whether the
// var is used in iteration. After the end of the
// iteration, global_used_vars_ is obtained synchronously
// globally. Choose whether to update the local
// gradient according to the global_used_vars_.
std::vector<int> local_used_vars_;
// global_used_vars_ is used in comm stream to avoid wait
framework::Variable global_used_vars_;
};
std::vector<std::vector<size_t>> AssignGroupBySize(
......
......@@ -620,6 +620,34 @@ class DistributedStrategy(object):
else:
raise ValueError("last_comm_group_size_MB should be greater than 0")
@property
def find_unused_parameters(self):
"""
Indicating whether we are using find_unused_parameters to
find unused parameters in DataParallel.
Default value: True
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.find_unused_parameters = True
"""
return self.strategy.find_unused_parameters
@find_unused_parameters.setter
@is_strict_auto
def find_unused_parameters(self, flag):
if isinstance(flag, bool):
self.strategy.find_unused_parameters = flag
else:
print(
"WARNING: find_unused_parameters should have value of bool type")
@property
def _fuse_grad_size_in_TFLOPS(self):
return self.strategy.fuse_grad_size_in_TFLOPS
......
......@@ -706,7 +706,9 @@ class Fleet(object):
model,
comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB)
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
return self.model
@dygraph_only
......
......@@ -22,6 +22,7 @@ import copy
import weakref
import warnings
from copy import deepcopy
import paddle
from . import parallel_helper
from .. import unique_name
......@@ -894,9 +895,15 @@ class Layer(core.Layer):
if not self._built:
with program_desc_tracing_guard(False):
self._build_once(*inputs, **kwargs)
if parallel_helper._is_data_parallel_mode():
# TODO(liuyuhui) Only xpu broadcast parameters here.
# The other device is to call _sync_params_buffers in DataParallel
# to realize the parameter synchronization among multiply cards.
if parallel_helper._is_data_parallel_mode(
) and paddle.is_compiled_with_xpu():
parallel_helper._broadcast_parameters(
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs)
......
......@@ -24,6 +24,7 @@ from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph import to_variable, no_grad
from paddle.utils import deprecated
from ..layers import collective
import warnings
import paddle
import itertools
......@@ -348,6 +349,18 @@ class DataParallel(layers.Layer):
last_comm_buffer_size(float, optional): It limits memory size(MB) of last buffer in communication
calling. Making the last communication buffer size small is useful to
improve performance. Default: 1.
find_unused_parameters(bool, optional): Whether to traverse the entire backward graph from the
all tensors in the return value of the wrapped model's
forward function. For parameters not involved in loss
calculation, their gradients will be marked as ready in
advance to prepare reduce. Please note that all forward
outputs derived from the wrapped model parameters must
participate in the calculation of loss and subsequent
gradient calculations. If not, serious error will occur.
Note that setting the find_unused_parameters to True
will affect computing performance. Therefore, if all parameters
are sure to participate in the loss calculation and the
autograd graph construction, please set it False. Default: True.
Returns:
Layer: The data paralleled module.
......@@ -403,11 +416,13 @@ class DataParallel(layers.Layer):
layers,
strategy=None,
comm_buffer_size=25,
last_comm_buffer_size=1):
last_comm_buffer_size=1,
find_unused_parameters=True):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers
self.find_unused_parameters = find_unused_parameters
# NOTE(chenweihang): The ParallelStrategy here is not strictly a strategy.
# It just stores some environment variables, which can be constructed by
......@@ -419,6 +434,17 @@ class DataParallel(layers.Layer):
self._strategy = _build_default_parallel_strategy()
if self._strategy.nranks > 1:
# check the environment
assert parallel_helper.__parallel_ctx__clz__ is not None, \
"ParallelContext must be initialized before. You should use init_parallel_env() before" \
"constructing the DataParallel."
# sync buffer and params
# TODO(liuyuhui) Currently not support xpu. xpu is
# still broadcasting parameters when calling layer
if not paddle.is_compiled_with_xpu():
self._sync_params_buffers()
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
# the size of the group, Default: 1MB. The role of this small group is:
......@@ -449,6 +475,10 @@ class DataParallel(layers.Layer):
trainable_parameters = [param for _, param in layers_param]
assert len(trainable_parameters) > 0, \
"This model does not have any parameters to train, and " \
"does not need to use DataParallel"
# NOTE(shenliang03): Here we can only use the attributes to judge whether
# parameter is sparse(or SelectedRows). The reason is that the sparse message
# can't be obtained when bp hasn't happened yet. So if layer supports sparse parameter,
......@@ -470,19 +500,12 @@ class DataParallel(layers.Layer):
trainable_parameters, is_sparse_gradient,
[self.last_comm_buffer_size, self.comm_buffer_size])
assert parallel_helper.__parallel_ctx__clz__ is not None, \
"ParallelContext must be initialized before. You should use init_parallel_env() before" \
"constructing the DataParallel."
# TODO(shenliang03) "find_unused_vars" interface will be exposed in the future
# to handle control flow to process unused parameters
find_unused_vars = True
self._reducer = core.Reducer(
trainable_parameters,
list(reversed(self.group_indices)), is_sparse_gradient,
parallel_helper.__parallel_ctx__clz__,
[self.last_comm_buffer_size, self.comm_buffer_size],
find_unused_vars)
self.find_unused_parameters)
def _find_varbase(self, obj):
if isinstance(obj, core.VarBase):
......@@ -493,11 +516,54 @@ class DataParallel(layers.Layer):
return itertools.chain(*map(self._find_varbase, obj.values()))
return []
def _sync_params_buffers(self):
model_vars = []
for _, param in self._layers.state_dict().items():
if not isinstance(param, core.VarBase):
raise TypeError("The data type of '%s' must be Varbase" %
param.name)
model_vars.append(param.detach())
if len(model_vars) == 0:
return
mega_bytes = 128 * 1024 * 1024
group_idx = 0
memory_counter = 0
var_groups = OrderedDict()
dtype = model_vars[0].dtype
for var in model_vars:
bytes = np.prod(var.shape) * core.size_of_dtype(var.dtype)
if memory_counter < mega_bytes and dtype == var.dtype:
memory_counter += bytes
else:
memory_counter = 0
dtype = var.dtype
group_idx += 1
var_groups.setdefault(group_idx, []).append(var)
coalesced_vars = _coalesce_tensors(var_groups)
for coalesced_var, _, _ in coalesced_vars:
collective._broadcast(coalesced_var, root=0, sync_mode=True)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len,
'axis': 0})
def forward(self, *inputs, **kwargs):
outputs = self._layers(*inputs, **kwargs)
if self._strategy.nranks > 1:
self._reducer.prepare_for_backward(
list(self._find_varbase(outputs)))
if self._strategy.nranks > 1 and framework._dygraph_tracer()._has_grad:
if self.find_unused_parameters:
self._reducer.prepare_for_backward(
list(self._find_varbase(outputs)))
else:
self._reducer.prepare_for_backward(list(self._find_varbase([])))
return outputs
......
......@@ -19,6 +19,8 @@ list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
......@@ -160,6 +162,8 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_control_flow)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
elseif(WITH_GPU)
......@@ -824,10 +828,12 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
endif()
endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.distributed as dist
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Embedding
import paddle.nn.functional as F
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
paddle.seed(123)
np.random.seed(2021)
class SimpleNet(fluid.Layer):
def __init__(self, hidden_size, vocab_size, is_sparse=False):
super(SimpleNet, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.embedding = Embedding(
size=[self.vocab_size, self.hidden_size],
dtype='float32',
is_sparse=is_sparse)
self.lin_a = paddle.nn.Linear(self.hidden_size, self.vocab_size)
self.lin_b = paddle.nn.Linear(self.vocab_size, 1)
self.unused_net = paddle.nn.Linear(5, 3)
self.phony = self.create_parameter(shape=[1], dtype="float32")
def forward(self, input, label, conf):
x_emb = self.embedding(input)
fc = self.lin_a(x_emb)
mask = conf > 0
mask = paddle.cast(mask, dtype="int64")
mask.stop_gradient = True
emb_mask = mask.max(1).flatten()
emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten()
emb_mask_inds.stop_gradient = True
if emb_mask_inds.numel() == 0:
loss_box = self.phony * 0
else:
projection = self.lin_b(fc)
projection = paddle.reshape(projection, shape=[-1, 1])
output = paddle.gather(projection, emb_mask_inds)
target = paddle.gather(label, emb_mask_inds)
loss_box = F.smooth_l1_loss(
output, target, reduction='sum', delta=1.0)
loss_box = loss_box / len(conf)
return loss_box
# global configs
batch_size = 4
batch_num = 2000
hidden_size = 5
vocab_size = 100
conf_dataset = [[0], [0], [0], [0], [1], [0], [1], [0], [0], [1], [0], [1],
[1], [1], [1], [1], [1], [1], [1], [1], [1], [0], [0], [1]]
def fake_sample_reader():
def __reader__():
for i in range(batch_num):
x_data = np.random.randint(0, vocab_size)
y_data = np.random.random_sample((1, )).astype('float32')
conf_data = np.array(conf_dataset[i % len(conf_dataset)]).astype(
'int64')
yield x_data, y_data, conf_data
return __reader__
class TestSimpleNet(TestParallelDyGraphRunnerBase):
def get_model(self):
model = SimpleNet(
hidden_size=hidden_size, vocab_size=vocab_size, is_sparse=False)
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
return model, train_reader, optimizer
def run_one_loop(self, model, optimizer, batch):
x_data = np.array([x[0] for x in batch]).astype('int64')
y_data = np.array([x[1] for x in batch]).astype('float32')
conf_data = np.array([x[2] for x in batch]).astype('int64')
x_data = x_data.reshape((-1, 1))
y_data = y_data.reshape((-1, 1))
conf_data = conf_data.reshape((-1, 1))
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
conf = paddle.to_tensor(conf_data)
loss = model(x, y, conf)
return loss
if __name__ == "__main__":
runtime_main(TestSimpleNet)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.base import to_variable
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
np.random.seed(2021)
paddle.seed(1024)
batch_size = 4
batch_num = 1000
class SimpleNet(fluid.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.net_a = paddle.nn.Sequential(
paddle.nn.Linear(10, 20),
paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5))
self.net_b = paddle.nn.Sequential(
paddle.nn.Linear(10, 20),
paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5))
self.net_unused = Linear(10, 20)
self.step = 0
def forward(self, x):
if self.step % 2 == 0:
return self.net_a(x)
else:
return self.net_b(x)
self.step = self.step + 1
def fake_sample_reader():
def __reader__():
for i in range(batch_num):
x_data = np.random.random_sample((10, )).astype('float32')
yield x_data
return __reader__
class TestSimpleNet(TestParallelDyGraphRunnerBase):
def get_model(self):
model = SimpleNet()
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
return model, train_reader, optimizer
def run_one_loop(self, model, optimizer, batch):
x_data = np.array([x for x in batch])
x_data = x_data.reshape((-1, 10))
x = to_variable(x_data)
out = model(x)
loss = out.sum() / len(batch)
return loss
if __name__ == "__main__":
runtime_main(TestSimpleNet)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import paddle.distributed as dist
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
paddle.seed(1024)
np.random.seed(2021)
batch = 5
in_dim = 10
out_dim = 20
class SimpleNet(fluid.Layer):
def __init__(self, train_id):
super(SimpleNet, self).__init__()
self.w1 = self.create_parameter(
shape=[in_dim, out_dim], dtype="float32")
self.w2 = self.create_parameter(
shape=[in_dim, out_dim], dtype="float32")
self.share_net = Linear(out_dim, 10)
self.unused_param = self.create_parameter(
shape=[out_dim, in_dim], dtype="float64")
# just for test sync_params_buffers
self.register_buffer("queue", paddle.randn([10, 5]))
self.queue = paddle.nn.functional.normalize(self.queue, axis=0)
self.register_buffer("queue_ptr", paddle.zeros([1], 'int64'))
self.trainer_id = train_id
def forward(self, x):
is_use = (paddle.equal_all(
x, paddle.ones(shape=(batch, in_dim))).numpy()[0] and
self.trainer_id == 1)
if is_use:
tmp = paddle.matmul(x, self.w1)
else:
tmp = paddle.matmul(x, self.w2)
return self.share_net(tmp)
class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self):
dist.init_parallel_env()
self.trainer_id = dist.get_rank()
model_a = SimpleNet(self.trainer_id)
model_b = SimpleNet(self.trainer_id)
state_dict = model_a.state_dict()
model_b.set_state_dict(state_dict)
model_a = paddle.DataParallel(model_a)
model_b = paddle.DataParallel(model_b)
ones_input = paddle.ones(shape=(batch, in_dim))
ones_input.stop_gradient = True
w1_grad_sum = np.zeros((in_dim, out_dim), dtype='float32')
w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32')
for step_id in range(5):
random_input = paddle.rand(shape=(batch, in_dim))
random_input.stop_gradient = True
if step_id % 2 == 0:
out_a = model_a(random_input)
out_b = model_b(random_input)
else:
out_a = model_a(ones_input)
out_b = model_b(ones_input)
out_a.sum().backward()
out_b.sum().backward()
self.check_gradient(model_a.parameters())
self.check_gradient(model_b.parameters())
# test acc gradient
w1_grad_sum = self.check_acc(model_a._layers.w1.grad, w1_grad_sum,
model_b._layers.w1.grad)
w2_grad_sum = self.check_acc(model_a._layers.w2.grad, w2_grad_sum,
model_b._layers.w2.grad)
model_a.clear_gradients()
def check_acc(self, grad, grad_sum, acc_grad):
if grad is not None:
grad_sum = grad_sum + grad
np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6)
return grad_sum
def print_trainer_0(self, *args):
if self.trainer_id == 0:
print(*args)
def broadcast_param(self, param, root):
paddle.distributed.broadcast(param, root)
return param
def check_gradient(self, params):
other_param = []
for param in params:
if param.trainable and (param._grad_ivar() is not None):
grad = param._grad_ivar()
other_grad = self.broadcast_param(grad.clone(), root=1)
if self.trainer_id == 0:
np.testing.assert_allclose(other_grad.numpy(), grad.numpy())
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Linear
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
np.random.seed(2021)
paddle.seed(1024)
batch_size = 4
batch_num = 1000
class SimpleNet(fluid.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.net_a = paddle.nn.Sequential(
paddle.nn.Linear(10, 20),
paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5))
self.net_b = paddle.nn.Sequential(
paddle.nn.Linear(10, 20),
paddle.nn.Linear(20, 20), paddle.nn.Linear(20, 5))
self.step = 0
def forward(self, x):
return paddle.to_tensor(0.0, dtype='float32')
def fake_sample_reader():
def __reader__():
for i in range(batch_num):
x_data = np.random.random_sample((10, )).astype('float32')
yield x_data
return __reader__
class TestSimpleNet(TestParallelDyGraphRunnerBase):
def get_model(self):
model = SimpleNet()
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
return model, train_reader, optimizer
def run_one_loop(self, model, optimizer, batch):
x_data = np.array([x for x in batch])
x_data = x_data.reshape((-1, 10))
x = paddle.to_tensor(x_data)
out = model(x)
loss = out.sum() / len(batch)
return loss
if __name__ == "__main__":
runtime_main(TestSimpleNet)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
np.random.seed(2021)
paddle.seed(1024)
class SimpleNet(fluid.Layer):
def __init__(self):
# bias is unused parameters, and it share with net_a
super(SimpleNet, self).__init__()
self.net_a = Linear(input_dim=10, output_dim=5)
self.net_b = Linear(10, 10)
self.bias = self.net_a.bias
def forward(self, x):
return self.net_b(x)
batch_size = 4
batch_num = 1000
def fake_sample_reader():
def __reader__():
for i in range(batch_num):
x_data = np.random.random_sample((10, )).astype('float32')
yield x_data
return __reader__
class TestSimpleNet(TestParallelDyGraphRunnerBase):
def get_model(self):
model = SimpleNet()
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = paddle.optimizer.SGD(learning_rate=0.001,
parameters=model.parameters())
return model, train_reader, optimizer
def run_one_loop(self, model, optimizer, batch):
x_data = np.array([x for x in batch])
x_data = x_data.reshape((-1, 10))
x = to_variable(x_data)
out = model(x)
loss = out.sum() / len(batch)
return loss
if __name__ == "__main__":
runtime_main(TestSimpleNet)
......@@ -65,8 +65,6 @@ class SimpleNet(Layer):
def forward(self, input, label):
x_emb = self.embedding(input)
fc = paddle.matmul(x_emb, self.softmax_weight)
# use detach to stop gradient
fc = fc.detach()
fc = paddle.add(fc, self.softmax_bias)
projection = paddle.reshape(fc, shape=[-1, self.vocab_size])
loss = paddle.nn.functional.softmax_with_cross_entropy(
......
......@@ -37,7 +37,7 @@ class SimpleNet(Layer):
self.embedding = Embedding(
self.vocab_size,
self.hidden_size,
sparse=True,
sparse=is_sparse,
weight_attr=paddle.ParamAttr(
name='embedding_param',
initializer=paddle.nn.initializer.Uniform(
......@@ -105,7 +105,7 @@ class TestSparseEmbeddingUnusedVars(TestParallelDyGraphRunnerBase):
vocab_size=vocab_size,
num_steps=num_steps,
init_scale=init_scale,
is_sparse=True)
is_sparse=False)
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
......
......@@ -501,7 +501,12 @@ class TestParallelDyGraphRunnerBase(object):
type(self).__name__,
"begin to prepare context in dygraph with nccl2")
dygraph.parallel.prepare_context(strategy)
model = dygraph.parallel.DataParallel(model, strategy)
if not args.find_unused_parameters:
model = dygraph.parallel.DataParallel(
model, strategy, find_unused_parameters=False)
else:
model = dygraph.parallel.DataParallel(
model, strategy, find_unused_parameters=True)
print_to_err(type(self).__name__, "model built in dygraph")
out_losses = []
print_to_err(type(self).__name__, "begin to run dygraph training")
......@@ -574,9 +579,14 @@ class TestParallelDyGraphRunnerBase(object):
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
# set strategy
strategy = fleet.DistributedStrategy()
if not args.find_unused_parameters:
strategy.find_unused_parameters = False
# 3. init parallel env
if args.update_method == "nccl2" or "bkcl":
fleet.init(is_collective=True)
fleet.init(is_collective=True, strategy=strategy)
# 4. train model
model, train_reader, opt = self.get_model()
......@@ -628,6 +638,7 @@ def runtime_main(test_class):
parser.add_argument('--use_xpu', action='store_true')
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--accumulate_gradient', action='store_true')
parser.add_argument('--find_unused_parameters', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument('--hogwild', action='store_true')
......@@ -726,6 +737,7 @@ class TestDistBase(unittest.TestCase):
self._save_model = False
self._fuse_all_reduce = None
self._accumulate_gradient = False
self._find_unused_parameters = True
self._setup_config()
global DIST_UT_PORT
......@@ -852,6 +864,9 @@ class TestDistBase(unittest.TestCase):
if self._accumulate_gradient:
cmd += " --accumulate_gradient"
if self._find_unused_parameters:
cmd += " --find_unused_parameters"
env_local.update(envs)
print("local_cmd: {}, env: {}".format(cmd, env_local))
......@@ -1021,6 +1036,9 @@ class TestDistBase(unittest.TestCase):
if self._accumulate_gradient:
tr_cmd += " --accumulate_gradient"
if self._find_unused_parameters:
tr_cmd += " --find_unused_parameters"
if self._pipeline_mode:
tr_cmd += " --use_pipeline"
if self._mp_mode:
......
......@@ -179,6 +179,15 @@ class TestStrategyConfig(unittest.TestCase):
with self.assertRaises(ValueError):
strategy.last_comm_group_size_MB = -1
def test_find_unused_parameters(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.find_unused_parameters = True
self.assertEqual(strategy.find_unused_parameters, True)
strategy.find_unused_parameters = False
self.assertEqual(strategy.find_unused_parameters, False)
strategy.find_unused_parameters = "True"
self.assertEqual(strategy.find_unused_parameters, False)
def test_fuse_grad_size_in_TFLOPS(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy._fuse_grad_size_in_TFLOPS = 0.1
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
from spawn_runner_base import TestDistSpawnRunner
flag_name = os.path.splitext(__file__)[0]
class TestDygraphControlFlowSame(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_control_flow_same.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestFleetDygraphControlFlowSame(TestDygraphControlFlowSame):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._use_fleet_api = True
class TestFleetDygraphControlFlowSameAccGrad(TestDygraphControlFlowSame):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._accumulate_gradient = True
class TestDygraphControlFlowDiff(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_control_flow_different.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestFleetDygraphControlFlowDiff(TestDygraphControlFlowDiff):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._use_fleet_api = True
class TestFleetDygraphControlFlowDiffAccGrad(TestDygraphControlFlowDiff):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._accumulate_gradient = True
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import time
import paddle.fluid as fluid
from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, get_gpus, start_local_trainers
def get_cluster_from_args(selected_gpus):
cluster_node_ips = '127.0.0.1'
node_ip = '127.0.0.1'
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
node_ips.index(node_ip)
free_ports = None
free_ports = find_free_ports(len(selected_gpus))
if free_ports is not None:
free_ports = list(free_ports)
trainer_endpoints = []
for ip in node_ips:
trainer_endpoints.append(["%s:%d" % (ip, port) for port in free_ports])
return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus)
class TestMultipleGpus(unittest.TestCase):
def run_mnist_2gpu(self, target_file_name):
if not fluid.core.is_compiled_with_cuda(
) or fluid.core.get_cuda_device_count() == 0:
return
selected_gpus = get_gpus('0,1')
cluster = None
pod = None
cluster, pod = get_cluster_from_args(selected_gpus)
procs = start_local_trainers(
cluster,
pod,
training_script=target_file_name,
training_script_args=[])
while True:
alive = watch_local_trainers(procs, cluster.trainers_nranks())
if not alive:
print("Local procs complete, POD info:{}".format(pod))
break
time.sleep(3)
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_gradient_check.py')
if __name__ == "__main__":
unittest.main()
......@@ -73,6 +73,7 @@ class TestParallelDygraphMnistAccGrad(TestDistBase):
self._dygraph = True
self._use_fleet_api = True
self._accumulate_gradient = True
self._find_unused_parameters = False
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
......
......@@ -54,6 +54,7 @@ class TestParallelDygraphTransformerAccGrad(TestDistBase):
self._nccl2_mode = True
self._dygraph = True
self._accumulate_gradient = True
self._find_unused_parameters = False
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
......
......@@ -26,13 +26,13 @@ from parallel_dygraph_unused_variables import TestSparseEmbeddingUnusedVars
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphMnist(TestDistBase):
class TestParallelDygraphUnusedVar(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_mnist(self):
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_unused_variables.py",
......@@ -41,6 +41,14 @@ class TestParallelDygraphMnist(TestDistBase):
log_name=flag_name)
class TestFleetDygraphUnusedVar(TestParallelDygraphUnusedVar):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._use_fleet_api = True
class TestSparseEmbeddingUnusedVarsSpawn(TestDistSpawnRunner):
def test_mnist_with_spawn(self):
if fluid.core.is_compiled_with_cuda() and sys.version_info >= (3, 4):
......@@ -48,17 +56,31 @@ class TestSparseEmbeddingUnusedVarsSpawn(TestDistSpawnRunner):
test_class=TestSparseEmbeddingUnusedVars, delta=1e-5)
class TestFleetDygraphMnist(TestDistBase):
class TestParallelDygraphNoVar(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_net(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_none_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
class TestParallelDygraphSharedUnusedVariables(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._use_fleet_api = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_unused_variables.py",
"parallel_dygraph_shared_unused_var.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册