未验证 提交 71e26232 编写于 作者: D daquexian 提交者: GitHub

ddp broadcast params and buffers (#5913)

* ddp broadcast params and buffers
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* auto format by CI
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 7171eb1e
......@@ -249,29 +249,36 @@ Maybe<one::UserOpExpr> FindOrCreatEagerNcclBroadcastOpExpr(Symbol<ParallelDesc>
}
return iter->second;
}
} // namespace
Maybe<Tensor> GetSyncedTensorIfBroadcast(const std::shared_ptr<Tensor>& tensor,
Symbol<ParallelDesc> parallel_desc,
Symbol<cfg::NdSbp> nd_sbp) {
Optional<int64_t> parallel_id;
JUST(GetDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
if (!parallel_id.has_value()) { return tensor; }
const auto& broadcast_parallel_desc = JUST(GetBroadcastSubParallelDesc(parallel_desc, nd_sbp));
if (broadcast_parallel_desc->parallel_num() == 1 /* no broadcast */) { return tensor; }
std::shared_ptr<UserOpExpr> op_expr =
JUST(FindOrCreatEagerNcclBroadcastOpExpr(broadcast_parallel_desc));
if (JUST(broadcast_parallel_desc->MachineId4ParallelId(0)) == GlobalProcessCtx::Rank()) {
Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, Symbol<ParallelDesc> parallel_desc) {
CHECK_OR_RETURN(parallel_desc->containing_current_rank());
if (parallel_desc->parallel_num() == 1 /* no broadcast */) { return tensor; }
std::shared_ptr<UserOpExpr> op_expr = JUST(FindOrCreatEagerNcclBroadcastOpExpr(parallel_desc));
if (JUST(parallel_desc->MachineId4ParallelId(0)) == GlobalProcessCtx::Rank()) {
// inplace.
TensorTuple outputs{tensor};
JUST(OpInterpUtil::Dispatch(*op_expr, {tensor}, &outputs,
one::OpExprInterpContext(AttrMap{}, broadcast_parallel_desc)));
one::OpExprInterpContext(AttrMap{}, parallel_desc)));
return tensor;
} else {
return JUST(OpInterpUtil::Dispatch<one::Tensor>(
*op_expr, {tensor}, one::OpExprInterpContext(AttrMap{}, broadcast_parallel_desc)));
*op_expr, {tensor}, one::OpExprInterpContext(AttrMap{}, parallel_desc)));
}
}
namespace {
Maybe<Tensor> GetSyncedTensorIfBroadcast(const std::shared_ptr<Tensor>& tensor,
Symbol<ParallelDesc> parallel_desc,
Symbol<cfg::NdSbp> nd_sbp) {
Optional<int64_t> parallel_id;
JUST(GetDevice4CurrentProcessCtx(parallel_desc, &parallel_id));
if (!parallel_id.has_value()) { return tensor; }
const auto& broadcast_parallel_desc = JUST(GetBroadcastSubParallelDesc(parallel_desc, nd_sbp));
return Broadcast(tensor, broadcast_parallel_desc);
}
Maybe<Shape> CalcPhysicalShape(Symbol<ConsistentTensorMeta> consistent_tensor_meta) {
const auto& opt_parallel_id =
JUST(GetParallelId4CurrentProcessCtx(consistent_tensor_meta->parallel_desc()));
......
......@@ -19,8 +19,14 @@ namespace oneflow {
class Device;
class TensorTuple;
class ParallelDesc;
namespace one {
class Tensor;
Maybe<void> RunEmptyOp(TensorTuple* outputs);
}
Maybe<Tensor> Broadcast(const std::shared_ptr<Tensor>& tensor, Symbol<ParallelDesc> parallel_desc);
} // namespace one
} // namespace oneflow
......@@ -970,6 +970,10 @@
signature: "Tensor AllReduce(Tensor x)"
bind_python: True
- name: "broadcast"
signature: "Tensor Broadcast(Tensor x)"
bind_python: True
- name: "select_first"
signature: "Tensor SelectFirst(TensorTuple inputs)"
bind_python: True
......
......@@ -18,6 +18,7 @@ limitations under the License.
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
......@@ -33,6 +34,19 @@ namespace one {
namespace functional {
namespace impl {
class BroadcastFunctor {
public:
BroadcastFunctor() = default;
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
std::string device_type_str = JUST(x->device())->type();
CHECK_OR_RETURN(device_type_str == "cuda" || device_type_str == "cpu");
DeviceType device_type = device_type_str == "cuda" ? DeviceType::kGPU : DeviceType::kCPU;
const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(device_type, rank_group));
return one::Broadcast(x, parallel_desc);
}
};
class AllReduceFunctor {
public:
AllReduceFunctor() = default;
......@@ -75,7 +89,10 @@ class AllReduceFunctor {
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::AllReduceFunctor>("AllReduce"); };
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::AllReduceFunctor>("AllReduce");
m.add_functor<impl::BroadcastFunctor>("Broadcast");
};
} // namespace functional
} // namespace one
......
......@@ -40,9 +40,18 @@ def allreduce_fn(ddp_state_for_reversed_params, param):
return allreduce
def DistributedDataParallel(module: "flow.nn.Module"):
def DistributedDataParallel(
module: "flow.nn.Module", *, broadcast_buffers: bool = True
):
world_size = flow.distributed.get_world_size()
# TODO(jianhao): broadcast parameters and buffers
with flow.no_grad():
for x in module.parameters():
requires_grad = x.requires_grad
x.copy_(flow.F.broadcast(x))
# TODO: fix the bug that x's requires_grad is discarded
# after flow.F.broadcast
x.requires_grad_(requires_grad)
ddp_state_for_reversed_params = OrderedDict(
reversed([(x, [False, False]) for x in module.parameters()])
)
......@@ -51,7 +60,7 @@ def DistributedDataParallel(module: "flow.nn.Module"):
param.register_hook(lambda grad: grad / world_size)
param.register_hook(allreduce_fn(ddp_state_for_reversed_params, param))
def hook(module, input, output):
def post_forward_hook(module, input, output):
ddp_state_for_reversed_params = module._ddp_state_for_reversed_params
for state in ddp_state_for_reversed_params.values():
state[0], state[1] = False, False
......@@ -60,5 +69,15 @@ def DistributedDataParallel(module: "flow.nn.Module"):
)
return output
module.register_forward_hook(hook)
module.register_forward_hook(post_forward_hook)
if broadcast_buffers:
def pre_forward_hook(module, input):
with flow.no_grad():
for x in module.buffers():
x.copy_(flow.F.broadcast(x))
module.register_forward_pre_hook(pre_forward_hook)
return module
......@@ -129,6 +129,49 @@ class TestDDP(flow.unittest.TestCase):
test_case.assertTrue(np_allclose_with_shape(m.w2.grad.numpy(), np.array([4.5])))
test_case.assertTrue(np_allclose_with_shape(m.w3.grad.numpy(), np.array([3])))
def test_broadcast_buffer(test_case):
rank = flow.framework.distribute.get_rank()
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", flow.tensor([1, 2]) * (rank + 1))
def forward(self, x):
res = self.buf + x
self.buf.copy_(x)
return res
x = flow.tensor([2, 3]) * (rank + 1)
x = x.to("cuda")
m = CustomModule()
m = m.to("cuda")
m = ddp(m)
y1 = m(x)
y2 = m(x)
m = CustomModule()
m = m.to("cuda")
m = ddp(m, broadcast_buffers=False)
y3 = m(x)
y4 = m(x)
if rank == 0:
test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([3, 5])))
test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([4, 6])))
test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([3, 5])))
test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([4, 6])))
elif rank == 1:
test_case.assertTrue(np_allclose_with_shape(y1.numpy(), np.array([5, 8])))
test_case.assertTrue(np_allclose_with_shape(y2.numpy(), np.array([6, 9])))
test_case.assertTrue(np_allclose_with_shape(y3.numpy(), np.array([6, 10])))
test_case.assertTrue(np_allclose_with_shape(y4.numpy(), np.array([8, 12])))
else:
raise ValueError()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册