提交 e4c35d83 编写于 作者: D dzhwinter

"add details"

上级 26822bd7
...@@ -54,15 +54,15 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data, ...@@ -54,15 +54,15 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data,
int tid_z = blockIdx.z * blockDim.z + threadIdx.z; int tid_z = blockIdx.z * blockDim.z + threadIdx.z;
int item_start = tid_x / element_len; int item_start = tid_x / element_len;
for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) { for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) {
shm[item_start + tid_z] += doutx_data[item_start * scale + tid_z]; shm[item_start + tid_z] += dout_data[item_start * scale + tid_z];
} }
} }
} }
// synchronize before write to dx // synchronize before write to dx
__syncthreads(); __syncthreads();
for (int idx = blockDimx * blockIdx.x + threadIdx.x; for (int idx = blockDim.x * blockIdx.x + threadIdx.x;
idx < static_cast<int>(dout_size); idx += blockDim.x * gridDim.x) { idx < static_cast<int>(dout_size); idx += blockDim.x * gridDim.x) {
dx_data[idx] = shm[idx;] dx_data[idx] = shm[idx];
} }
} }
...@@ -86,19 +86,18 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { ...@@ -86,19 +86,18 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
template <typename T> template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> { struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const LoDTensor& x, void operator()(const platform::CUDADeviceContext& context,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { const LoDTensor& x, const LoDTensor& out,
const LoDTensor& dout, LoDTensor* dx) {
auto x_dims = x.dims(); auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0]; size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>(); auto out_starts = out.lod().back();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
dim3 block_size(16, 32, element_len); dim3 block_size(16, 32, element_len);
dim3 grid_size(10, 10); dim3 grid_size(10, 10);
size_t out_size = framework::product(dx->dims()); size_t out_size = framework::product(dx->dims());
sequence_expand_kernel<<<grid_size, block_size, out_size * sizeof(T), sequence_expand_grad_kernel<<<grid_size, block_size, out_size * sizeof(T),
context.stream()>>>( context.stream()>>>(
dout.data<T>(), dx->mutable_data<T>(context.GetPlace()), dout.data<T>(), dx->mutable_data<T>(context.GetPlace()),
out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len, out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len,
out_size); out_size);
......
...@@ -40,7 +40,7 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> { ...@@ -40,7 +40,7 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
LoDTensor* out) { LoDTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0]; size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>(); const T* x_data = x.data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back(); auto out_starts = out->lod().back();
...@@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> { ...@@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* */ * */
template <typename T> template <typename T>
struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> { struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x, void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) { const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
auto out_last_level = out.lod().back(); auto out_last_level = out.lod().back();
const T* d_out_data = d_out.data<T>(); const T* d_out_data = dout.data<T>();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace()); T* d_x_data = dx->mutable_data<T>(context.GetPlace());
size_t element_len = d_out.numel() / d_out.dims()[0]; size_t element_len = dout.numel() / dout.dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) { for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i]; size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap< Eigen::TensorMap<
...@@ -117,13 +117,15 @@ template <typename DeviceContext, typename T> ...@@ -117,13 +117,15 @@ template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> { class SequenceExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X"); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out"); auto* out = context.Input<LoDTensor>("Out");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
d_x->set_lod(x->lod()); d_x->set_lod(x->lod());
SequenceExpandGradFunctor(context.template device_context(), *x, *out, SequenceExpandGradFunctor<DeviceContext, T> functor;
d_out, d_x); functor(context.template device_context<DeviceContext>(), *x, *out, *d_out,
d_x);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册