未验证 提交 8c54f1fb 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10906 from chengduoZH/fix_data_trans

Fix DataTransFunc
...@@ -16,31 +16,25 @@ limitations under the License. */ ...@@ -16,31 +16,25 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static const platform::DeviceContext* GetDeviceContext( void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
const platform::Place& src_place, const platform::Place& dst_place) { Tensor *out) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) {
return pool.Get(src_place);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
return pool.Get(dst_place);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
Tensor* out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
TensorCopy(in, dst_place, *dev_ctx, out); PADDLE_ENFORCE_NE(
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) { in.place().which(), dst_place.which(),
dev_ctx->Wait(); "Currently, model parallelism is only supported between CPU and CUDA");
}
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
// slow, especially when the number of elements is little, for example,
// the elements of learning rate are one and it's CPU side.
// One solution is to use a CUDA kernel to complete the copy operation when
// the transforming is from CPU to GPU and the number of elements is little.
// But the embarrassment is that this solution this solution makes training
// slower.
TensorCopySync(in, dst_place, out);
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册