未验证 提交 8c54f1fb 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10906 from chengduoZH/fix_data_trans

Fix DataTransFunc
......@@ -16,31 +16,25 @@ limitations under the License. */
namespace paddle {
namespace framework {
static const platform::DeviceContext* GetDeviceContext(
const platform::Place& src_place, const platform::Place& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
Tensor* out) {
void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
Tensor *out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
TensorCopy(in, dst_place, *dev_ctx, out);
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) {
dev_ctx->Wait();
}
PADDLE_ENFORCE_NE(
in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA");
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
// slow, especially when the number of elements is little, for example,
// the elements of learning rate are one and it's CPU side.
// One solution is to use a CUDA kernel to complete the copy operation when
// the transforming is from CPU to GPU and the number of elements is little.
// But the embarrassment is that this solution this solution makes training
// slower.
TensorCopySync(in, dst_place, out);
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册