提交 65dbeb6a 编写于 作者: Q qijun

fix gpu build error

上级 9e3a9eb2
...@@ -25,9 +25,9 @@ Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< ...@@ -25,9 +25,9 @@ Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device<
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
DeviceType* OpKernel::KernelContext::get_eigen_device<platform::GPUPlace>() Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device<
const { platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<DeviceType>(); return device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif
......
...@@ -32,7 +32,7 @@ __global__ void KeRowConv(real* y, const real* x, const real* w, ...@@ -32,7 +32,7 @@ __global__ void KeRowConv(real* y, const real* x, const real* w,
for (int i = tidy; i < context; i += blky) { for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
} }
__syncthreads(); __syncthreads();
for (int i = 0; i < numSeq; ++i) { for (int i = 0; i < numSeq; ++i) {
...@@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, ...@@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy,
int yoff = start + j; int yoff = start + j;
// transpose // transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0; x[yoff * width + xoff] : 0.0;
sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ?
dy[yoff * width + xoff] : 0.0;
__syncthreads(); __syncthreads();
if (tidy < (context - 1)) { if (tidy < (context - 1)) {
yoff = yoff - context + 1; yoff = yoff - context + 1;
sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0; sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ?
dy[yoff * width + xoff] : 0.0;
} }
__syncthreads(); __syncthreads();
...@@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy, ...@@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy,
int yoff = start + j; int yoff = start + j;
// transpose // transpose
sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; sh_x[tidx][tidy] = (xoff < width && yoff < end) ?
x[yoff * width + xoff] : 0.0;
__syncthreads(); __syncthreads();
for (int t = 0; t < context; t++) { for (int t = 0; t < context; t++) {
sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start &&
yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0;
__syncthreads(); __syncthreads();
real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx];
...@@ -239,7 +244,7 @@ __global__ void KeRowConvBwData(real* dx, const real* w, const real* dy, ...@@ -239,7 +244,7 @@ __global__ void KeRowConvBwData(real* dx, const real* w, const real* dy,
for (int i = tidy; i < context; i += blky) { for (int i = tidy; i < context; i += blky) {
sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0;
} }
__syncthreads(); __syncthreads();
for (int i = 0; i < numSeq; ++i) { for (int i = 0; i < numSeq; ++i) {
...@@ -312,7 +317,7 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG, ...@@ -312,7 +317,7 @@ void RowConvGrad<DEVICE_TYPE_GPU>(const GpuMatrix& outG,
dim3 dimBlock(32, 32); dim3 dimBlock(32, 32);
dim3 dimGrid(DIVUP(width, dimBlock.x), 1); dim3 dimGrid(DIVUP(width, dimBlock.x), 1);
real* dw = filterG.getData(); real* dw = filterG.getData();
if (contextLength <= 32) { if (contextLength <= 32) {
KeRowConvBwWeight<32, 32, 32> KeRowConvBwWeight<32, 32, 32>
<<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>> <<<dimGrid, dimBlock, 0, STREAM_DEFAULT>>>
(dw, x, dy, starts, height, width, numSeq, contextLength); (dw, x, dy, starts, height, width, numSeq, contextLength);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册