提交 690cd1f7 编写于 作者: C chengduoZH

refine gather and broadcast

上级 494c262a
...@@ -99,9 +99,11 @@ void BroadcastOpHandle::RunImpl() { ...@@ -99,9 +99,11 @@ void BroadcastOpHandle::RunImpl() {
PADDLE_THROW("Var should be LoDTensor or SelectedRows."); PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
} }
Tensor *out_tensor = GetTensorFromVar(out_var); auto dev_ctx = dev_ctxes_[out_p];
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
out_tensor); Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor);
});
} }
} }
......
...@@ -84,7 +84,7 @@ void GatherOpHandle::RunImpl() { ...@@ -84,7 +84,7 @@ void GatherOpHandle::RunImpl() {
"The type of input is not consistent."); "The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(), PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent."); "The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), , PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent."); "The dims of inputs is not consistent.");
auto in_sr_rows = in_sr.rows(); auto in_sr_rows = in_sr.rows();
...@@ -110,14 +110,17 @@ void GatherOpHandle::RunImpl() { ...@@ -110,14 +110,17 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out->mutable_value(); Tensor *out_tensor = out->mutable_value();
// copy // copy
int s = 0, e = 0; auto dev_ctx = dev_ctxes_[out_place];
for (size_t j = 0; j < in_tensors.size(); ++j) { RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
e += in_tensors[j].dims()[0]; int s = 0, e = 0;
auto sub_out = out_tensor->Slice(s, e); for (size_t j = 0; j < in_tensors.size(); ++j) {
paddle::framework::TensorCopy(in_tensors[j], out_place, e += in_tensors[j].dims()[0];
*(dev_ctxes_[in_places[j]]), &sub_out); auto sub_out = out_tensor->Slice(s, e);
s = e; paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
} &sub_out);
s = e;
}
});
} }
std::string GatherOpHandle::Name() const { return "gather"; } std::string GatherOpHandle::Name() const { return "gather"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册