From 26822bd774a99d19d5bb37f4890e82aacd57c391 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 20 Mar 2018 04:04:58 -0700 Subject: [PATCH] "add sequence kernel" --- paddle/fluid/operators/sequence_expand_op.cu | 107 +++++++++++++------ paddle/fluid/operators/sequence_expand_op.h | 86 ++++++++------- 2 files changed, 123 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 6477af89f1..9cdb89f8fd 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -21,48 +21,89 @@ namespace operators { using LoDTensor = framework::LoDTensor; template -__global__ sequence_expand_kernel(const T* x_data, T* out_data, size_t* lod, - size_t element_len) { - int BLOCK_SIZE = 1024; - __shared__ T shm_lod[BLOCK_SIZE]; - for (int idx = threadIdx.x; idx < BLOCK_SIZE; ++idx) { - shm_lod[idx] = lod[idx]; +__global__ void sequence_expand_kernel(const T* x_data, T* out_data, + const size_t* lod, size_t lod_size, + size_t element_len) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < static_cast(lod_size - 1); + tid_x += blockDim.x * gridDim.x) { + int scale = lod[tid_x + 1] - lod[tid_x]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) { + int tid_z = blockIdx.z * blockDim.z + threadIdx.z; + int item_start = tid_x / element_len; + for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { + out_data[item_start * scale + tid_z] = x_data[item_start + tid_z]; + } + } } - for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < lod.size(); - idx += blockDim.x * gridDim.x) { - int scale = lod[i] +} + +template +__global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data, + const size_t* lod, size_t lod_size, + size_t element_len, + size_t dout_size) { + extern __shared__ T shm[]; + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < static_cast(lod_size - 1); + tid_x += blockDim.x * gridDim.x) { + int scale = lod[tid_x + 1] - lod[tid_x]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < scale; tid_y += blockDim.y * gridDim.y) { + int tid_z = blockIdx.z * blockDim.z + threadIdx.z; + int item_start = tid_x / element_len; + for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { + shm[item_start + tid_z] += doutx_data[item_start * scale + tid_z]; + } + } + } + // synchronize before write to dx + __syncthreads(); + for (int idx = blockDimx * blockIdx.x + threadIdx.x; + idx < static_cast(dout_size); idx += blockDim.x * gridDim.x) { + dx_data[idx] = shm[idx;] } } template -void SequenceExpandFunctor::operator()( - const platform::CPUDeviceContext& context, const LoDTensor& x, - LoDTensor* out) { - x_dims = x.dims(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); +struct SequenceExpandFunctor { + void operator()(const platform::CUDADeviceContext& context, + const LoDTensor& x, LoDTensor* out) { + auto x_dims = x.dims(); + size_t element_len = framework::product(x_dims) / x_dims[0]; + T* out_data = out->mutable_data(context.GetPlace()); + auto out_starts = out->lod().back(); - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((out_cols + 31) >> 5) << 5; + dim3 block_size(16, 32, element_len); + dim3 grid_size(10, 10); + sequence_expand_kernel<<>>( + x.data(), out->mutable_data(context.GetPlace()), + out_starts.CUDAData(context.GetPlace()), out_starts.size(), + element_len); } - int block_rows = kThreadsPerBlock / block_cols; - dim3 block_size = dim3(block_cols, block_rows, 1); +}; - int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); +template +struct SequenceExpandGradFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const LoDTensor& x, + const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { + auto x_dims = x.dims(); + size_t element_len = framework::product(x_dims) / x_dims[0]; + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + auto out_starts = out->lod().back(); - int grid_cols = - std::min((out_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); - dim3 grid_size = dim3(grid_cols, grid_rows, 1); - sequence_expand_kernel<<>>( - x.data(), out->mutable_data(context.GetPlace()), - out_starts.CUDAData(context.GetPlace()), element_len); -} + dim3 block_size(16, 32, element_len); + dim3 grid_size(10, 10); + size_t out_size = framework::product(dx->dims()); + sequence_expand_kernel<<>>( + dout.data(), dx->mutable_data(context.GetPlace()), + out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len, + out_size); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 12e4018b95..3b66bf3d8c 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -28,31 +28,36 @@ struct SequenceExpandFunctor { void operator()(const DeviceContext& ctx, const LoDTensor& x, LoDTensor* out); }; -// template -// struct SequenceExpandGradFunctor {}; +template +struct SequenceExpandGradFunctor { + void operator()(const DeviceContext& ctx, const LoDTensor& x, + const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx); +}; template -void SequenceExpandFunctor::operator()( - const platform::CPUDeviceContext& context, const LoDTensor& x, - LoDTensor* out) { - x_dims = x.dims(); - size_t element_len = framework::product(x_dims) / x_dims[0]; - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); +struct SequenceExpandFunctor { + void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, + LoDTensor* out) { + auto x_dims = x.dims(); + size_t element_len = framework::product(x_dims) / x_dims[0]; + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + auto out_starts = out->lod().back(); - for (size_t i = 0; i < out_starts.size() - 1; i++) { - int scale = out_starts[i + 1] - out_starts[i]; - Eigen::TensorMap< - Eigen::Tensor> - x_t(x_data, 1, element_len); - Eigen::TensorMap> - out_t(out_data, scale, element_len); - Eigen::array cast({{scale, 1}}); - out_t.device(*context.eigen_device()) = x_t.broadcast(cast); - x_data += element_len; - out_data += element_len * scale; + for (size_t i = 0; i < out_starts.size() - 1; i++) { + int scale = out_starts[i + 1] - out_starts[i]; + Eigen::TensorMap< + Eigen::Tensor> + x_t(x_data, 1, element_len); + Eigen::TensorMap> + out_t(out_data, scale, element_len); + Eigen::array cast({{scale, 1}}); + out_t.device(*context.eigen_device()) = x_t.broadcast(cast); + x_data += element_len; + out_data += element_len * scale; + } } -} +}; template class SequenceExpandKernel : public framework::OpKernel { @@ -60,7 +65,6 @@ class SequenceExpandKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* out = context.Output("Out"); - const T* x_data = x->data(); auto x_dims = x->dims(); auto* y = context.Input("Y"); PADDLE_ENFORCE(!y->lod().empty(), "y should have lod"); @@ -86,19 +90,14 @@ class SequenceExpandKernel : public framework::OpKernel { * Grad(X).lod = Input(X).lod * * */ -template -class SequenceExpandGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* x = context.Input("X"); - auto* out = context.Input("Out"); - auto* d_x = context.Output(framework::GradVarName("X")); - auto out_last_level = out->lod().back(); - d_x->set_lod(x->lod()); - const T* d_out_data = d_out->data(); +template +struct SequenceExpandGradFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x, + const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { + auto out_last_level = out.lod().back(); + const T* d_out_data = d_out.data(); T* d_x_data = d_x->mutable_data(context.GetPlace()); - size_t element_len = d_out->numel() / d_out->dims()[0]; + size_t element_len = d_out.numel() / d_out.dims()[0]; for (size_t i = 0; i < out_last_level.size() - 1; ++i) { size_t repeat = out_last_level[i + 1] - out_last_level[i]; Eigen::TensorMap< @@ -106,14 +105,27 @@ class SequenceExpandGradKernel : public framework::OpKernel { d_out_t(d_out_data, static_cast(repeat), element_len); Eigen::TensorMap> d_x_t(d_x_data, static_cast(element_len)); - auto place = - context.template device_context().eigen_device(); - d_x_t.device(*place) = d_out_t.sum(Eigen::array({{0}})); + d_x_t.device(*context.eigen_device()) = + d_out_t.sum(Eigen::array({{0}})); d_out_data += (repeat * element_len); d_x_data += element_len; } } }; +template +class SequenceExpandGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* x = context.Input("X"); + auto* out = context.Input("Out"); + auto* d_x = context.Output(framework::GradVarName("X")); + d_x->set_lod(x->lod()); + SequenceExpandGradFunctor(context.template device_context(), *x, *out, + d_out, d_x); + } +}; + } // namespace operators } // namespace paddle -- GitLab