From e4c35d837d79c4b1a4f30e42efe143f64ec10e71 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Tue, 20 Mar 2018 04:43:00 -0700 Subject: [PATCH] "add details" --- paddle/fluid/operators/sequence_expand_op.cu | 19 +++++++++---------- paddle/fluid/operators/sequence_expand_op.h | 18 ++++++++++-------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/sequence_expand_op.cu b/paddle/fluid/operators/sequence_expand_op.cu index 9cdb89f8f..cae0a6928 100644 --- a/paddle/fluid/operators/sequence_expand_op.cu +++ b/paddle/fluid/operators/sequence_expand_op.cu @@ -54,15 +54,15 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data, int tid_z = blockIdx.z * blockDim.z + threadIdx.z; int item_start = tid_x / element_len; for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { - shm[item_start + tid_z] += doutx_data[item_start * scale + tid_z]; + shm[item_start + tid_z] += dout_data[item_start * scale + tid_z]; } } } // synchronize before write to dx __syncthreads(); - for (int idx = blockDimx * blockIdx.x + threadIdx.x; + for (int idx = blockDim.x * blockIdx.x + threadIdx.x; idx < static_cast(dout_size); idx += blockDim.x * gridDim.x) { - dx_data[idx] = shm[idx;] + dx_data[idx] = shm[idx]; } } @@ -86,19 +86,18 @@ struct SequenceExpandFunctor { template struct SequenceExpandGradFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const LoDTensor& x, - const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { + void operator()(const platform::CUDADeviceContext& context, + const LoDTensor& x, const LoDTensor& out, + const LoDTensor& dout, LoDTensor* dx) { auto x_dims = x.dims(); size_t element_len = framework::product(x_dims) / x_dims[0]; - const T* x_data = x->data(); - T* out_data = out->mutable_data(context.GetPlace()); - auto out_starts = out->lod().back(); + auto out_starts = out.lod().back(); dim3 block_size(16, 32, element_len); dim3 grid_size(10, 10); size_t out_size = framework::product(dx->dims()); - sequence_expand_kernel<<>>( + sequence_expand_grad_kernel<<>>( dout.data(), dx->mutable_data(context.GetPlace()), out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len, out_size); diff --git a/paddle/fluid/operators/sequence_expand_op.h b/paddle/fluid/operators/sequence_expand_op.h index 3b66bf3d8..11890b30a 100644 --- a/paddle/fluid/operators/sequence_expand_op.h +++ b/paddle/fluid/operators/sequence_expand_op.h @@ -40,7 +40,7 @@ struct SequenceExpandFunctor { LoDTensor* out) { auto x_dims = x.dims(); size_t element_len = framework::product(x_dims) / x_dims[0]; - const T* x_data = x->data(); + const T* x_data = x.data(); T* out_data = out->mutable_data(context.GetPlace()); auto out_starts = out->lod().back(); @@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel { * */ template struct SequenceExpandGradFunctor { - void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x, + void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x, const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { auto out_last_level = out.lod().back(); - const T* d_out_data = d_out.data(); - T* d_x_data = d_x->mutable_data(context.GetPlace()); - size_t element_len = d_out.numel() / d_out.dims()[0]; + const T* d_out_data = dout.data(); + T* d_x_data = dx->mutable_data(context.GetPlace()); + size_t element_len = dout.numel() / dout.dims()[0]; for (size_t i = 0; i < out_last_level.size() - 1; ++i) { size_t repeat = out_last_level[i + 1] - out_last_level[i]; Eigen::TensorMap< @@ -117,13 +117,15 @@ template class SequenceExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* d_out = context.Input(framework::GradVarName("Out")); auto* x = context.Input("X"); auto* out = context.Input("Out"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); d_x->set_lod(x->lod()); - SequenceExpandGradFunctor(context.template device_context(), *x, *out, - d_out, d_x); + SequenceExpandGradFunctor functor; + functor(context.template device_context(), *x, *out, *d_out, + d_x); } }; -- GitLab