提交 fbdb5b7b 编写于 作者: D dzhwinter

"fix based on comment"

上级 a80bf702
...@@ -25,27 +25,17 @@ using LoDTensor = framework::LoDTensor; ...@@ -25,27 +25,17 @@ using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
const size_t* ref_lod, const size_t* ref_lod,
const size_t* offset,
const size_t lod_size, const size_t lod_size,
/* default=1, /* default=1,
the instance length*/ the instance length*/
const int x_item_length, T* out_data) { const int x_item_length, T* out_data) {
constexpr int N = 1024;
__shared__ int mem[N];
int offset = 0;
for (int i = 0; i < lod_size; ++i) {
mem[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
}
}
__syncthreads();
int bid = blockIdx.x; int bid = blockIdx.x;
if (bid >= lod_size - 1) return; if (bid >= lod_size - 1) return;
int x_item_count = x_lod[bid + 1] - x_lod[bid]; int x_item_count = x_lod[bid + 1] - x_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid]; int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = mem[bid]; int out_offset = static_cast<int>(offset[bid]);
int x_offset = x_lod[bid]; int x_offset = x_lod[bid];
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
...@@ -59,32 +49,17 @@ __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, ...@@ -59,32 +49,17 @@ __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
} }
template <typename T> template <typename T>
__global__ void sequence_expand_grad_kernel(const T* dout_data, __global__ void sequence_expand_grad_kernel(
const size_t* ref_lod, const T* dout_data, const size_t* ref_lod, const size_t* dx_lod,
const size_t* dx_lod, const size_t* offset, const size_t lod_size,
const size_t lod_size,
/* default=1, /* default=1,
the instance length*/ the instance length*/
const int x_item_length, const int x_item_length, T* dx_data) {
T* dx_data) {
// TODO(dzhwinter) : too many atomicAdd
// use shared memory to reduce memory visits
constexpr int N = 1024;
__shared__ int mem[N];
int offset = 0;
for (int i = 0; i < lod_size; ++i) {
mem[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]);
}
}
__syncthreads();
int bid = blockIdx.x; int bid = blockIdx.x;
if (bid >= lod_size - 1) return; if (bid >= lod_size - 1) return;
int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid]; int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = mem[bid]; int out_offset = static_cast<int>(offset[bid]);
int x_offset = dx_lod[bid]; int x_offset = dx_lod[bid];
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
...@@ -101,6 +76,19 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, ...@@ -101,6 +76,19 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data,
} }
} }
void GetOutputOffset(const framework::Vector<size_t>& x_lod,
const framework::Vector<size_t>& ref_lod,
framework::Vector<size_t>& out_offset) {
size_t offset = 0;
int lod_size = static_cast<int>(x_lod.size());
for (int i = 0; i < static_cast<int>(x_lod.size()); ++i) {
out_offset[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
}
}
}
template <typename T> template <typename T>
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
void operator()( void operator()(
...@@ -109,6 +97,9 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { ...@@ -109,6 +97,9 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/ const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) { LoDTensor* out) {
int x_item_length = x.numel() / x.dims()[0]; int x_item_length = x.numel() / x.dims()[0];
framework::Vector<size_t> out_offset(x_lod.size());
GetOutputOffset(x_lod, ref_lod, out_offset);
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16; int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y; int thread_z = 1024 / thread_x / thread_y;
...@@ -118,7 +109,8 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { ...@@ -118,7 +109,8 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>( sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
x.data<T>(), x_lod.CUDAData(context.GetPlace()), x.data<T>(), x_lod.CUDAData(context.GetPlace()),
ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, ref_lod.CUDAData(context.GetPlace()),
out_offset.CUDAData(context.GetPlace()), x_lod.size(), x_item_length,
out->mutable_data<T>(context.GetPlace())); out->mutable_data<T>(context.GetPlace()));
} }
}; };
...@@ -131,6 +123,9 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> { ...@@ -131,6 +123,9 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
const framework::Vector<size_t>& ref_lod, /*expand based lod*/ const framework::Vector<size_t>& ref_lod, /*expand based lod*/
LoDTensor* dx) { LoDTensor* dx) {
int x_item_length = framework::product(dx->dims()) / dx->dims()[0]; int x_item_length = framework::product(dx->dims()) / dx->dims()[0];
framework::Vector<size_t> out_offset(x_lod.size());
GetOutputOffset(x_lod, ref_lod, out_offset);
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16; int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y; int thread_z = 1024 / thread_x / thread_y;
...@@ -139,7 +134,8 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> { ...@@ -139,7 +134,8 @@ struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
dim3 grid_size(block_x, 1); dim3 grid_size(block_x, 1);
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>( sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()), dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, x_lod.CUDAData(context.GetPlace()),
out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
dx->mutable_data<T>(context.GetPlace())); dx->mutable_data<T>(context.GetPlace()));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册