diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8c28709a68bec4fca5acaf2ec74b6d02402a6139..8b71f73c36c33d882b34c833031c50cd14817e76 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -22,7 +22,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) -nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) +nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc index 53b0d0fe083579da4f0bb600f292765aa2aa0d8a..cb27de6991674247e6215ce64a2da5000fa78ed4 100644 --- a/paddle/framework/lod_tensor.cc +++ b/paddle/framework/lod_tensor.cc @@ -24,8 +24,6 @@ limitations under the License. */ #include #include -#include - namespace paddle { namespace framework { diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 9d1294fdeb9bd76bf944f7ec3687e3c5bb333241..d0ab640485baf6d76ee629ea420b603f42b031b4 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -18,11 +18,11 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include #include -#include #endif #include #include "paddle/framework/ddim.h" +#include "paddle/framework/mixed_vector.h" #include "paddle/framework/tensor.h" #include "paddle/framework/tensor_util.h" #include "paddle/platform/enforce.h" @@ -31,15 +31,6 @@ limitations under the License. */ namespace paddle { namespace framework { -#ifndef PADDLE_WITH_CUDA -template -using Vector = std::vector; -#else -template -using Vector = thrust::host_vector< - T, thrust::system::cuda::experimental::pinned_allocator>; -#endif - /* * LoD is short for Level of Details. * @@ -55,7 +46,15 @@ using Vector = thrust::host_vector< * 0 2 4 7 * 0 2 5 7 10 12 15 20 */ -using LoD = std::vector>; +struct LoD : public std::vector> { + using std::vector>::vector; + + void CopyFromCUDA() { + for (auto it = this->begin(); it != this->end(); ++it) { + it->CopyFromCUDA(); + } + } +}; std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoDTensor& t); @@ -109,7 +108,10 @@ bool CheckAbsLoD(const LoD& in, int tensor_height = -1); */ class LoDTensor : public Tensor { public: - LoDTensor() {} + LoDTensor() : Tensor() {} + + /* Constructor with place should only be used in pybind */ + explicit LoDTensor(const platform::Place& place) : Tensor(place) {} explicit LoDTensor(const LoD& lod) : lod_(lod) {} diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc index 4d172c43c7cceacb7d0dfaf1c4d3028717350268..3b63020e685436396071fa05cd7697630ae56c95 100644 --- a/paddle/framework/lod_tensor_test.cc +++ b/paddle/framework/lod_tensor_test.cc @@ -23,6 +23,17 @@ namespace paddle { namespace framework { +TEST(LoD, data) { + LoD lod{{0, 1, 2}}; + lod.push_back({0, 2, 4, 5}); + lod.push_back(std::vector({0, 1, 6, 8, 10, 11})); + + auto& v = lod[0]; + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(v[i], i); + } +} + TEST(LodExpand, test) { LoD lod{{0, 2}}; LoDTensor tensor; diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index 1e253a2f6f35e827fb2e5db6270da03705b39514..d4c9f00bd9c00f3cae68858ca46c5320fc117405 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -14,6 +14,8 @@ #include #include +#include +#include "paddle/framework/init.h" #include "paddle/framework/lod_tensor.h" #include "paddle/platform/assert.h" @@ -26,7 +28,48 @@ __global__ void test(size_t* a, int size) { } } +TEST(Vector, Normal) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::memory; + + paddle::framework::InitDevices(); + + paddle::framework::Vector vec({1, 2, 3}); + size_t* ptr = vec.data(); + for (size_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], *(ptr + i)); + } + + vec.clear(); + vec.CopyFromCUDA(); + + std::vector v = {1, 2, 3}; + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(v[i], vec[i]); + } +} + +TEST(LoD, data) { + paddle::framework::InitDevices(); + + paddle::framework::LoD lod{{0, 1, 2}}; + lod.push_back({0, 2, 4, 5}); + lod.push_back(std::vector({0, 1, 6, 8, 10, 11})); + + auto& v = lod[0]; + test<<<1, 1>>>(v.cuda_data(), v.size()); + cudaDeviceSynchronize(); + + v.CopyFromCUDA(); + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(v[i], i * 2); + } +} + TEST(LoDTensor, LoDInGPU) { + paddle::framework::InitDevices(); + paddle::framework::LoDTensor lod_tensor; paddle::platform::CUDAPlace place(0); @@ -42,8 +85,9 @@ TEST(LoDTensor, LoDInGPU) { auto lod = lod_tensor.lod(); - test<<<1, 8>>>(lod[0].data(), lod[0].size()); + test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size()); cudaDeviceSynchronize(); + lod.CopyFromCUDA(); for (size_t i = 0; i < src_lod[0].size(); ++i) { EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); diff --git a/paddle/framework/mixed_vector.h b/paddle/framework/mixed_vector.h new file mode 100644 index 0000000000000000000000000000000000000000..0e0e23958602343f8e0106e3a88eaac9c6d71066 --- /dev/null +++ b/paddle/framework/mixed_vector.h @@ -0,0 +1,154 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/memory/memcpy.h" +#include "paddle/memory/memory.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace framework { + +/** + * @brief Vector support both cpu and gpu. + * host vector lifetime is same with Vector + * device vector is lazily malloc and modified. + */ + +template +class Vector : public std::vector { + public: + /* NOTE(dzhwinter): + * Data always store and modified on Host. + * If the data is modified when use cuda_data interface, + * You need to call the CopyFromCUDA explicitly to synchronize data. + * + */ + enum class kDataPosition { + kDataOnHost = 0, + kDataOnDevice = 1, + }; + + public: + using std::vector::vector; + + Vector() {} + Vector(const std::vector &v) : std::vector(v) {} // NOLINT + + virtual ~Vector() { +#ifdef PADDLE_WITH_CUDA + if (cuda_ptr_ != nullptr) { + memory::Free(place_, static_cast(cuda_ptr_)); + } +#endif + } + + T *cuda_data() { + CopyToCUDA(); + PADDLE_ENFORCE_NOT_NULL( + cuda_ptr_, "No data or Insufficient CUDA memory to allocation"); + return static_cast(cuda_ptr_); + } + + T *data() { return std::vector::data(); } + + const T *data() const { return std::vector::data(); } + + void CopyToCUDA(); + + void CopyFromCUDA(); + + void CopyToPeer(platform::Place); + + private: + void *cuda_ptr_ = nullptr; + size_t cuda_size_ = 0; + /*The DataPosition is unused now, + if we want support random access from cpu and cuda, + we need to overload all the vector method */ + + kDataPosition position_ = kDataPosition::kDataOnHost; + platform::CUDAPlace place_; +}; + +template +void Vector::CopyToCUDA() { +#ifdef PADDLE_WITH_CUDA + if (cuda_ptr_ == nullptr) { + cuda_ptr_ = + memory::Alloc(place_, this->size() * sizeof(T)); + } + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *cuda_ctx = pool.GetByPlace(place_); + + memory::Copy(place_, static_cast(cuda_ptr_), platform::CPUPlace(), + static_cast(this->data()), + this->size() * sizeof(T), cuda_ctx->stream()); + cuda_ctx->Wait(); + + cuda_size_ = this->size(); +#endif +} + +template +void Vector::CopyFromCUDA() { +#ifdef PADDLE_WITH_CUDA + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto *cuda_ctx = pool.GetByPlace(place_); + if (cuda_ptr_ == nullptr) { + LOG(WARNING) << "No uncommited cuda data."; + return; + } + this->resize(cuda_size_); + memory::Copy(platform::CPUPlace(), static_cast(this->data()), place_, + static_cast(cuda_ptr_), this->size() * sizeof(T), + cuda_ctx->stream()); + cuda_ctx->Wait(); + +#endif +} + +template +void Vector::CopyToPeer(platform::Place peer_place) { + if (platform::is_cpu_place(peer_place)) { + return; + } +#ifdef PADDLE_WITH_CUDA + auto *cuda_ctx = platform::DeviceContextPool::Instance().GetByPlace(place_); + void *peer_cuda_ptr_ = memory::Alloc( + boost::get(peer_place), this->size() * sizeof(T)); + memory::Copy(boost::get(peer_place), + static_cast(peer_cuda_ptr_), place_, + static_cast(cuda_ptr_), this->size() * sizeof(T), + cuda_ctx->stream()); + cuda_ctx->Wait(); + memory::Free(place_, static_cast(cuda_ptr_)); + place_ = boost::get(peer_place); + cuda_ptr_ = peer_cuda_ptr_; +#endif +} + +template class Vector; +template class Vector; +template class Vector; +template class Vector; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4aaa29d794c95592832a1fe990e2dce274eba9d5..f0ea709a5c37e769e3ffa1b2e9d1e39721979251 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -47,6 +47,11 @@ class Tensor { public: Tensor() : offset_(0) {} + /*! Constructor with place should only be used in pybind. */ + explicit Tensor(const platform::Place& place) : offset_(0) { + holder_->set_place(place); + } + /*! Return a pointer to mutable memory block. */ template inline T* data(); @@ -137,6 +142,7 @@ class Tensor { virtual std::type_index type() const = 0; virtual platform::Place place() const = 0; virtual void set_type(std::type_index type) = 0; + virtual void set_place(platform::Place place) = 0; }; template @@ -156,6 +162,7 @@ class Tensor { virtual void* ptr() const { return static_cast(ptr_.get()); } virtual std::type_index type() const { return type_; } virtual void set_type(std::type_index type) { type_ = type; } + virtual void set_place(platform::Place place) { place_ = place; } /*! the pointer of memory block. */ std::unique_ptr> ptr_; diff --git a/paddle/inference/CMakeLists.txt b/paddle/inference/CMakeLists.txt index 0288266c08f3ddfc5337bf2847cf65267491105b..2289ddc139cbddfbaa5238e683b2f8e784a7291e 100644 --- a/paddle/inference/CMakeLists.txt +++ b/paddle/inference/CMakeLists.txt @@ -1,4 +1,4 @@ -set(FLUID_CORE_MODULES proto_desc paddle_memory executor prune init) +set(FLUID_CORE_MODULES proto_desc paddle_memory lod_tensor executor prune init) cc_library(paddle_fluid_api SRCS io.cc diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 4e579387924a5b0499f29609bc6b1322030a3c0d..00cb6e9cafb4e79ed3d59cd4a6e40ea132e5efda 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -82,7 +82,7 @@ struct SparseAdagradFunctor { math::scatter::MergeAdd merge_func; auto grad_merge = merge_func(context, grad); auto* grad_merge_data = grad_merge.mutable_value()->template data(); - auto& merge_rows = grad_merge.rows(); + framework::Vector merge_rows(grad_merge.rows()); // 2. m += g_m * g_m math::scatter::Mul sqare_func; auto grad_square = sqare_func(context, grad_merge, grad_merge); @@ -101,8 +101,8 @@ struct SparseAdagradFunctor { SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, grad_merge.rows().data(), - lr, param_data, moment_data, grad_width, + .stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr, + param_data, moment_data, grad_width, epsilon); } }; diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 9cc34bdded780e61e8700eb4fa4a295c84fb48bc..bf536687d398b8342e6ae76a07c11e5fe47483e0 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -199,7 +199,12 @@ class AdamOpKernel : public framework::OpKernel { merge_func(ctx.template device_context(), grad); auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - auto* rows = grad_merge.rows().data(); + int64_t* rows = nullptr; + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = grad_merge.mutable_rows()->cuda_data(); + } else { + rows = grad_merge.mutable_rows()->data(); + } auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); SparseAdamFunctor functor( diff --git a/paddle/operators/ctc_align_op.cu b/paddle/operators/ctc_align_op.cu index 45635f16745346b08f7e31db2f25905bdbc3aeeb..2a970cd9fa965b4126356eaa1519068f9c7a7f34 100644 --- a/paddle/operators/ctc_align_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -69,12 +69,11 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel { auto stream = ctx.cuda_device_context().stream(); MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( - num_tokens, tokens, num_seq, input_lod[level].data(), blank, + num_tokens, tokens, num_seq, input_lod[level].cuda_data(), blank, merge_repeated, dev_out_lod0_ptr, output_data); // set output lod - thrust::host_vector host_out_lod0(dev_out_lod0.begin(), - dev_out_lod0.end()); + std::vector host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); framework::LoD out_lod; out_lod.push_back(host_out_lod0); output->set_lod(out_lod); diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index b1957fb9ce6add8628cb206abf2c569d3f615c85..a08bd4233b02d021aaa64bafe4b855f11a60d338 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -30,11 +30,12 @@ using Tensor = framework::Tensor; template inline void ReorderInitState(const DeviceContext& ctx, - const framework::Tensor& src, const size_t* index, + const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor* dst, bool indexed_src) { math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); - row_shuffle(ctx, src, index, *dst, indexed_src); + row_shuffle(ctx, src, index_lod, *dst, indexed_src); } template @@ -76,7 +77,9 @@ class GRUKernel : public framework::OpKernel { gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); Tensor ordered_h0; - const size_t* order = batch_gate->lod()[2].data(); + + framework::Vector order(batch_gate->lod()[2]); + if (h0) { // Since the batch computing for GRU reorders the input sequences // according to their length. The initialized cell state also needs @@ -159,7 +162,9 @@ class GRUGradKernel : public framework::OpKernel { zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast(0.0)); Tensor ordered_h0, ordered_h0_grad; - const size_t* order = batch_gate->lod()[2].data(); + + framework::Vector order(batch_gate->lod()[2]); + if (h0) { ReorderInitState(dev_ctx, *h0, order, &ordered_h0, true); diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index d97390fa1c53fa0bdf16ab34cb209b994621f83c..07372808bbf078bd2e9b0bb5782b95a046253f46 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -125,8 +125,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { new_rows.resize(ids_dim[0]); auto gpu_place = boost::get(context.GetPlace()); - memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data, - ids_dim[0] * sizeof(int64_t), stream); + memory::Copy(platform::CPUPlace(), new_rows.cuda_data(), gpu_place, + ids_data, ids_dim[0] * sizeof(int64_t), stream); d_table->set_rows(new_rows); diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index c57ee414dc5b3417549c8ac3a7fd57a9c8f452df..72e95b75e29c88c5944607ceaa40435bac7a745c 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -27,11 +27,12 @@ using Tensor = framework::Tensor; template inline void ReorderInitState(const DeviceContext& ctx, - const framework::Tensor& src, const size_t* index, + const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor* dst, bool indexed_src) { math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); - row_shuffle(ctx, src, index, *dst, indexed_src); + row_shuffle(ctx, src, index_lod, *dst, indexed_src); } template @@ -84,7 +85,9 @@ class LSTMKernel : public framework::OpKernel { } lstm_value.prev_state_value = nullptr; Tensor ordered_c0; - const size_t* order = batch_gate->lod()[2].data(); + + framework::Vector order(batch_gate->lod()[2]); + if (cell_t0) { // Since the batch computing for LSTM reorders the input sequence // according to their length. The initialized cell state also needs @@ -202,7 +205,8 @@ class LSTMGradKernel : public framework::OpKernel { // ordered_h0_g/c0_g is the reordered gradient of hidden/cell // initialization. Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; - const size_t* order = batch_gate->lod()[2].data(); + framework::Vector order(batch_gate->lod()[2]); + if (c0) { ReorderInitState(device_ctx, *c0, order, &ordered_c0, true); diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h index ee82d5c10a5421b181e525f49a263d4808ede62f..e064a155dfadd8104fa80727a962cb2e24ade29f 100644 --- a/paddle/operators/lstmp_op.h +++ b/paddle/operators/lstmp_op.h @@ -34,7 +34,8 @@ using EigenMatrix = framework::EigenMatrix; template inline void ReorderInitState(const DeviceContext& ctx, - const framework::Tensor& src, const size_t* index, + const framework::Tensor& src, + framework::Vector index, framework::Tensor* dst, bool indexed_src) { math::CopyMatrixRowsFunctor row_shuffle; dst->mutable_data(src.dims(), ctx.GetPlace()); @@ -109,7 +110,9 @@ class LSTMPKernel : public framework::OpKernel { } lstmp_value.prev_state_value = nullptr; Tensor ordered_c0; - const size_t* order = batch_gate->lod()[2].data(); + + framework::Vector order(batch_gate->lod()[2]); + if (cell_t0) { // Since the batch computing for LSTMP reorders the input sequence // according to their length. The initialized cell state also needs @@ -275,7 +278,9 @@ class LSTMPGradKernel : public framework::OpKernel { // ordered_h0_g/c0_g is the reordered gradient of hidden/cell // initialization. Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; - const size_t* order = batch_gate->lod()[2].data(); + + framework::Vector order(batch_gate->lod()[2]); + if (c0) { ReorderInitState(device_ctx, *c0, order, &ordered_c0, true); diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 0ee456f9bc61436bd0f2f8ef20dd1654e7e56d56..acdd87cb3550bc5f3891aed6fefd4301a3395f9f 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -31,7 +31,7 @@ struct SelectedRowsAdd { PADDLE_ENFORCE_EQ(in1_height, input2.height()); output->set_height(in1_height); - auto& in1_rows = input1.rows(); + framework::Vector in1_rows(input1.rows()); auto& in2_rows = input2.rows(); std::vector out_rows; out_rows.reserve(in1_rows.size() + in2_rows.size()); @@ -108,7 +108,7 @@ struct SelectedRowsAddTensor { PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); + framework::Vector in1_rows(input1.rows()); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); @@ -126,7 +126,7 @@ struct SelectedRowsAddTensor { dim3 grid(1, in1_rows.size()); SelectedRowsAddTensorKernel< T, block_size><<>>( - in1_data, in1_rows.data(), out_data, in1_row_numel); + in1_data, in1_rows.cuda_data(), out_data, in1_row_numel); auto out_eigen = framework::EigenVector::Flatten(*output); auto in2_eigen = framework::EigenVector::Flatten(input2); @@ -146,7 +146,7 @@ struct SelectedRowsAddTo { auto in1_height = input1.height(); PADDLE_ENFORCE_EQ(in1_height, input2->height()); - auto& in1_rows = input1.rows(); + framework::Vector in1_rows(input1.rows()); auto& in2_rows = *(input2->mutable_rows()); auto& in1_value = input1.value(); @@ -204,7 +204,7 @@ struct SelectedRowsAddToTensor { PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); + framework::Vector in1_rows(input1.rows()); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); @@ -216,7 +216,7 @@ struct SelectedRowsAddToTensor { dim3 grid(1, in1_rows.size()); SelectedRowsAddToTensorKernel< T, block_size><<>>( - in1_data, in1_rows.data(), in2_data, in1_row_numel); + in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel); } }; @@ -257,7 +257,7 @@ struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; - auto input_rows = input.rows(); + framework::Vector input_rows(input.rows()); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); @@ -283,9 +283,9 @@ struct MergeAdd { MergeAddKernel< T, 256><<(context) - .stream()>>>(input_data, input.rows().data(), out_data, - out.rows().data(), out.rows().size(), - input_width); + .stream()>>>(input_data, input_rows.cuda_data(), out_data, + out.mutable_rows()->cuda_data(), + out.rows().size(), input_width); return out; } }; @@ -370,8 +370,8 @@ struct UpdateToTensor { dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); dim3 grid(1, in1_rows.size()); UpdateToTensorKernel<<< - grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, - in2_data, in1_row_numel); + grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(), + op, in2_data, in1_row_numel); } }; } // namespace scatter diff --git a/paddle/operators/math/sequence2batch.cc b/paddle/operators/math/sequence2batch.cc index e459a42ca251a9fc79f745f48a118ce898a0f77e..17abce1c2f809f75edb2c5dc46709094c2ce10c3 100644 --- a/paddle/operators/math/sequence2batch.cc +++ b/paddle/operators/math/sequence2batch.cc @@ -23,8 +23,10 @@ template class CopyMatrixRowsFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& src, const size_t* index, - framework::Tensor& dst, bool is_src_index) { + const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor& dst, + bool is_src_index) { + size_t* index = index_lod.data(); auto src_dims = src.dims(); auto dst_dims = dst.dims(); PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, diff --git a/paddle/operators/math/sequence2batch.cu b/paddle/operators/math/sequence2batch.cu index 452ae8951000872b706f7e4227a62dbf98109e7e..f27631271a42b4d64abef00d7f119b85e32edda4 100644 --- a/paddle/operators/math/sequence2batch.cu +++ b/paddle/operators/math/sequence2batch.cu @@ -42,8 +42,10 @@ template class CopyMatrixRowsFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& src, const size_t* index, - framework::Tensor& dst, bool is_src_index) { + const framework::Tensor& src, + framework::Vector index_lod, framework::Tensor& dst, + bool is_src_index) { + size_t* index = index_lod.cuda_data(); auto src_dims = src.dims(); auto dst_dims = dst.dims(); PADDLE_ENFORCE_EQ(src_dims.size(), 2, diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index a5c43a2c7d4d729c35a20a27de2a23141e6019bc..6db0427b4174a09dd254d771e8d3d215cc6571a9 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -35,7 +35,7 @@ class CopyMatrixRowsFunctor { // copy the input src to the indexed rows of output dst. // The indexed rows are based on the input index. void operator()(const DeviceContext& context, const framework::Tensor& src, - const size_t* index, framework::Tensor& dst, + framework::Vector index_lod, framework::Tensor& dst, bool is_src_index); }; @@ -66,7 +66,7 @@ class LoDTensor2BatchFunctor { PADDLE_ENFORCE_EQ(lods[1].size(), static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_batch; - to_batch(context, lod_tensor, lods[1].data(), batch, true); + to_batch(context, lod_tensor, lods[1], batch, true); return; } @@ -144,7 +144,7 @@ class LoDTensor2BatchFunctor { batch.set_lod(batch_lods); CopyMatrixRowsFunctor to_batch; - to_batch(context, lod_tensor, seq2batch_idx, batch, true); + to_batch(context, lod_tensor, batch_lods[1], batch, true); } }; @@ -159,8 +159,7 @@ class Batch2LoDTensorFunctor { PADDLE_ENFORCE_EQ(in_lod[1].size(), static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_seq; - size_t* index = in_lod[1].data(); - to_seq(context, batch, index, lod_tensor, false); + to_seq(context, batch, in_lod[1], lod_tensor, false); } }; diff --git a/paddle/operators/math/sequence_padding.cu b/paddle/operators/math/sequence_padding.cu index a38df26f59569c4fd54a1ba5691b2cd5f3245344..65c9cfe4a0ec14d220ad237baa71703a783ed0fa 100644 --- a/paddle/operators/math/sequence_padding.cu +++ b/paddle/operators/math/sequence_padding.cu @@ -120,12 +120,14 @@ class PaddingLoDTensorFunctor { T* padding_data = padding.data(); if (norm_by_times) { SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), abs_offset_lod[level].data(), - sequence_width, max_sequence_length, num_sequences); + padding_data, const_cast(seq_data), + abs_offset_lod[level].cuda_data(), sequence_width, + max_sequence_length, num_sequences); } else { SequencePaddingKernel<<>>( - padding_data, const_cast(seq_data), abs_offset_lod[level].data(), - sequence_width, max_sequence_length, num_sequences); + padding_data, const_cast(seq_data), + abs_offset_lod[level].cuda_data(), sequence_width, + max_sequence_length, num_sequences); } } }; @@ -193,12 +195,14 @@ class UnpaddingLoDTensorFunctor { T* seq_data = seq.data(); if (norm_by_times) { SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, abs_offset_lod[level].data(), - sequence_width, max_sequence_length, num_sequences); + const_cast(padding_data), seq_data, + abs_offset_lod[level].cuda_data(), sequence_width, + max_sequence_length, num_sequences); } else { SequencePaddingKernel<<>>( - const_cast(padding_data), seq_data, abs_offset_lod[level].data(), - sequence_width, max_sequence_length, num_sequences); + const_cast(padding_data), seq_data, + abs_offset_lod[level].cuda_data(), sequence_width, + max_sequence_length, num_sequences); } } }; diff --git a/paddle/operators/math/sequence_pooling.cu b/paddle/operators/math/sequence_pooling.cu index 4c9e6b375ce7251747b9cd443d86cca0858c84ef..f66534a6812a66c737445ea96914a393077d7d65 100644 --- a/paddle/operators/math/sequence_pooling.cu +++ b/paddle/operators/math/sequence_pooling.cu @@ -73,7 +73,7 @@ class MaxSeqPoolFunctor { dim3 grid(num_seq, 1); auto stream = context.stream(); KeMaxSequencePool<<>>( - in_data, starts.data(), out_data, max_index, num_seq, dim); + in_data, starts.cuda_data(), out_data, max_index, num_seq, dim); } }; diff --git a/paddle/operators/math/sequence_scale.cu b/paddle/operators/math/sequence_scale.cu index ceaabd8e0fd81c927fbd4333c0aa7954b8da8513..fd4e28f6113729cd1fa9dc179bd9b601d29b8a7f 100644 --- a/paddle/operators/math/sequence_scale.cu +++ b/paddle/operators/math/sequence_scale.cu @@ -46,7 +46,7 @@ class ScaleLoDTensorFunctor { SequenceScaleKernel<<< num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( - seq_data, abs_offset_lod[level].data(), scales, seq_width); + seq_data, abs_offset_lod[level].cuda_data(), scales, seq_width); } }; diff --git a/paddle/operators/row_conv_op.cu b/paddle/operators/row_conv_op.cu index 41f2c5b9de91ade15b4010f56377675cfd1b611c..b3825212e1ac41b13a2f4cad2c128da39c5f6e71 100644 --- a/paddle/operators/row_conv_op.cu +++ b/paddle/operators/row_conv_op.cu @@ -307,7 +307,7 @@ class RowConvKernel int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.data(); + size_t *idx = batch_indices.cuda_data(); auto stream = context.cuda_device_context().stream(); if (future_context <= 32) { @@ -345,7 +345,7 @@ class RowConvGradKernel int input_dim = X->dims()[1]; int num_sequence = batch_indices.size() - 1; int future_context = Filter->dims()[0]; - size_t *idx = batch_indices.data(); + size_t *idx = batch_indices.cuda_data(); auto &device_ctx = context.cuda_device_context(); math::SetConstant zero; diff --git a/paddle/operators/sequence_erase_op.cu b/paddle/operators/sequence_erase_op.cu index f1e3b96acd0259de2b3ca1348834bd17e1e174a2..a5311f15f0c607c880a6f12c0bef10b2dd8c8a79 100644 --- a/paddle/operators/sequence_erase_op.cu +++ b/paddle/operators/sequence_erase_op.cu @@ -96,9 +96,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel { GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); - // Set LoD for output - thrust::host_vector out_lod0 = dev_out_lod; + std::vector out_lod0(dev_out_lod.begin(), dev_out_lod.end()); framework::LoD out_lod; out_lod.push_back(out_lod0); out->set_lod(out_lod); diff --git a/paddle/operators/sgd_op.cu b/paddle/operators/sgd_op.cu index 42f8f8b2f072f9d204dfadcd732926b5c98dc617..29f5aa3542c26c76a1b80da61ec6752019216131 100644 --- a/paddle/operators/sgd_op.cu +++ b/paddle/operators/sgd_op.cu @@ -89,7 +89,7 @@ class SGDOpCUDAKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(in_height, out_dims[0]); auto& in_value = grad->value(); - auto& in_rows = grad->rows(); + framework::Vector in_rows(grad->rows()); int64_t in_row_numel = in_value.numel() / in_rows.size(); PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height); @@ -102,7 +102,7 @@ class SGDOpCUDAKernel : public framework::OpKernel { dim3 grid(1, in_rows.size()); SparseSGDFunctorKernel< T, 256><<>>( - in_data, in_rows.data(), learning_rate->data(), out_data, + in_data, in_rows.cuda_data(), learning_rate->data(), out_data, in_row_numel); } else { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 490397afdd4de0cc1aafde746d31b1d800eded3b..a880d9bdbc63aacc1f2cdbc0d7da001a59c7b372 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -124,44 +124,25 @@ PYBIND11_PLUGIN(core) { .def( "__init__", [](LoDTensor &instance, const std::vector> &lod) { -#ifndef PADDLE_WITH_CUDA - new (&instance) LoDTensor(lod); -#else - LoD new_lod; - new_lod.reserve(lod.size()); - std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); - new (&instance) LoDTensor(new_lod); -#endif + LoD new_lod; + new_lod.reserve(lod.size()); + std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); + new (&instance) LoDTensor(new_lod); }) .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); }) .def("set_lod", [](LoDTensor &self, const std::vector> &lod) { -#ifndef PADDLE_WITH_CUDA - self.set_lod(lod); -#else LoD new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); self.set_lod(new_lod); -#endif }) .def("lod", [](LoDTensor &self) -> std::vector> { -#ifndef PADDLE_WITH_CUDA - return self.lod(); -#else - auto lod = self.lod(); - std::vector> new_lod; - new_lod.reserve(lod.size()); - std::transform(lod.begin(), lod.end(), std::back_inserter(new_lod), - [](Vector item) -> - std::vector { - std::vector v; - v.reserve(item.size()); - std::copy(item.begin(), item.end(), std::back_inserter(v)); - return v; - }); - return new_lod; -#endif + auto lod = self.lod(); + std::vector> new_lod; + new_lod.reserve(lod.size()); + std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); + return new_lod; }); py::class_(m, "SelectedRows") diff --git a/python/paddle/v2/fluid/tests/test_tensor.py b/python/paddle/v2/fluid/tests/test_tensor.py index d5cc235f588ad37b0d1293dc9894952c97411757..0219bef42b3ba133dda7412c1036cf989a170a36 100644 --- a/python/paddle/v2/fluid/tests/test_tensor.py +++ b/python/paddle/v2/fluid/tests/test_tensor.py @@ -108,9 +108,31 @@ class TestTensor(unittest.TestCase): scope = core.Scope() place = core.CPUPlace() lod_py = [[0, 2, 5], [0, 2, 4, 5]] - lod_tensor = core.LoDTensor(lod_py) + lod_tensor = core.LoDTensor() lod_tensor.set_dims([5, 2, 3, 4]) + lod_tensor.set_lod(lod_py) + lod_tensor.alloc_float(place) + tensor_array = numpy.array(lod_tensor) + tensor_array[0, 0, 0, 0] = 1.0 + tensor_array[0, 0, 0, 1] = 2.0 + lod_tensor.set(tensor_array, place) + + lod_v = numpy.array(lod_tensor) + self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) + self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) + self.assertListEqual(lod_py, lod_tensor.lod()) + + def test_lod_tensor_gpu_init(self): + if not core.is_compiled_with_cuda(): + return + scope = core.Scope() + place = core.CUDAPlace(0) + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor = core.LoDTensor() + + lod_tensor.set_dims([5, 2, 3, 4]) + lod_tensor.set_lod(lod_py) lod_tensor.alloc_float(place) tensor_array = numpy.array(lod_tensor) tensor_array[0, 0, 0, 0] = 1.0