未验证 提交 b754700f 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix reduce and broadcast to avoid multi-stream, test=develop (#19889)

上级 8359b415
...@@ -38,8 +38,6 @@ void BroadcastOpHandle::RunImpl() { ...@@ -38,8 +38,6 @@ void BroadcastOpHandle::RunImpl() {
VarHandle *in_var_handle = in_var_handles[0]; VarHandle *in_var_handle = in_var_handles[0];
WaitInputVarGenerated();
BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_); BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
} }
...@@ -59,6 +57,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -59,6 +57,7 @@ void BroadcastOpHandle::BroadcastOneVar(
InitOutputValue(in_var_handle, out_var_handles); InitOutputValue(in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) { if (platform::is_cpu_place(in_tensor.place())) {
WaitInputVarGenerated();
for (auto *out_var_handle : out_var_handles) { for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(in_var_handle)) { if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue; continue;
...@@ -109,6 +108,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -109,6 +108,7 @@ void BroadcastOpHandle::BroadcastOneVar(
}); });
} }
WaitInputVarGenerated();
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
{ {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
......
...@@ -78,44 +78,59 @@ struct ReduceBufferData { ...@@ -78,44 +78,59 @@ struct ReduceBufferData {
} }
}; };
inline void GatherLocalSelectedRows( struct GatherLocalSelectedRowsFunctor {
const std::vector<const SelectedRows *> &src_selecte_rows_, GatherLocalSelectedRowsFunctor(
const std::vector<platform::Place> &in_places, const std::vector<const SelectedRows *> &src_selected_rows,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes, const std::vector<platform::Place> &in_places,
const platform::Place &out_place, SelectedRows *dst_selecte_rows) { const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
PADDLE_ENFORCE(!src_selecte_rows_.empty()); const platform::Place &out_place, SelectedRows *dst_selected_rows)
: dev_ctxes_(dev_ctxes),
std::vector<Tensor> in_tensors; in_places_(in_places),
std::vector<int64_t> out_rows; out_place_(out_place),
dst_selected_rows_(dst_selected_rows) {
for (auto in_sr_ptr : src_selecte_rows_) { PADDLE_ENFORCE_EQ(src_selected_rows.empty(), false);
auto &in_sr = *in_sr_ptr;
in_tensors.emplace_back(in_sr.value()); std::vector<int64_t> out_rows;
out_rows.insert(out_rows.end(), in_sr.rows().begin(), in_sr.rows().end());
for (auto in_sr_ptr : src_selected_rows) {
auto &in_sr = *in_sr_ptr;
in_tensors_.emplace_back(in_sr.value());
out_rows.insert(out_rows.end(), in_sr.rows().begin(), in_sr.rows().end());
}
auto &pre_in = src_selected_rows[0];
auto &dst_tensor = *dst_selected_rows_;
dst_tensor.set_height(pre_in->height());
dst_tensor.set_rows(out_rows);
size_t rows = out_rows.size();
DDim out_dim = pre_in->GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows);
dst_tensor.mutable_value()->Resize(out_dim);
dst_tensor.mutable_value()->mutable_data(out_place, pre_in->value().type());
} }
auto &pre_in = src_selecte_rows_[0]; void operator()() {
auto *out_tensor = dst_selected_rows_->mutable_value();
auto &dst_tensor = *dst_selecte_rows; // copy
dst_tensor.set_height(pre_in->height()); int s = 0, e = 0;
dst_tensor.set_rows(out_rows); for (size_t j = 0; j < in_tensors_.size(); ++j) {
size_t rows = out_rows.size(); e += in_tensors_[j].dims()[0];
DDim out_dim = pre_in->GetCompleteDims(); auto sub_out = out_tensor->Slice(s, e);
out_dim[0] = static_cast<int64_t>(rows); paddle::framework::TensorCopy(in_tensors_[j], out_place_,
dst_tensor.mutable_value()->Resize(out_dim); *(dev_ctxes_.at(in_places_[j])), &sub_out);
dst_tensor.mutable_value()->mutable_data(out_place, pre_in->value().type()); s = e;
Tensor *out_tensor = dst_tensor.mutable_value(); }
// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place,
*(dev_ctxes.at(in_places[j])), &sub_out);
s = e;
} }
}
private:
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes_;
std::vector<platform::Place> in_places_;
std::vector<Tensor> in_tensors_;
platform::Place out_place_;
SelectedRows *dst_selected_rows_;
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -66,8 +66,11 @@ void ReduceOpHandle::GatherSelectedRows( ...@@ -66,8 +66,11 @@ void ReduceOpHandle::GatherSelectedRows(
auto gathered_var_mid = scope->Var(gathered_var_name); auto gathered_var_mid = scope->Var(gathered_var_name);
auto gathered_select_rows = auto gathered_select_rows =
gathered_var_mid->GetMutable<framework::SelectedRows>(); gathered_var_mid->GetMutable<framework::SelectedRows>();
GatherLocalSelectedRows(src_selected_rows, in_places, dev_ctxes, out_place, GatherLocalSelectedRowsFunctor functor(
gathered_select_rows); src_selected_rows, in_places, dev_ctxes, out_place, gathered_select_rows);
WaitInputVarGenerated();
functor();
// FIXME(gongwb): remove this Wait. // FIXME(gongwb): remove this Wait.
Wait(dev_ctxes); Wait(dev_ctxes);
...@@ -167,9 +170,6 @@ void ReduceOpHandle::RunImpl() { ...@@ -167,9 +170,6 @@ void ReduceOpHandle::RunImpl() {
var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name()); var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU. // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
...@@ -209,9 +209,11 @@ void ReduceOpHandle::RunImpl() { ...@@ -209,9 +209,11 @@ void ReduceOpHandle::RunImpl() {
// TODO(gongwb): add cpu support // TODO(gongwb): add cpu support
if (collective_context.endpoints_.size() <= 1 || if (collective_context.endpoints_.size() <= 1 ||
is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) { is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) {
GatherLocalSelectedRows(in_selected_rows, in_places, dev_ctxes_, GatherLocalSelectedRowsFunctor functor(
t_out_p, in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
WaitInputVarGenerated();
functor();
return; return;
} }
...@@ -236,6 +238,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -236,6 +238,7 @@ void ReduceOpHandle::RunImpl() {
GetInputValues<LoDTensor>(in_var_handles, var_scopes); GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) { if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
WaitInputVarGenerated();
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
// FIXME(zcd): The order of summing is important, // FIXME(zcd): The order of summing is important,
// especially when the type of data is float or double. // especially when the type of data is float or double.
...@@ -295,6 +298,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -295,6 +298,7 @@ void ReduceOpHandle::RunImpl() {
}); });
} }
WaitInputVarGenerated();
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) { for (auto &call : all_reduce_calls) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册