提交 239fafb0 编写于 作者: D dzhwinter

"test on parallel do op"

上级 07dd3d25
...@@ -79,6 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -79,6 +79,7 @@ inline void CopyOrShare(const framework::Variable &src,
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
} }
dst->set_lod(src.lod());
} else if (src.IsType<SelectedRows>()) { } else if (src.IsType<SelectedRows>()) {
auto &src_sr = src.Get<SelectedRows>(); auto &src_sr = src.Get<SelectedRows>();
auto *dst_sr = dst->GetMutable<SelectedRows>(); auto *dst_sr = dst->GetMutable<SelectedRows>();
...@@ -89,6 +90,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -89,6 +90,7 @@ inline void CopyOrShare(const framework::Variable &src,
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
} }
dst_sr->set_rows(src_sr.rows());
} else { } else {
PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name()); PADDLE_THROW("Expect LoDTensor/SelectedRows, get %s", src.Type().name());
} }
...@@ -145,6 +147,7 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -145,6 +147,7 @@ class ParallelDoOp : public framework::OperatorBase {
auto *sub_scope = sub_scopes[i]; auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>(); auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::Copy(src, place, dst); framework::Copy(src, place, dst);
dst->set_lod(src.lod());
} }
} }
WaitOnPlaces(places); WaitOnPlaces(places);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册