diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 8a35bc908e8d73f9bdc5f0fbf268fb40e5d228ce..8119afce1a4827be11552c704e75ce14657bd425 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -25,27 +25,17 @@ using LoDTensor = framework::LoDTensor; template __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, const size_t* ref_lod, + const size_t* offset, const size_t lod_size, /* default=1, the instance length*/ const int x_item_length, T* out_data) { - constexpr int N = 1024; - __shared__ int mem[N]; - int offset = 0; - for (int i = 0; i < lod_size; ++i) { - mem[i] = offset; - if (i < lod_size - 1) { - offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); - } - } - __syncthreads(); - int bid = blockIdx.x; if (bid >= lod_size - 1) return; int x_item_count = x_lod[bid + 1] - x_lod[bid]; int repeats = ref_lod[bid + 1] - ref_lod[bid]; - int out_offset = mem[bid]; + int out_offset = static_cast(offset[bid]); int x_offset = x_lod[bid]; for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { @@ -59,32 +49,17 @@ __global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, } template -__global__ void sequence_expand_grad_kernel(const T* dout_data, - const size_t* ref_lod, - const size_t* dx_lod, - const size_t lod_size, - /* default=1, - the instance length*/ - const int x_item_length, - T* dx_data) { - // TODO(dzhwinter) : too many atomicAdd - // use shared memory to reduce memory visits - constexpr int N = 1024; - __shared__ int mem[N]; - int offset = 0; - for (int i = 0; i < lod_size; ++i) { - mem[i] = offset; - if (i < lod_size - 1) { - offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]); - } - } - __syncthreads(); - +__global__ void sequence_expand_grad_kernel( + const T* dout_data, const size_t* ref_lod, const size_t* dx_lod, + const size_t* offset, const size_t lod_size, + /* default=1, + the instance length*/ + const int x_item_length, T* dx_data) { int bid = blockIdx.x; if (bid >= lod_size - 1) return; int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; int repeats = ref_lod[bid + 1] - ref_lod[bid]; - int out_offset = mem[bid]; + int out_offset = static_cast(offset[bid]); int x_offset = dx_lod[bid]; for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { @@ -101,6 +76,19 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, } } +void GetOutputOffset(const framework::Vector& x_lod, + const framework::Vector& ref_lod, + framework::Vector& out_offset) { + size_t offset = 0; + int lod_size = static_cast(x_lod.size()); + for (int i = 0; i < static_cast(x_lod.size()); ++i) { + out_offset[i] = offset; + if (i < lod_size - 1) { + offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); + } + } +} + template struct SequenceExpandFunctor { void operator()( @@ -109,6 +97,9 @@ struct SequenceExpandFunctor { const framework::Vector& ref_lod, /*expand referenced lod*/ LoDTensor* out) { int x_item_length = x.numel() / x.dims()[0]; + framework::Vector out_offset(x_lod.size()); + GetOutputOffset(x_lod, ref_lod, out_offset); + int thread_x = std::min(32, std::max(static_cast(ref_lod.size()), 16)); int thread_y = 16; int thread_z = 1024 / thread_x / thread_y; @@ -118,7 +109,8 @@ struct SequenceExpandFunctor { sequence_expand_kernel<<>>( x.data(), x_lod.CUDAData(context.GetPlace()), - ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, + ref_lod.CUDAData(context.GetPlace()), + out_offset.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, out->mutable_data(context.GetPlace())); } }; @@ -131,6 +123,9 @@ struct SequenceExpandGradFunctor { const framework::Vector& ref_lod, /*expand based lod*/ LoDTensor* dx) { int x_item_length = framework::product(dx->dims()) / dx->dims()[0]; + framework::Vector out_offset(x_lod.size()); + GetOutputOffset(x_lod, ref_lod, out_offset); + int thread_x = std::min(32, std::max(static_cast(ref_lod.size()), 16)); int thread_y = 16; int thread_z = 1024 / thread_x / thread_y; @@ -139,7 +134,8 @@ struct SequenceExpandGradFunctor { dim3 grid_size(block_x, 1); sequence_expand_grad_kernel<<>>( dout.data(), ref_lod.CUDAData(context.GetPlace()), - x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, + x_lod.CUDAData(context.GetPlace()), + out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, dx->mutable_data(context.GetPlace())); } };