From 2603cb7e86dc4fdfe163d17f286df7ab2f05c4d6 Mon Sep 17 00:00:00 2001 From: QI JUN Date: Wed, 11 Oct 2017 10:21:54 -0700 Subject: [PATCH] Unify CUDA stream in Tensor CopyFrom interface (#4692) * init * unify CopyFrom interface * fix gpu build error * fix bug in tensor_py.h * refine code comments and add TODO list * fix conflicts in FeedOp and FetchOp --- paddle/framework/tensor.h | 17 ++++--- paddle/framework/tensor_array.cc | 15 ++++-- paddle/framework/tensor_impl.h | 51 ++++++++++++++------- paddle/framework/tensor_test.cc | 44 +++++++++++------- paddle/operators/feed_op.h | 2 +- paddle/operators/fetch_op.h | 3 +- paddle/operators/math/im2col_test.cc | 32 +++++++------ paddle/operators/math/math_function_test.cc | 32 +++++++------ paddle/operators/multiplex_op.cu | 6 ++- paddle/operators/recurrent_op.cc | 6 +-- paddle/operators/reshape_op.h | 4 +- paddle/operators/rnn/recurrent_op_utils.cc | 4 +- paddle/operators/rnn/recurrent_op_utils.h | 2 +- paddle/pybind/tensor_py.h | 15 +++++- 14 files changed, 147 insertions(+), 86 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index ba82127d9c0..3304d857ae2 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -87,26 +87,31 @@ class Tensor { /** * @brief Copy the content of external tensor to a new place. * - * @param[in] src The external tensor. - * @param[in] ctx The device context contains place where to store. + * @param[in] src The external tensor. + * @param[in] dst_place The dst place. + * @param[in] ctx The device context contains device resources. * * @note CopyFrom supports CPU <-> GPU, GPU <-> GPU. */ + // TODO(qijun): https://github.com/PaddlePaddle/Paddle/issues/4647 + // Remove `CopyFrom` and `CopyFromVector` from Tensor interface + // and make them global functions template - inline void CopyFrom(const Tensor& src, const platform::Place& dst_place); + inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, + const platform::DeviceContext& ctx); /** * @brief Copy the content of an external vector to a tensor. * - * @param[in] src The external vector. - * @param[in] ctx The device context contains place where to store. + * @param[in] src The external tensor. + * @param[in] ctx The device context contains device resources. * * * @note CopyFromVector assumes that the tensor has been resized * before invoking. */ template inline void CopyFromVector(const std::vector& src, - const platform::Place& dst_place); + const platform::DeviceContext& ctx); /** * @brief Return the slice of the tensor. diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc index 2728bce1c1a..7ae16e99cdb 100644 --- a/paddle/framework/tensor_array.cc +++ b/paddle/framework/tensor_array.cc @@ -95,7 +95,8 @@ void TensorArray::Write(size_t index, const LoDTensor& value) { values_[index].Resize(value.dims()); values_[index].mutable_data(platform::CPUPlace()); - values_[index].CopyFrom(value, platform::CPUPlace()); + values_[index].CopyFrom(value, platform::CPUPlace(), + platform::CPUDeviceContext()); } void TensorArray::WriteShared(size_t index, const LoDTensor& value) { @@ -151,7 +152,8 @@ LoDTensor TensorArray::Stack() const { for (size_t idx = 0; idx < size(); idx++) { result.Slice(idx, idx + 1) - .CopyFrom(Read(idx), platform::CPUPlace()); + .CopyFrom(Read(idx), platform::CPUPlace(), + platform::CPUDeviceContext()); } return result; } @@ -182,7 +184,8 @@ void TensorArray::Unstack(const LoDTensor& source, bool data_shared) const { // copy value.Resize(value_dims); value.CopyFrom(source.Slice(elem, elem + 1), - platform::CPUPlace()); + platform::CPUPlace(), + platform::CPUDeviceContext()); } } } @@ -236,7 +239,8 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { auto target = result.Slice(i, i + 1); auto source_ = source->Slice(index, index + 1); - target.CopyFrom(source_, platform::CPUPlace()); + target.CopyFrom(source_, platform::CPUPlace(), + platform::CPUDeviceContext()); } return result; @@ -269,7 +273,8 @@ LoDTensor PackDynamicBatch(const std::vector& source, if (index >= seq_meta.end) break; auto source_ = source[batch_id].Slice(seq_id, seq_id + 1); auto target = result.Slice(index, index + 1); - target.CopyFrom(source_, platform::CPUPlace()); + target.CopyFrom(source_, platform::CPUPlace(), + platform::CPUDeviceContext()); } } diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 8ee9941982c..ce73e0a9edb 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -88,7 +88,8 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) { template inline void Tensor::CopyFrom(const Tensor& src, - const platform::Place& dst_place) { + const platform::Place& dst_place, + const platform::DeviceContext& ctx) { src.check_memory_size(); Resize(src.dims()); @@ -106,26 +107,45 @@ inline void Tensor::CopyFrom(const Tensor& src, #ifdef PADDLE_WITH_CUDA else if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, 0); + auto src_gpu_place = boost::get(src_place); + auto dst_cpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); + memory::Copy( + dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); } else if (platform::is_cpu_place(src_place) && platform::is_gpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, 0); + auto src_cpu_place = boost::get(src_place); + auto dst_gpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); + memory::Copy( + dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, - boost::get(src_place), src_ptr, size, 0); + auto src_gpu_place = boost::get(src_place); + auto dst_gpu_place = boost::get(dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); + auto ctx_gpu_place = boost::get(ctx_place); + PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); + memory::Copy( + dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx).stream()); } - PADDLE_ENFORCE(cudaStreamSynchronize(0), - "cudaStreamSynchronize failed in Tensor CopyFrom"); - #endif } template inline void Tensor::CopyFromVector(const std::vector& src, - const platform::Place& dst_place) { + const platform::DeviceContext& ctx) { + auto dst_place = ctx.GetPlace(); auto src_ptr = static_cast(src.data()); platform::CPUPlace src_place; auto dst_ptr = static_cast(mutable_data(dst_place)); @@ -137,12 +157,11 @@ inline void Tensor::CopyFromVector(const std::vector& src, } #ifdef PADDLE_WITH_CUDA else if (platform::is_gpu_place(dst_place)) { - memory::Copy(boost::get(dst_place), dst_ptr, src_place, - src_ptr, size, 0); + memory::Copy( + boost::get(dst_place), dst_ptr, src_place, src_ptr, + size, + reinterpret_cast(ctx).stream()); } - PADDLE_ENFORCE(cudaStreamSynchronize(0), - "cudaStreamSynchronize failed in Tensor CopyFromVector"); - #endif } diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 492eba69e1e..0b62fe08ce9 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -194,6 +194,7 @@ TEST(Tensor, CopyFrom) { { Tensor src_tensor; Tensor dst_tensor; + CPUDeviceContext cpu_ctx((CPUPlace())); int* src_ptr = src_tensor.mutable_data(make_ddim({3, 3}), CPUPlace()); @@ -201,7 +202,7 @@ TEST(Tensor, CopyFrom) { memcpy(src_ptr, arr, 9 * sizeof(int)); auto cpu_place = new paddle::platform::CPUPlace(); - dst_tensor.CopyFrom(src_tensor, *cpu_place); + dst_tensor.CopyFrom(src_tensor, *cpu_place, cpu_ctx); const int* dst_ptr = dst_tensor.data(); ASSERT_NE(src_ptr, dst_ptr); @@ -210,7 +211,7 @@ TEST(Tensor, CopyFrom) { } Tensor slice_tensor = src_tensor.Slice(1, 2); - dst_tensor.CopyFrom(slice_tensor, *cpu_place); + dst_tensor.CopyFrom(slice_tensor, *cpu_place, cpu_ctx); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); ASSERT_NE(dst_ptr, slice_ptr); @@ -231,13 +232,15 @@ TEST(Tensor, CopyFrom) { // CPU Tensor to GPU Tensor auto gpu_place = new paddle::platform::GPUPlace(0); - gpu_tensor.CopyFrom(src_tensor, *gpu_place); + CUDADeviceContext gpu_ctx(*gpu_place); + gpu_tensor.CopyFrom(src_tensor, *gpu_place, gpu_ctx); // GPU Tensor to CPU Tensor auto cpu_place = new paddle::platform::CPUPlace(); - dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx); - // Compare Tensors + // Sync before Compare Tensors + gpu_ctx.Wait(); const int* dst_ptr = dst_tensor.data(); ASSERT_NE(src_ptr, dst_ptr); for (size_t i = 0; i < 9; ++i) { @@ -247,12 +250,13 @@ TEST(Tensor, CopyFrom) { Tensor slice_tensor = src_tensor.Slice(1, 2); // CPU Slice Tensor to GPU Tensor - gpu_tensor.CopyFrom(slice_tensor, *gpu_place); + gpu_tensor.CopyFrom(slice_tensor, *gpu_place, gpu_ctx); // GPU Tensor to CPU Tensor - dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx); - // Compare Slice Tensors + // Sync before Compare Slice Tensors + gpu_ctx.Wait(); const int* slice_ptr = slice_tensor.data(); dst_ptr = dst_tensor.data(); ASSERT_NE(dst_ptr, slice_ptr); @@ -273,7 +277,8 @@ TEST(Tensor, CopyFromVector) { // Copy to CPU Tensor cpu_tensor.Resize(make_ddim({3, 3})); auto cpu_place = new paddle::platform::CPUPlace(); - cpu_tensor.CopyFromVector(src_vec, *cpu_place); + CPUDeviceContext cpu_ctx(*cpu_place); + cpu_tensor.CopyFromVector(src_vec, cpu_ctx); // Compare Tensors const int* cpu_ptr = cpu_tensor.data(); @@ -285,7 +290,7 @@ TEST(Tensor, CopyFromVector) { src_vec.erase(src_vec.begin(), src_vec.begin() + 5); cpu_tensor.Resize(make_ddim({2, 2})); - cpu_tensor.CopyFromVector(src_vec, *cpu_place); + cpu_tensor.CopyFromVector(src_vec, cpu_ctx); cpu_ptr = cpu_tensor.data(); src_ptr = src_vec.data(); ASSERT_NE(src_ptr, cpu_ptr); @@ -306,16 +311,19 @@ TEST(Tensor, CopyFromVector) { // Copy to CPU Tensor cpu_tensor.Resize(make_ddim({3, 3})); auto cpu_place = new paddle::platform::CPUPlace(); - cpu_tensor.CopyFromVector(src_vec, *cpu_place); + CPUDeviceContext cpu_ctx(*cpu_place); + cpu_tensor.CopyFromVector(src_vec, cpu_ctx); // Copy to GPUTensor gpu_tensor.Resize(make_ddim({3, 3})); auto gpu_place = new paddle::platform::GPUPlace(); - gpu_tensor.CopyFromVector(src_vec, *gpu_place); + CUDADeviceContext gpu_ctx(*gpu_place); + gpu_tensor.CopyFromVector(src_vec, gpu_ctx); // Copy from GPU to CPU tensor for comparison - dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx); - // Compare Tensors + // Sync before Compare Tensors + gpu_ctx.Wait(); const int* src_ptr = src_vec.data(); const int* cpu_ptr = cpu_tensor.data(); const int* dst_ptr = dst_tensor.data(); @@ -329,11 +337,13 @@ TEST(Tensor, CopyFromVector) { src_vec.erase(src_vec.begin(), src_vec.begin() + 5); cpu_tensor.Resize(make_ddim({2, 2})); - cpu_tensor.CopyFromVector(src_vec, *cpu_place); + cpu_tensor.CopyFromVector(src_vec, cpu_ctx); gpu_tensor.Resize(make_ddim({2, 2})); - gpu_tensor.CopyFromVector(src_vec, *gpu_place); - dst_tensor.CopyFrom(gpu_tensor, *cpu_place); + gpu_tensor.CopyFromVector(src_vec, gpu_ctx); + dst_tensor.CopyFrom(gpu_tensor, *cpu_place, gpu_ctx); + // Sync before Compare Tensors + gpu_ctx.Wait(); src_ptr = src_vec.data(); cpu_ptr = cpu_tensor.data(); dst_ptr = dst_tensor.data(); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h index 9d8158299fe..e756cd1842a 100644 --- a/paddle/operators/feed_op.h +++ b/paddle/operators/feed_op.h @@ -34,7 +34,7 @@ class FeedKernel : public framework::OpKernel { // TODO(qijun): // check tensors[col].dims() with attribute, // except the first dimenson. - out->CopyFrom(tensors[col], ctx.GetPlace()); + out->CopyFrom(tensors[col], ctx.GetPlace(), ctx.device_context()); } }; diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h index eb9c3a7b593..b2a6e958750 100644 --- a/paddle/operators/fetch_op.h +++ b/paddle/operators/fetch_op.h @@ -35,7 +35,8 @@ class FetchKernel : public framework::OpKernel { PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); (*tensors)[col].Resize(input->dims()); (*tensors)[col].mutable_data(platform::CPUPlace()); - (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + (*tensors)[col].CopyFrom(*input, platform::CPUPlace(), + ctx.device_context()); // TODO(qijun): need to handle LodTensor later } }; diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 40bdbfe7335..9c506ae89bd 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -49,10 +49,22 @@ void testIm2col() { memcpy(input_ptr, arr, 6 * sizeof(float)); auto* place = new Place(); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { +#ifdef PADDLE_WITH_CUDA + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#else + PADDLE_THROW("no GPU support"); +#endif // PADDLE_ONLY_CPU + } if (paddle::platform::is_cpu_place(*place)) { input = input_tmp; } else { - input.CopyFrom(input_tmp, *place); + input.CopyFrom(input_tmp, *place, *context); } output_cfo.mutable_data( {1, filter_size, filter_size, output_height, output_width}, *place); @@ -66,18 +78,6 @@ void testIm2col() { paddle::operators::math::ColFormat::kOCF, Place, float> im2col_ocf; - paddle::platform::DeviceContext* context; - if (paddle::platform::is_cpu_place(*place)) { - context = - new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); - } else { -#ifdef PADDLE_WITH_CUDA - context = - new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); -#else - PADDLE_THROW("no GPU support"); -#endif // PADDLE_ONLY_CPU - } im2col(*context, input, output_cfo, stride, stride, padding, padding); im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); @@ -85,7 +85,8 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { out_cfo_ptr = output_cfo.data(); } else { - output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace()); + output_tmp.CopyFrom(output_cfo, paddle::platform::CPUPlace(), + *context); out_cfo_ptr = output_tmp.data(); } EXPECT_EQ(out_cfo_ptr[0], 0); @@ -101,7 +102,8 @@ void testIm2col() { if (paddle::platform::is_cpu_place(*place)) { out_ocf_ptr = output_ocf.data(); } else { - output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace()); + output_tmp.CopyFrom(output_ocf, paddle::platform::CPUPlace(), + *context); out_ocf_ptr = output_tmp.data(); } EXPECT_EQ(out_ocf_ptr[0], 0); diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 9945ba101d7..c87d200c3aa 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -17,17 +17,18 @@ TEST(math_function, notrans_mul_trans) { auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - input1_gpu.CopyFrom(input1, *gpu_place); - input2_gpu.CopyFrom(input1, *gpu_place); + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input1, *gpu_place, context); out_gpu.mutable_data({2, 2}, *gpu_place); paddle::operators::math::matmul( context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); - out.CopyFrom(out_gpu, *cpu_place); + out.CopyFrom(out_gpu, *cpu_place, context); float* out_ptr = out.data(); + context.Wait(); EXPECT_EQ(out_ptr[0], 5); EXPECT_EQ(out_ptr[1], 14); EXPECT_EQ(out_ptr[2], 14); @@ -50,17 +51,18 @@ TEST(math_function, trans_mul_notrans) { auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - input1_gpu.CopyFrom(input1, *gpu_place); - input2_gpu.CopyFrom(input1, *gpu_place); + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input1, *gpu_place, context); out_gpu.mutable_data({3, 3}, *gpu_place); paddle::operators::math::matmul( context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); - out.CopyFrom(out_gpu, *cpu_place); + out.CopyFrom(out_gpu, *cpu_place, context); float* out_ptr = out.data(); + context.Wait(); EXPECT_EQ(out_ptr[0], 9); EXPECT_EQ(out_ptr[1], 12); EXPECT_EQ(out_ptr[2], 15); @@ -98,9 +100,9 @@ TEST(math_function, gemm_notrans_cublas) { auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - input1_gpu.CopyFrom(input1, *gpu_place); - input2_gpu.CopyFrom(input2, *gpu_place); - input3_gpu.CopyFrom(input3, *gpu_place); + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input2, *gpu_place, context); + input3_gpu.CopyFrom(input3, *gpu_place, context); float* a = input1_gpu.data(); float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(*gpu_place); @@ -108,7 +110,7 @@ TEST(math_function, gemm_notrans_cublas) { paddle::operators::math::gemm( context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); - input3.CopyFrom(input3_gpu, *cpu_place); + input3.CopyFrom(input3_gpu, *cpu_place, context); // numpy code: // a = np.arange(6).reshape(2, 3) @@ -116,6 +118,7 @@ TEST(math_function, gemm_notrans_cublas) { // c = np.arange(8).reshape(2, 4)[:, 1:] // out = np.arange(8).reshape(2, 4) // out[:, 1:] = np.dot(a, b) + c + context.Wait(); EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[1], 24); EXPECT_EQ(input3_ptr[2], 28); @@ -152,9 +155,9 @@ TEST(math_function, gemm_trans_cublas) { auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::CUDADeviceContext context(*gpu_place); - input1_gpu.CopyFrom(input1, *gpu_place); - input2_gpu.CopyFrom(input2, *gpu_place); - input3_gpu.CopyFrom(input3, *gpu_place); + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input2, *gpu_place, context); + input3_gpu.CopyFrom(input3, *gpu_place, context); float* a = input1_gpu.data(); float* b = input2_gpu.data(); float* c = input3_gpu.mutable_data(*gpu_place); @@ -162,7 +165,8 @@ TEST(math_function, gemm_trans_cublas) { paddle::operators::math::gemm( context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); - input3.CopyFrom(input3_gpu, *cpu_place); + input3.CopyFrom(input3_gpu, *cpu_place, context); + context.Wait(); EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[1], 24); diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 72b1f96eafd..10cb0e005f4 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -33,7 +33,8 @@ class MultiplexGPUKernel : public framework::OpKernel { auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; - index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), + ctx.device_context()); auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( ctx.device_context()) @@ -70,7 +71,8 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto cols = ins[0]->numel() / rows; // copy index to cpu Tensor index_t_cpu; - index_t_cpu.CopyFrom(*ids, platform::CPUPlace()); + index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), + ctx.device_context()); auto* index = index_t_cpu.data(); auto stream = reinterpret_cast( diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 04c4c24951f..00647f55f79 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -46,7 +46,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, } (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx); } void RecurrentAlgorithm::CreateScopes(const Scope& scope, @@ -151,12 +151,12 @@ void RecurrentGradientAlgorithm::Run( auto& step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len); for (int step_id = seq_len - 1; step_id >= 0; --step_id) { - if (step_id != seq_len - 1) { + if (static_cast(step_id) != seq_len - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1); } (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } - rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len); + rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx); LinkBootMemoryGradients(step_scopes[0]); } diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 628dfe4c0fa..3ba4611458f 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -33,7 +33,7 @@ class ReshapeKernel : public framework::OpKernel { std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { return static_cast(a); }); auto out_dims = framework::make_ddim(shape_int64); - out->CopyFrom(*in, ctx.GetPlace()); + out->CopyFrom(*in, ctx.GetPlace(), ctx.device_context()); out->Resize(out_dims); } }; @@ -47,7 +47,7 @@ class ReshapeGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); auto in_dims = d_x->dims(); - d_x->CopyFrom(*d_out, ctx.GetPlace()); + d_x->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context()); d_x->Resize(in_dims); } }; diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index ef317a71f12..d264664a99e 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -51,7 +51,7 @@ void SegmentInputs(const std::vector& step_scopes, void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len) { + const size_t seq_len, const platform::DeviceContext& ctx) { for (size_t i = 0; i < outlinks.size(); i++) { auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]); PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.", @@ -72,7 +72,7 @@ void ConcatOutputs(const std::vector& step_scopes, // TODO(luotao02) data type and platform::DeviceContext() should set // correctly (output->Slice(j, j + 1)) - .CopyFrom(*step_output, platform::CPUPlace()); + .CopyFrom(*step_output, platform::CPUPlace(), ctx); } } } diff --git a/paddle/operators/rnn/recurrent_op_utils.h b/paddle/operators/rnn/recurrent_op_utils.h index fd17b9b8891..fe173edb24a 100644 --- a/paddle/operators/rnn/recurrent_op_utils.h +++ b/paddle/operators/rnn/recurrent_op_utils.h @@ -71,7 +71,7 @@ void SegmentInputs(const std::vector& step_scopes, */ void ConcatOutputs(const std::vector& step_scopes, const std::vector& outlinks, - const size_t seq_len); + const size_t seq_len, const platform::DeviceContext& ctx); void LinkMemories(const std::vector& step_scopes, const std::vector& memories, const size_t step_id, diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index 9e73f79cbdd..85f9f22733c 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -57,7 +57,18 @@ struct CastToPyBufferImpl { } framework::Tensor dst_tensor; if (paddle::platform::is_gpu_place(tensor.place())) { - dst_tensor.CopyFrom(tensor, platform::CPUPlace()); +#ifdef PADDLE_WITH_CUDA + auto *src_ptr = static_cast(tensor.data()); + auto *dst_ptr = static_cast(dst_tensor.mutable_data( + tensor.dims(), platform::CPUPlace())); + // TODO(qijun): Here we use default CUDA stream to set GPU Tensor to + // a Python numpy array. It's better to manage CDUA stream unifiedly. + paddle::platform::GpuMemcpySync(dst_ptr, src_ptr, + sizeof(CUR_TYPE) * tensor.numel(), + cudaMemcpyDeviceToHost); +#else + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); +#endif } else if (paddle::platform::is_cpu_place(tensor.place())) { dst_tensor = tensor; } @@ -120,6 +131,8 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); + // TODO(qijun): Here we use default CUDA stream to set a Python numpy + // array to a GPU Tensor. It's better to manage CDUA stream unifiedly. paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); } -- GitLab