提交 757c32f2 编写于 作者: Z zchen0211

lstm unit gpu

上级 2b10d322
...@@ -35,7 +35,7 @@ __device__ Dtype cuda_tanh(const Dtype x) { ...@@ -35,7 +35,7 @@ __device__ Dtype cuda_tanh(const Dtype x) {
} }
template <typename T> template <typename T>
__global__ void LSTMUnitKernel(const int nthreads, const int dim, const int t, __global__ void LSTMUnitKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, T* C, T* H, const T* C_prev, const T* X, T* C, T* H,
const T forget_bias) { const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { CUDA_1D_KERNEL_LOOP(index, nthreads) {
...@@ -159,9 +159,9 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel { ...@@ -159,9 +159,9 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
int n = N * D; int n = N * D;
int grid = (n + block - 1) / block; int grid = (n + block - 1) / block;
LSTMUnitGradientKernel<T><<<N * D, block>>>(n, D, C_prev, X, C, H, C_diff, LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
H_diff, C_prev_diff, X_diff, H_diff, C_prev_diff, X_diff,
T forget_bias) forget_bias);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册