未验证 提交 ae7d1c1f 编写于 作者: D dzhwinter 提交者: GitHub

Fix/lod (#7714)

* "Need to re-design LoD "

* "add lod design"

* "fix lod gpu ptr pointer"

* "removed commented code"

* "fix CI"

* "remove set lod in pybind"

* "fix style check"

* "fix CI"

* "fix long type template error"

* "pybind reorder to use Place"

* "fix ci"

* "fix ci"

* fix ci

* "sperate as a new file"

* "fix CI"

* "fix ci"

* small fix

* "add test"

* "fix adam op"

* "fix lstmp op"

* "fix adam op"

* "follow comments"

* "fix ci"
上级 865a714e
...@@ -22,7 +22,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) ...@@ -22,7 +22,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -24,8 +24,6 @@ limitations under the License. */ ...@@ -24,8 +24,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <glog/logging.h>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -18,11 +18,11 @@ limitations under the License. */ ...@@ -18,11 +18,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#endif #endif
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/mixed_vector.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h" #include "paddle/framework/tensor_util.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -31,15 +31,6 @@ limitations under the License. */ ...@@ -31,15 +31,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifndef PADDLE_WITH_CUDA
template <typename T>
using Vector = std::vector<T>;
#else
template <typename T>
using Vector = thrust::host_vector<
T, thrust::system::cuda::experimental::pinned_allocator<T>>;
#endif
/* /*
* LoD is short for Level of Details. * LoD is short for Level of Details.
* *
...@@ -55,7 +46,15 @@ using Vector = thrust::host_vector< ...@@ -55,7 +46,15 @@ using Vector = thrust::host_vector<
* 0 2 4 7 * 0 2 4 7
* 0 2 5 7 10 12 15 20 * 0 2 5 7 10 12 15 20
*/ */
using LoD = std::vector<Vector<size_t>>; struct LoD : public std::vector<Vector<size_t>> {
using std::vector<Vector<size_t>>::vector;
void CopyFromCUDA() {
for (auto it = this->begin(); it != this->end(); ++it) {
it->CopyFromCUDA();
}
}
};
std::ostream& operator<<(std::ostream& os, const LoD& lod); std::ostream& operator<<(std::ostream& os, const LoD& lod);
std::ostream& operator<<(std::ostream& os, const LoDTensor& t); std::ostream& operator<<(std::ostream& os, const LoDTensor& t);
...@@ -109,7 +108,10 @@ bool CheckAbsLoD(const LoD& in, int tensor_height = -1); ...@@ -109,7 +108,10 @@ bool CheckAbsLoD(const LoD& in, int tensor_height = -1);
*/ */
class LoDTensor : public Tensor { class LoDTensor : public Tensor {
public: public:
LoDTensor() {} LoDTensor() : Tensor() {}
/* Constructor with place should only be used in pybind */
explicit LoDTensor(const platform::Place& place) : Tensor(place) {}
explicit LoDTensor(const LoD& lod) : lod_(lod) {} explicit LoDTensor(const LoD& lod) : lod_(lod) {}
......
...@@ -23,6 +23,17 @@ ...@@ -23,6 +23,17 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(LoD, data) {
LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));
auto& v = lod[0];
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i);
}
}
TEST(LodExpand, test) { TEST(LodExpand, test) {
LoD lod{{0, 2}}; LoD lod{{0, 2}};
LoDTensor tensor; LoDTensor tensor;
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h>
#include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
...@@ -26,7 +28,48 @@ __global__ void test(size_t* a, int size) { ...@@ -26,7 +28,48 @@ __global__ void test(size_t* a, int size) {
} }
} }
TEST(Vector, Normal) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::memory;
paddle::framework::InitDevices();
paddle::framework::Vector<size_t> vec({1, 2, 3});
size_t* ptr = vec.data();
for (size_t i = 0; i < vec.size(); ++i) {
EXPECT_EQ(vec[i], *(ptr + i));
}
vec.clear();
vec.CopyFromCUDA();
std::vector<size_t> v = {1, 2, 3};
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], vec[i]);
}
}
TEST(LoD, data) {
paddle::framework::InitDevices();
paddle::framework::LoD lod{{0, 1, 2}};
lod.push_back({0, 2, 4, 5});
lod.push_back(std::vector<size_t>({0, 1, 6, 8, 10, 11}));
auto& v = lod[0];
test<<<1, 1>>>(v.cuda_data(), v.size());
cudaDeviceSynchronize();
v.CopyFromCUDA();
for (size_t i = 0; i < v.size(); ++i) {
EXPECT_EQ(v[i], i * 2);
}
}
TEST(LoDTensor, LoDInGPU) { TEST(LoDTensor, LoDInGPU) {
paddle::framework::InitDevices();
paddle::framework::LoDTensor lod_tensor; paddle::framework::LoDTensor lod_tensor;
paddle::platform::CUDAPlace place(0); paddle::platform::CUDAPlace place(0);
...@@ -42,8 +85,9 @@ TEST(LoDTensor, LoDInGPU) { ...@@ -42,8 +85,9 @@ TEST(LoDTensor, LoDInGPU) {
auto lod = lod_tensor.lod(); auto lod = lod_tensor.lod();
test<<<1, 8>>>(lod[0].data(), lod[0].size()); test<<<1, 8>>>(lod[0].cuda_data(), lod[0].size());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
lod.CopyFromCUDA();
for (size_t i = 0; i < src_lod[0].size(); ++i) { for (size_t i = 0; i < src_lod[0].size(); ++i) {
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <initializer_list>
#include <vector>
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace framework {
/**
* @brief Vector support both cpu and gpu.
* host vector lifetime is same with Vector
* device vector is lazily malloc and modified.
*/
template <typename T>
class Vector : public std::vector<T> {
public:
/* NOTE(dzhwinter):
* Data always store and modified on Host.
* If the data is modified when use cuda_data interface,
* You need to call the CopyFromCUDA explicitly to synchronize data.
*
*/
enum class kDataPosition {
kDataOnHost = 0,
kDataOnDevice = 1,
};
public:
using std::vector<T>::vector;
Vector() {}
Vector(const std::vector<T> &v) : std::vector<T>(v) {} // NOLINT
virtual ~Vector() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ != nullptr) {
memory::Free<platform::CUDAPlace>(place_, static_cast<void *>(cuda_ptr_));
}
#endif
}
T *cuda_data() {
CopyToCUDA();
PADDLE_ENFORCE_NOT_NULL(
cuda_ptr_, "No data or Insufficient CUDA memory to allocation");
return static_cast<T *>(cuda_ptr_);
}
T *data() { return std::vector<T>::data(); }
const T *data() const { return std::vector<T>::data(); }
void CopyToCUDA();
void CopyFromCUDA();
void CopyToPeer(platform::Place);
private:
void *cuda_ptr_ = nullptr;
size_t cuda_size_ = 0;
/*The DataPosition is unused now,
if we want support random access from cpu and cuda,
we need to overload all the vector method */
kDataPosition position_ = kDataPosition::kDataOnHost;
platform::CUDAPlace place_;
};
template <typename T>
void Vector<T>::CopyToCUDA() {
#ifdef PADDLE_WITH_CUDA
if (cuda_ptr_ == nullptr) {
cuda_ptr_ =
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T));
}
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *cuda_ctx = pool.GetByPlace(place_);
memory::Copy(place_, static_cast<void *>(cuda_ptr_), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), cuda_ctx->stream());
cuda_ctx->Wait();
cuda_size_ = this->size();
#endif
}
template <typename T>
void Vector<T>::CopyFromCUDA() {
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *cuda_ctx = pool.GetByPlace(place_);
if (cuda_ptr_ == nullptr) {
LOG(WARNING) << "No uncommited cuda data.";
return;
}
this->resize(cuda_size_);
memory::Copy(platform::CPUPlace(), static_cast<void *>(this->data()), place_,
static_cast<const void *>(cuda_ptr_), this->size() * sizeof(T),
cuda_ctx->stream());
cuda_ctx->Wait();
#endif
}
template <typename T>
void Vector<T>::CopyToPeer(platform::Place peer_place) {
if (platform::is_cpu_place(peer_place)) {
return;
}
#ifdef PADDLE_WITH_CUDA
auto *cuda_ctx = platform::DeviceContextPool::Instance().GetByPlace(place_);
void *peer_cuda_ptr_ = memory::Alloc<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(peer_place), this->size() * sizeof(T));
memory::Copy(boost::get<platform::CUDAPlace>(peer_place),
static_cast<void *>(peer_cuda_ptr_), place_,
static_cast<const void *>(cuda_ptr_), this->size() * sizeof(T),
cuda_ctx->stream());
cuda_ctx->Wait();
memory::Free<platform::CUDAPlace>(place_, static_cast<void *>(cuda_ptr_));
place_ = boost::get<platform::CUDAPlace>(peer_place);
cuda_ptr_ = peer_cuda_ptr_;
#endif
}
template class Vector<int>;
template class Vector<unsigned>;
template class Vector<size_t>;
template class Vector<int64_t>;
} // namespace framework
} // namespace paddle
...@@ -47,6 +47,11 @@ class Tensor { ...@@ -47,6 +47,11 @@ class Tensor {
public: public:
Tensor() : offset_(0) {} Tensor() : offset_(0) {}
/*! Constructor with place should only be used in pybind. */
explicit Tensor(const platform::Place& place) : offset_(0) {
holder_->set_place(place);
}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
inline T* data(); inline T* data();
...@@ -137,6 +142,7 @@ class Tensor { ...@@ -137,6 +142,7 @@ class Tensor {
virtual std::type_index type() const = 0; virtual std::type_index type() const = 0;
virtual platform::Place place() const = 0; virtual platform::Place place() const = 0;
virtual void set_type(std::type_index type) = 0; virtual void set_type(std::type_index type) = 0;
virtual void set_place(platform::Place place) = 0;
}; };
template <typename Place> template <typename Place>
...@@ -156,6 +162,7 @@ class Tensor { ...@@ -156,6 +162,7 @@ class Tensor {
virtual void* ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* ptr() const { return static_cast<void*>(ptr_.get()); }
virtual std::type_index type() const { return type_; } virtual std::type_index type() const { return type_; }
virtual void set_type(std::type_index type) { type_ = type; } virtual void set_type(std::type_index type) { type_ = type; }
virtual void set_place(platform::Place place) { place_ = place; }
/*! the pointer of memory block. */ /*! the pointer of memory block. */
std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t, Place>> ptr_; std::unique_ptr<uint8_t, memory::PODDeleter<uint8_t, Place>> ptr_;
......
set(FLUID_CORE_MODULES proto_desc paddle_memory executor prune init) set(FLUID_CORE_MODULES proto_desc paddle_memory lod_tensor executor prune init)
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS io.cc SRCS io.cc
......
...@@ -82,7 +82,7 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -82,7 +82,7 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func; math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad); auto grad_merge = merge_func(context, grad);
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>(); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
auto& merge_rows = grad_merge.rows(); framework::Vector<int64_t> merge_rows(grad_merge.rows());
// 2. m += g_m * g_m // 2. m += g_m * g_m
math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func; math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge); auto grad_square = sqare_func(context, grad_merge, grad_merge);
...@@ -101,8 +101,8 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -101,8 +101,8 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
SparseAdagradFunctorKernel< SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge.rows().data(), .stream()>>>(grad_merge_data, merge_rows.cuda_data(), lr,
lr, param_data, moment_data, grad_width, param_data, moment_data, grad_width,
epsilon); epsilon);
} }
}; };
......
...@@ -199,7 +199,12 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -199,7 +199,12 @@ class AdamOpKernel : public framework::OpKernel<T> {
merge_func(ctx.template device_context<DeviceContext>(), grad); merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value(); auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad_merge.rows().data(); int64_t* rows = nullptr;
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = grad_merge.mutable_rows()->cuda_data();
} else {
rows = grad_merge.mutable_rows()->data();
}
auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor( SparseAdamFunctor<T> functor(
......
...@@ -69,12 +69,11 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { ...@@ -69,12 +69,11 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>( MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].data(), blank, num_tokens, tokens, num_seq, input_lod[level].cuda_data(), blank,
merge_repeated, dev_out_lod0_ptr, output_data); merge_repeated, dev_out_lod0_ptr, output_data);
// set output lod // set output lod
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(), std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end());
dev_out_lod0.end());
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(host_out_lod0); out_lod.push_back(host_out_lod0);
output->set_lod(out_lod); output->set_lod(out_lod);
......
...@@ -30,11 +30,12 @@ using Tensor = framework::Tensor; ...@@ -30,11 +30,12 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx, inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src); row_shuffle(ctx, src, index_lod, *dst, indexed_src);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -76,7 +77,9 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -76,7 +77,9 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0; Tensor ordered_h0;
const size_t* order = batch_gate->lod()[2].data();
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (h0) { if (h0) {
// Since the batch computing for GRU reorders the input sequences // Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs // according to their length. The initialized cell state also needs
...@@ -159,7 +162,9 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -159,7 +162,9 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0)); zero(dev_ctx, &batch_reset_hidden_prev_grad, static_cast<T>(0.0));
Tensor ordered_h0, ordered_h0_grad; Tensor ordered_h0, ordered_h0_grad;
const size_t* order = batch_gate->lod()[2].data();
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (h0) { if (h0) {
ReorderInitState<DeviceContext, T>(dev_ctx, *h0, order, &ordered_h0, ReorderInitState<DeviceContext, T>(dev_ctx, *h0, order, &ordered_h0,
true); true);
......
...@@ -125,8 +125,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -125,8 +125,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
new_rows.resize(ids_dim[0]); new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace()); auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data, memory::Copy(platform::CPUPlace(), new_rows.cuda_data(), gpu_place,
ids_dim[0] * sizeof(int64_t), stream); ids_data, ids_dim[0] * sizeof(int64_t), stream);
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
......
...@@ -27,11 +27,12 @@ using Tensor = framework::Tensor; ...@@ -27,11 +27,12 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx, inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src,
framework::Vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src); row_shuffle(ctx, src, index_lod, *dst, indexed_src);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -84,7 +85,9 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -84,7 +85,9 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
lstm_value.prev_state_value = nullptr; lstm_value.prev_state_value = nullptr;
Tensor ordered_c0; Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (cell_t0) { if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence // Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs // according to their length. The initialized cell state also needs
...@@ -202,7 +205,8 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -202,7 +205,8 @@ class LSTMGradKernel : public framework::OpKernel<T> {
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization. // initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
const size_t* order = batch_gate->lod()[2].data(); framework::Vector<size_t> order(batch_gate->lod()[2]);
if (c0) { if (c0) {
ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0, ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0,
true); true);
......
...@@ -34,7 +34,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -34,7 +34,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx, inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src,
framework::Vector<size_t> index,
framework::Tensor* dst, bool indexed_src) { framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle; math::CopyMatrixRowsFunctor<DeviceContext, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace()); dst->mutable_data<T>(src.dims(), ctx.GetPlace());
...@@ -109,7 +110,9 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -109,7 +110,9 @@ class LSTMPKernel : public framework::OpKernel<T> {
} }
lstmp_value.prev_state_value = nullptr; lstmp_value.prev_state_value = nullptr;
Tensor ordered_c0; Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (cell_t0) { if (cell_t0) {
// Since the batch computing for LSTMP reorders the input sequence // Since the batch computing for LSTMP reorders the input sequence
// according to their length. The initialized cell state also needs // according to their length. The initialized cell state also needs
...@@ -275,7 +278,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -275,7 +278,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization. // initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g; Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
const size_t* order = batch_gate->lod()[2].data();
framework::Vector<size_t> order(batch_gate->lod()[2]);
if (c0) { if (c0) {
ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0, ReorderInitState<DeviceContext, T>(device_ctx, *c0, order, &ordered_c0,
true); true);
......
...@@ -31,7 +31,7 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> { ...@@ -31,7 +31,7 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
PADDLE_ENFORCE_EQ(in1_height, input2.height()); PADDLE_ENFORCE_EQ(in1_height, input2.height());
output->set_height(in1_height); output->set_height(in1_height);
auto& in1_rows = input1.rows(); framework::Vector<int64_t> in1_rows(input1.rows());
auto& in2_rows = input2.rows(); auto& in2_rows = input2.rows();
std::vector<int64_t> out_rows; std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size()); out_rows.reserve(in1_rows.size() + in2_rows.size());
...@@ -108,7 +108,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -108,7 +108,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value(); auto& in1_value = input1.value();
auto& in1_rows = input1.rows(); framework::Vector<int64_t> in1_rows(input1.rows());
int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
...@@ -126,7 +126,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -126,7 +126,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddTensorKernel< SelectedRowsAddTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.data(), out_data, in1_row_numel); in1_data, in1_rows.cuda_data(), out_data, in1_row_numel);
auto out_eigen = framework::EigenVector<T>::Flatten(*output); auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2); auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
...@@ -146,7 +146,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { ...@@ -146,7 +146,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2->height()); PADDLE_ENFORCE_EQ(in1_height, input2->height());
auto& in1_rows = input1.rows(); framework::Vector<int64_t> in1_rows(input1.rows());
auto& in2_rows = *(input2->mutable_rows()); auto& in2_rows = *(input2->mutable_rows());
auto& in1_value = input1.value(); auto& in1_value = input1.value();
...@@ -204,7 +204,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> { ...@@ -204,7 +204,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value(); auto& in1_value = input1.value();
auto& in1_rows = input1.rows(); framework::Vector<int64_t> in1_rows(input1.rows());
int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
...@@ -216,7 +216,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> { ...@@ -216,7 +216,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
SelectedRowsAddToTensorKernel< SelectedRowsAddToTensorKernel<
T, block_size><<<grid, threads, 0, context.stream()>>>( T, block_size><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.data(), in2_data, in1_row_numel); in1_data, in1_rows.cuda_data(), in2_data, in1_row_numel);
} }
}; };
...@@ -257,7 +257,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -257,7 +257,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
auto input_rows = input.rows(); framework::Vector<int64_t> input_rows(input.rows());
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
...@@ -283,9 +283,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -283,9 +283,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
MergeAddKernel< MergeAddKernel<
T, 256><<<grid1, threads, 0, T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data, .stream()>>>(input_data, input_rows.cuda_data(), out_data,
out.rows().data(), out.rows().size(), out.mutable_rows()->cuda_data(),
input_width); out.rows().size(), input_width);
return out; return out;
} }
}; };
...@@ -370,8 +370,8 @@ struct UpdateToTensor<platform::CUDADeviceContext, T> { ...@@ -370,8 +370,8 @@ struct UpdateToTensor<platform::CUDADeviceContext, T> {
dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<< UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, grid, threads, 0, context.stream()>>>(in1_data, in1_rows.cuda_data(),
in2_data, in1_row_numel); op, in2_data, in1_row_numel);
} }
}; };
} // namespace scatter } // namespace scatter
......
...@@ -23,8 +23,10 @@ template <typename T> ...@@ -23,8 +23,10 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> { class CopyMatrixRowsFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src,
framework::Tensor& dst, bool is_src_index) { framework::Vector<size_t> index_lod, framework::Tensor& dst,
bool is_src_index) {
size_t* index = index_lod.data();
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
......
...@@ -42,8 +42,10 @@ template <typename T> ...@@ -42,8 +42,10 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> { class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src,
framework::Tensor& dst, bool is_src_index) { framework::Vector<size_t> index_lod, framework::Tensor& dst,
bool is_src_index) {
size_t* index = index_lod.cuda_data();
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(), 2,
......
...@@ -35,7 +35,7 @@ class CopyMatrixRowsFunctor { ...@@ -35,7 +35,7 @@ class CopyMatrixRowsFunctor {
// copy the input src to the indexed rows of output dst. // copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index. // The indexed rows are based on the input index.
void operator()(const DeviceContext& context, const framework::Tensor& src, void operator()(const DeviceContext& context, const framework::Tensor& src,
const size_t* index, framework::Tensor& dst, framework::Vector<size_t> index_lod, framework::Tensor& dst,
bool is_src_index); bool is_src_index);
}; };
...@@ -66,7 +66,7 @@ class LoDTensor2BatchFunctor { ...@@ -66,7 +66,7 @@ class LoDTensor2BatchFunctor {
PADDLE_ENFORCE_EQ(lods[1].size(), PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<DeviceContext, T> to_batch; CopyMatrixRowsFunctor<DeviceContext, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true); to_batch(context, lod_tensor, lods[1], batch, true);
return; return;
} }
...@@ -144,7 +144,7 @@ class LoDTensor2BatchFunctor { ...@@ -144,7 +144,7 @@ class LoDTensor2BatchFunctor {
batch.set_lod(batch_lods); batch.set_lod(batch_lods);
CopyMatrixRowsFunctor<DeviceContext, T> to_batch; CopyMatrixRowsFunctor<DeviceContext, T> to_batch;
to_batch(context, lod_tensor, seq2batch_idx, batch, true); to_batch(context, lod_tensor, batch_lods[1], batch, true);
} }
}; };
...@@ -159,8 +159,7 @@ class Batch2LoDTensorFunctor { ...@@ -159,8 +159,7 @@ class Batch2LoDTensorFunctor {
PADDLE_ENFORCE_EQ(in_lod[1].size(), PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<DeviceContext, T> to_seq; CopyMatrixRowsFunctor<DeviceContext, T> to_seq;
size_t* index = in_lod[1].data(); to_seq(context, batch, in_lod[1], lod_tensor, false);
to_seq(context, batch, index, lod_tensor, false);
} }
}; };
......
...@@ -120,12 +120,14 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -120,12 +120,14 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
T* padding_data = padding.data<T>(); T* padding_data = padding.data<T>();
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(), padding_data, const_cast<T*>(seq_data),
sequence_width, max_sequence_length, num_sequences); abs_offset_lod[level].cuda_data(), sequence_width,
max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(), padding_data, const_cast<T*>(seq_data),
sequence_width, max_sequence_length, num_sequences); abs_offset_lod[level].cuda_data(), sequence_width,
max_sequence_length, num_sequences);
} }
} }
}; };
...@@ -193,12 +195,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -193,12 +195,14 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
T* seq_data = seq.data<T>(); T* seq_data = seq.data<T>();
if (norm_by_times) { if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(), const_cast<T*>(padding_data), seq_data,
sequence_width, max_sequence_length, num_sequences); abs_offset_lod[level].cuda_data(), sequence_width,
max_sequence_length, num_sequences);
} else { } else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>( SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(), const_cast<T*>(padding_data), seq_data,
sequence_width, max_sequence_length, num_sequences); abs_offset_lod[level].cuda_data(), sequence_width,
max_sequence_length, num_sequences);
} }
} }
}; };
......
...@@ -73,7 +73,7 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> { ...@@ -73,7 +73,7 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> {
dim3 grid(num_seq, 1); dim3 grid(num_seq, 1);
auto stream = context.stream(); auto stream = context.stream();
KeMaxSequencePool<T><<<grid, threads, 0, stream>>>( KeMaxSequencePool<T><<<grid, threads, 0, stream>>>(
in_data, starts.data(), out_data, max_index, num_seq, dim); in_data, starts.cuda_data(), out_data, max_index, num_seq, dim);
} }
}; };
......
...@@ -46,7 +46,7 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -46,7 +46,7 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<< SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>( num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].data(), scales, seq_width); seq_data, abs_offset_lod[level].cuda_data(), scales, seq_width);
} }
}; };
......
...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T> ...@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.data(); size_t *idx = batch_indices.cuda_data();
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
if (future_context <= 32) { if (future_context <= 32) {
...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T> ...@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T>
int input_dim = X->dims()[1]; int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.data(); size_t *idx = batch_indices.cuda_data();
auto &device_ctx = context.cuda_device_context(); auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::CUDADeviceContext, T> zero; math::SetConstant<platform::CUDADeviceContext, T> zero;
......
...@@ -96,9 +96,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { ...@@ -96,9 +96,8 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr); num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
// Set LoD for output // Set LoD for output
thrust::host_vector<size_t> out_lod0 = dev_out_lod; std::vector<size_t> out_lod0(dev_out_lod.begin(), dev_out_lod.end());
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(out_lod0); out_lod.push_back(out_lod0);
out->set_lod(out_lod); out->set_lod(out_lod);
......
...@@ -89,7 +89,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -89,7 +89,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(in_height, out_dims[0]); PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
auto& in_value = grad->value(); auto& in_value = grad->value();
auto& in_rows = grad->rows(); framework::Vector<int64_t> in_rows(grad->rows());
int64_t in_row_numel = in_value.numel() / in_rows.size(); int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height); PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
...@@ -102,7 +102,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -102,7 +102,7 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
dim3 grid(1, in_rows.size()); dim3 grid(1, in_rows.size());
SparseSGDFunctorKernel< SparseSGDFunctorKernel<
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>( T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
in_data, in_rows.data(), learning_rate->data<T>(), out_data, in_data, in_rows.cuda_data(), learning_rate->data<T>(), out_data,
in_row_numel); in_row_numel);
} else { } else {
......
...@@ -124,44 +124,25 @@ PYBIND11_PLUGIN(core) { ...@@ -124,44 +124,25 @@ PYBIND11_PLUGIN(core) {
.def( .def(
"__init__", "__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) { [](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) {
#ifndef PADDLE_WITH_CUDA LoD new_lod;
new (&instance) LoDTensor(lod); new_lod.reserve(lod.size());
#else std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
LoD new_lod; new (&instance) LoDTensor(new_lod);
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
new (&instance) LoDTensor(new_lod);
#endif
}) })
.def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); }) .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
.def("set_lod", .def("set_lod",
[](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) { [](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
#ifndef PADDLE_WITH_CUDA
self.set_lod(lod);
#else
LoD new_lod; LoD new_lod;
new_lod.reserve(lod.size()); new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
self.set_lod(new_lod); self.set_lod(new_lod);
#endif
}) })
.def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> { .def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> {
#ifndef PADDLE_WITH_CUDA auto lod = self.lod();
return self.lod(); std::vector<std::vector<size_t>> new_lod;
#else new_lod.reserve(lod.size());
auto lod = self.lod(); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
std::vector<std::vector<size_t>> new_lod; return new_lod;
new_lod.reserve(lod.size());
std::transform(lod.begin(), lod.end(), std::back_inserter(new_lod),
[](Vector<size_t> item) ->
std::vector<size_t> {
std::vector<size_t> v;
v.reserve(item.size());
std::copy(item.begin(), item.end(), std::back_inserter(v));
return v;
});
return new_lod;
#endif
}); });
py::class_<SelectedRows>(m, "SelectedRows") py::class_<SelectedRows>(m, "SelectedRows")
......
...@@ -108,9 +108,31 @@ class TestTensor(unittest.TestCase): ...@@ -108,9 +108,31 @@ class TestTensor(unittest.TestCase):
scope = core.Scope() scope = core.Scope()
place = core.CPUPlace() place = core.CPUPlace()
lod_py = [[0, 2, 5], [0, 2, 4, 5]] lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor = core.LoDTensor(lod_py) lod_tensor = core.LoDTensor()
lod_tensor.set_dims([5, 2, 3, 4]) lod_tensor.set_dims([5, 2, 3, 4])
lod_tensor.set_lod(lod_py)
lod_tensor.alloc_float(place)
tensor_array = numpy.array(lod_tensor)
tensor_array[0, 0, 0, 0] = 1.0
tensor_array[0, 0, 0, 1] = 2.0
lod_tensor.set(tensor_array, place)
lod_v = numpy.array(lod_tensor)
self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0])
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertListEqual(lod_py, lod_tensor.lod())
def test_lod_tensor_gpu_init(self):
if not core.is_compiled_with_cuda():
return
scope = core.Scope()
place = core.CUDAPlace(0)
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor = core.LoDTensor()
lod_tensor.set_dims([5, 2, 3, 4])
lod_tensor.set_lod(lod_py)
lod_tensor.alloc_float(place) lod_tensor.alloc_float(place)
tensor_array = numpy.array(lod_tensor) tensor_array = numpy.array(lod_tensor)
tensor_array[0, 0, 0, 0] = 1.0 tensor_array[0, 0, 0, 0] = 1.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册