提交 690cd1f7 编写于 作者: C chengduoZH

refine gather and broadcast

上级 494c262a
......@@ -99,9 +99,11 @@ void BroadcastOpHandle::RunImpl() {
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
}
Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
out_tensor);
auto dev_ctx = dev_ctxes_[out_p];
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
Tensor *out_tensor = GetTensorFromVar(out_var);
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor);
});
}
}
......
......@@ -84,7 +84,7 @@ void GatherOpHandle::RunImpl() {
"The type of input is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
"The height of inputs is not consistent.");
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), ,
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
"The dims of inputs is not consistent.");
auto in_sr_rows = in_sr.rows();
......@@ -110,14 +110,17 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out->mutable_value();
// copy
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place,
*(dev_ctxes_[in_places[j]]), &sub_out);
s = e;
}
auto dev_ctx = dev_ctxes_[out_place];
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) {
e += in_tensors[j].dims()[0];
auto sub_out = out_tensor->Slice(s, e);
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
&sub_out);
s = e;
}
});
}
std::string GatherOpHandle::Name() const { return "gather"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册