未验证 提交 b754700f 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix reduce and broadcast to avoid multi-stream, test=develop (#19889)

上级 8359b415
......@@ -38,8 +38,6 @@ void BroadcastOpHandle::RunImpl() {
VarHandle *in_var_handle = in_var_handles[0];
WaitInputVarGenerated();
BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
}
......@@ -59,6 +57,7 @@ void BroadcastOpHandle::BroadcastOneVar(
InitOutputValue(in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) {
WaitInputVarGenerated();
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue;
......@@ -109,6 +108,7 @@ void BroadcastOpHandle::BroadcastOneVar(
});
}
WaitInputVarGenerated();
this->RunAndRecordEvent([&] {
{
platform::NCCLGroupGuard guard;
......
......@@ -78,44 +78,59 @@ struct ReduceBufferData {
}
};
inline void GatherLocalSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selecte_rows) {
PADDLE_ENFORCE(!src_selecte_rows_.empty());
std::vector<Tensor> in_tensors;
std::vector<int64_t> out_rows;
for (auto in_sr_ptr : src_selecte_rows_) {
auto &in_sr = *in_sr_ptr;
in_tensors.emplace_back(in_sr.value());
out_rows.insert(out_rows.end(), in_sr.rows().begin(), in_sr.rows().end());
struct GatherLocalSelectedRowsFunctor {
GatherLocalSelectedRowsFunctor(
const std::vector<const SelectedRows *> &src_selected_rows,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selected_rows)
: dev_ctxes_(dev_ctxes),
in_places_(in_places),
out_place_(out_place),
dst_selected_rows_(dst_selected_rows) {
PADDLE_ENFORCE_EQ(src_selected_rows.empty(), false);
std::vector<int64_t> out_rows;
for (auto in_sr_ptr : src_selected_rows) {
auto &in_sr = *in_sr_ptr;
in_tensors_.emplace_back(in_sr.value());
out_rows.insert(out_rows.end(), in_sr.rows().begin(), in_sr.rows().end());
}
auto &pre_in = src_selected_rows[0];
auto &dst_tensor = *dst_selected_rows_;
dst_tensor.set_height(pre_in->height());
dst_tensor.set_rows(out_rows);
size_t rows = out_rows.size();
DDim out_dim = pre_in->GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows);
dst_tensor.mutable_value()->Resize(out_dim);
dst_tensor.mutable_value()->mutable_data(out_place, pre_in->value().type());
}
auto &pre_in = src_selecte_rows_[0];
auto &dst_tensor = *dst_selecte_rows;
dst_tensor.set_height(pre_in->height());
dst_tensor.set_rows(out_rows);
size_t rows = out_rows.size();
DDim out_dim = pre_in->GetCompleteDims();
out_dim[0] = static_cast<int64_t>(rows);
dst_tensor.mutable_value()->Resize(out_dim);
dst_tensor.mutable_value()->mutable_data(out_place, pre_in->value().type());
Tensor *out_tensor = dst_tensor.mutable_value();
// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place,
*(dev_ctxes.at(in_places[j])), &sub_out);
s = e;
void operator()() {
auto *out_tensor = dst_selected_rows_->mutable_value();
// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors_.size(); ++j) {
e += in_tensors_[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors_[j], out_place_,
*(dev_ctxes_.at(in_places_[j])), &sub_out);
s = e;
}
}
}
private:
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes_;
std::vector<platform::Place> in_places_;
std::vector<Tensor> in_tensors_;
platform::Place out_place_;
SelectedRows *dst_selected_rows_;
};
} // namespace details
} // namespace framework
......
......@@ -66,8 +66,11 @@ void ReduceOpHandle::GatherSelectedRows(
auto gathered_var_mid = scope->Var(gathered_var_name);
auto gathered_select_rows =
gathered_var_mid->GetMutable<framework::SelectedRows>();
GatherLocalSelectedRows(src_selected_rows, in_places, dev_ctxes, out_place,
gathered_select_rows);
GatherLocalSelectedRowsFunctor functor(
src_selected_rows, in_places, dev_ctxes, out_place, gathered_select_rows);
WaitInputVarGenerated();
functor();
// FIXME(gongwb): remove this Wait.
Wait(dev_ctxes);
......@@ -167,9 +170,6 @@ void ReduceOpHandle::RunImpl() {
var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation
WaitInputVarGenerated();
// NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) {
......@@ -209,9 +209,11 @@ void ReduceOpHandle::RunImpl() {
// TODO(gongwb): add cpu support
if (collective_context.endpoints_.size() <= 1 ||
is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) {
GatherLocalSelectedRows(in_selected_rows, in_places, dev_ctxes_,
t_out_p,
out_var->GetMutable<framework::SelectedRows>());
GatherLocalSelectedRowsFunctor functor(
in_selected_rows, in_places, dev_ctxes_, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
WaitInputVarGenerated();
functor();
return;
}
......@@ -236,6 +238,7 @@ void ReduceOpHandle::RunImpl() {
GetInputValues<LoDTensor>(in_var_handles, var_scopes);
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
WaitInputVarGenerated();
this->RunAndRecordEvent([&] {
// FIXME(zcd): The order of summing is important,
// especially when the type of data is float or double.
......@@ -295,6 +298,7 @@ void ReduceOpHandle::RunImpl() {
});
}
WaitInputVarGenerated();
this->RunAndRecordEvent([&] {
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册