提交 17a076d8 编写于 作者: C chengduoZH

replace TensorCopy with TensorCopySync

上级 fa613206
...@@ -16,26 +16,14 @@ limitations under the License. */ ...@@ -16,26 +16,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static const platform::DeviceContext* GetDeviceContext( void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
const platform::Place& src_place, const platform::Place& dst_place) { Tensor *out) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
Tensor* out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA");
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the // the enforced checkings have been done in GetDeviceContext, so the
...@@ -46,8 +34,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, ...@@ -46,8 +34,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
// the transforming is from CPU to GPU and the number of elements is little. // the transforming is from CPU to GPU and the number of elements is little.
// But the embarrassment is that this solution this solution makes training // But the embarrassment is that this solution this solution makes training
// slower. // slower.
TensorCopy(in, dst_place, *dev_ctx, out); TensorCopySync(in, dst_place, out);
dev_ctx->Wait();
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册