提交 e4c35d83 编写于 作者: D dzhwinter

"add details"

上级 26822bd7
......@@ -54,15 +54,15 @@ __global__ void sequence_expand_grad_kernel(const T* dout_data, T* dx_data,
int tid_z = blockIdx.z * blockDim.z + threadIdx.z;
int item_start = tid_x / element_len;
for (; tid_z < element_len; tid_z += blockDim.z * gridDim.z) {
shm[item_start + tid_z] += doutx_data[item_start * scale + tid_z];
shm[item_start + tid_z] += dout_data[item_start * scale + tid_z];
}
}
}
// synchronize before write to dx
__syncthreads();
for (int idx = blockDimx * blockIdx.x + threadIdx.x;
for (int idx = blockDim.x * blockIdx.x + threadIdx.x;
idx < static_cast<int>(dout_size); idx += blockDim.x * gridDim.x) {
dx_data[idx] = shm[idx;]
dx_data[idx] = shm[idx];
}
}
......@@ -86,19 +86,18 @@ struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const LoDTensor& x,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
void operator()(const platform::CUDADeviceContext& context,
const LoDTensor& x, const LoDTensor& out,
const LoDTensor& dout, LoDTensor* dx) {
auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
auto out_starts = out.lod().back();
dim3 block_size(16, 32, element_len);
dim3 grid_size(10, 10);
size_t out_size = framework::product(dx->dims());
sequence_expand_kernel<<<grid_size, block_size, out_size * sizeof(T),
context.stream()>>>(
sequence_expand_grad_kernel<<<grid_size, block_size, out_size * sizeof(T),
context.stream()>>>(
dout.data<T>(), dx->mutable_data<T>(context.GetPlace()),
out_starts.CUDAData(context.GetPlace()), out_starts.size(), element_len,
out_size);
......
......@@ -40,7 +40,7 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
LoDTensor* out) {
auto x_dims = x.dims();
size_t element_len = framework::product(x_dims) / x_dims[0];
const T* x_data = x->data<T>();
const T* x_data = x.data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto out_starts = out->lod().back();
......@@ -92,12 +92,12 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* */
template <typename T>
struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const LoDTensor& x,
void operator()(const platform::CPUDeviceContext& context, const LoDTensor& x,
const LoDTensor& out, const LoDTensor& dout, LoDTensor* dx) {
auto out_last_level = out.lod().back();
const T* d_out_data = d_out.data<T>();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = d_out.numel() / d_out.dims()[0];
const T* d_out_data = dout.data<T>();
T* d_x_data = dx->mutable_data<T>(context.GetPlace());
size_t element_len = dout.numel() / dout.dims()[0];
for (size_t i = 0; i < out_last_level.size() - 1; ++i) {
size_t repeat = out_last_level[i + 1] - out_last_level[i];
Eigen::TensorMap<
......@@ -117,13 +117,15 @@ template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
d_x->set_lod(x->lod());
SequenceExpandGradFunctor(context.template device_context(), *x, *out,
d_out, d_x);
SequenceExpandGradFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *x, *out, *d_out,
d_x);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册