hl_cuda_cnn.cu 25.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <float.h>
#include "hl_base.h"
#include "hl_cnn.h"
L
liaogang 已提交
18
#include "hl_device_functions.cuh"
Z
zhangjinchao01 已提交
19

L
liaogang 已提交
20 21 22 23
__global__ void KeMaxPoolForward(const int nthreads,
                                 const real* inputData,
                                 const int channels,
                                 const int height,
24
                                 const int width,
L
liaogang 已提交
25 26 27 28 29 30 31 32 33 34 35
                                 const int pooledH,
                                 const int pooledW,
                                 const int ksizeW,
                                 const int ksizeH,
                                 const int strideH,
                                 const int strideW,
                                 const int offsetH,
                                 const int offsetW,
                                 real* tgtData,
                                 const int tgtStride) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
36 37 38 39
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
40 41 42 43 44 45 46
    int frameNum = index / pooledW / pooledH / channels;
    int hstart = ph * strideH - offsetH;
    int wstart = pw * strideW - offsetW;
    int hend = min(hstart + ksizeH, height);
    int wend = min(wstart + ksizeW, width);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
Z
zhangjinchao01 已提交
47 48 49 50 51 52 53 54
    real maxval = -FLT_MAX;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        if (maxval < inputData[h * width + w])
          maxval = inputData[h * width + w];
      }
    }
L
liaogang 已提交
55 56
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
57
    tgtData[tgtIndex] = maxval;
Z
zhangjinchao01 已提交
58 59 60
  }
}

L
liaogang 已提交
61 62
void hl_maxpool_forward(const int frameCnt,
                        const real* inputData,
63
                        const int channels,
L
liaogang 已提交
64 65 66 67 68 69 70 71 72 73 74 75
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride) {
76 77
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
78
  dim3 threads(1024, 1);
79 80
  dim3 grid(blocks, 1);

L
liaogang 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  KeMaxPoolForward<<<grid, threads, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         tgtData,
                                                         tgtStride);
Z
zhangjinchao01 已提交
96 97 98
  CHECK_SYNC("hl_maxpool_forward failed");
}

L
liaogang 已提交
99 100 101 102 103 104
__global__ void KeMaxPoolBackward(const int nthreads,
                                  const real* inputData,
                                  const real* outData,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
105
                                  const int width,
L
liaogang 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* targetGrad,
                                  const int outStride) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
119 120 121
  if (index < nthreads) {
    // find out the local index
    // find out the local offset
122 123
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
124
    int offsetC = (index / width / height) % channels;
125 126 127 128 129 130

    int frameNum = index / width / height / channels;
    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
131 132
    real gradient = 0;
    real input = inputData[index];
Q
qijun 已提交
133 134
    outData += (frameNum * outStride + offsetC * pooledH * pooledW);
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);
Z
zhangjinchao01 已提交
135 136 137 138 139 140 141
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        if (input == outData[ph * pooledW + pw]) {
          gradient += outGrad[ph * pooledW + pw];
        }
      }
    }
L
liaogang 已提交
142
    targetGrad[index] = scaleB * targetGrad[index] + scaleA * gradient;
Z
zhangjinchao01 已提交
143 144 145
  }
}

L
liaogang 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
void hl_maxpool_backward(const int frameCnt,
                         const real* inputData,
                         const real* outData,
                         const real* outGrad,
                         const int channels,
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* targetGrad,
                         const int outStride) {
165 166
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
167

L
liaogang 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  KeMaxPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         inputData,
                                                         outData,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         targetGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
187 188 189
  CHECK_SYNC("hl_maxpool_backward");
}

L
liaogang 已提交
190 191
__global__ void KeAvgPoolForward(const int nthreads,
                                 const real* inputData,
192
                                 const int channels,
L
liaogang 已提交
193 194 195 196 197 198 199 200 201 202 203 204
                                 const int height,
                                 const int width,
                                 const int pooledH,
                                 const int pooledW,
                                 const int sizeX,
                                 const int sizeY,
                                 const int strideH,
                                 const int strideW,
                                 const int padH,
                                 const int padW,
                                 real* tgtData,
                                 const int tgtStride) {
205
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
206 207 208 209
  if (index < nthreads) {
    int pw = index % pooledW;
    int ph = (index / pooledW) % pooledH;
    int c = (index / pooledW / pooledH) % channels;
210 211 212 213 214 215 216 217 218 219 220 221
    int frameNum = index / pooledW / pooledH / channels;

    int hstart = ph * strideH - padH;
    int wstart = pw * strideW - padW;
    int hend = min(hstart + sizeY, height + padH);
    int wend = min(wstart + sizeX, width + padW);
    int pool_size = (hend - hstart) * (wend - wstart);
    hstart = max(hstart, 0);
    wstart = max(wstart, 0);
    hend = min(hend, height);
    wend = min(wend, width);

Z
zhangjinchao01 已提交
222 223 224 225 226 227 228
    real aveval = 0;
    inputData += (frameNum * channels + c) * height * width;
    for (int h = hstart; h < hend; ++h) {
      for (int w = wstart; w < wend; ++w) {
        aveval += inputData[h * width + w];
      }
    }
L
liaogang 已提交
229 230
    int tgtIndex =
        index % (pooledW * pooledH * channels) + frameNum * tgtStride;
Q
qijun 已提交
231
    tgtData[tgtIndex] = aveval / pool_size;
Z
zhangjinchao01 已提交
232 233 234
  }
}

L
liaogang 已提交
235 236
void hl_avgpool_forward(const int frameCnt,
                        const real* inputData,
237
                        const int channels,
L
liaogang 已提交
238 239 240 241 242 243 244 245 246 247 248 249
                        const int height,
                        const int width,
                        const int pooledH,
                        const int pooledW,
                        const int sizeX,
                        const int sizeY,
                        const int strideH,
                        const int strideW,
                        const int paddingH,
                        const int paddingW,
                        real* tgtData,
                        const int tgtStride) {
250 251
  int num_kernels = pooledH * pooledW * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
  KeAvgPoolForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                        inputData,
                                                        channels,
                                                        height,
                                                        width,
                                                        pooledH,
                                                        pooledW,
                                                        sizeX,
                                                        sizeY,
                                                        strideH,
                                                        strideW,
                                                        paddingH,
                                                        paddingW,
                                                        tgtData,
                                                        tgtStride);
Z
zhangjinchao01 已提交
267 268 269
  CHECK_SYNC("hl_avgpool_forward failed");
}

L
liaogang 已提交
270 271 272 273
__global__ void KeAvgPoolBackward(const int nthreads,
                                  const real* outGrad,
                                  const int channels,
                                  const int height,
274
                                  const int width,
L
liaogang 已提交
275 276 277 278 279 280 281 282 283 284 285 286
                                  const int pooledH,
                                  const int pooledW,
                                  const int sizeX,
                                  const int sizeY,
                                  const int strideH,
                                  const int strideW,
                                  const int padH,
                                  const int padW,
                                  real scaleA,
                                  real scaleB,
                                  real* tgtGrad,
                                  const int outStride) {
287
  int index = blockIdx.x * blockDim.x + threadIdx.x;
Z
zhangjinchao01 已提交
288
  if (index < nthreads) {
289 290
    int offsetW = index % width + padW;
    int offsetH = (index / width) % height + padH;
Z
zhangjinchao01 已提交
291
    int offsetC = (index / width / height) % channels;
292 293 294 295 296 297
    int frameNum = index / width / height / channels;

    int phstart = (offsetH < sizeY) ? 0 : (offsetH - sizeY) / strideH + 1;
    int pwstart = (offsetW < sizeX) ? 0 : (offsetW - sizeX) / strideW + 1;
    int phend = offsetH >= 0 ? min(offsetH / strideH + 1, pooledH) : 0;
    int pwend = offsetW >= 0 ? min(offsetW / strideW + 1, pooledW) : 0;
Z
zhangjinchao01 已提交
298
    real gradient = 0;
Q
qijun 已提交
299 300
    outGrad += (frameNum * outStride + offsetC * pooledH * pooledW);

Z
zhangjinchao01 已提交
301 302 303
    for (int ph = phstart; ph < phend; ++ph) {
      for (int pw = pwstart; pw < pwend; ++pw) {
        // figure out the pooling size
304 305 306 307 308
        int hstart = ph * strideH - padH;
        int wstart = pw * strideW - padW;
        int hend = min(hstart + sizeY, height + padH);
        int wend = min(wstart + sizeX, width + padW);
        int poolsize = (hend - hstart) * (wend - wstart);
L
liaogang 已提交
309
        gradient += outGrad[ph * pooledW + pw] / poolsize;
Z
zhangjinchao01 已提交
310 311 312 313 314 315
      }
    }
    tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
  }
}

L
liaogang 已提交
316 317
void hl_avgpool_backward(const int frameCnt,
                         const real* outGrad,
318
                         const int channels,
L
liaogang 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332
                         const int height,
                         const int width,
                         const int pooledH,
                         const int pooledW,
                         const int sizeX,
                         const int sizeY,
                         const int strideH,
                         const int strideW,
                         const int paddingH,
                         const int paddingW,
                         real scaleA,
                         real scaleB,
                         real* backGrad,
                         const int outStride) {
333 334
  int num_kernels = height * width * channels * frameCnt;
  int blocks = (num_kernels + 1024 - 1) / 1024;
Z
zhangjinchao01 已提交
335

L
liaogang 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
  KeAvgPoolBackward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
                                                         outGrad,
                                                         channels,
                                                         height,
                                                         width,
                                                         pooledH,
                                                         pooledW,
                                                         sizeX,
                                                         sizeY,
                                                         strideH,
                                                         strideW,
                                                         paddingH,
                                                         paddingW,
                                                         scaleA,
                                                         scaleB,
                                                         backGrad,
                                                         outStride);
Z
zhangjinchao01 已提交
353 354 355
  CHECK_SYNC("hl_avgpool_backward failed");
}

L
liaogang 已提交
356
__global__ void KeBilinearInterpFw(const real* in,
L
liaogang 已提交
357 358 359 360 361 362 363 364 365 366 367 368
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
369
  int nthreads = outputH * outputW;
L
liaogang 已提交
370
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
390 391
    const real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                            inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
392 393 394

    // bilinear interpolation
    out[outIdH * outputW + outIdW] =
L
liaogang 已提交
395 396 397
        h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wId]) +
        h1lambda * (w2lambda * inPos[hId * inImgW] +
                    w1lambda * inPos[hId * inImgW + wId]);
L
liaogang 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410
  }
}

void hl_bilinear_forward(const real* inData,
                         const size_t inImgH,
                         const size_t inImgW,
                         const size_t inputH,
                         const size_t inputW,
                         real* outData,
                         const size_t outImgH,
                         const size_t outImgW,
                         const size_t outputH,
                         const size_t outputW,
L
liaogang 已提交
411 412 413 414
                         const size_t numChannels,
                         const real ratioH,
                         const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
415 416
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429
  KeBilinearInterpFw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inData,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outData,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
430 431 432
  CHECK_SYNC("hl_bilinear_forward failed");
}

L
liaogang 已提交
433
__global__ void KeBilinearInterpBw(real* in,
L
liaogang 已提交
434 435 436 437 438 439 440 441 442 443 444 445
                                   const size_t inImgH,
                                   const size_t inImgW,
                                   const size_t inputH,
                                   const size_t inputW,
                                   const real* out,
                                   const size_t outImgH,
                                   const size_t outImgW,
                                   const size_t outputH,
                                   const size_t outputW,
                                   const size_t numChannels,
                                   const real ratioH,
                                   const real ratioW) {
L
liaogang 已提交
446
  int nthreads = outputH * outputW;
L
liaogang 已提交
447
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
  if (tid < nthreads) {
    int outIdH = tid / outputW;
    int outIdW = tid % outputW;
    int inImgSize = inputW / numChannels;
    int outImgSize = outputW / numChannels;
    int channelId = outIdW / outImgSize;

    int outImgIdy = (outIdW % outImgSize) / outImgW;
    int inImgIdy = ratioH * outImgIdy;
    int hId = (inImgIdy < inImgH - 1) ? 1 : 0;
    real h1lambda = ratioH * outImgIdy - inImgIdy;
    real h2lambda = 1.f - h1lambda;

    int outImgIdx = tid % outImgW;
    int inImgIdx = ratioW * outImgIdx;
    int wId = (inImgIdx < inImgW - 1) ? 1 : 0;
    real w1lambda = ratioW * outImgIdx - inImgIdx;
    real w2lambda = 1.f - w1lambda;

L
liaogang 已提交
467 468
    real* inPos = &in[outIdH * inputW + channelId * inImgSize +
                      inImgIdy * inImgW + inImgIdx];
L
liaogang 已提交
469
    const real* outPos = &out[outIdH * outputW + outIdW];
L
liaogang 已提交
470 471
    paddle::paddleAtomicAdd(&inPos[0], h2lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[wId], h2lambda * w1lambda * outPos[0]);
L
liaogang 已提交
472 473 474 475
    paddle::paddleAtomicAdd(&inPos[hId * inImgW],
                            h1lambda * w2lambda * outPos[0]);
    paddle::paddleAtomicAdd(&inPos[hId * inImgW + wId],
                            h1lambda * w1lambda * outPos[0]);
L
liaogang 已提交
476 477 478 479 480 481 482 483 484 485 486 487 488
  }
}

void hl_bilinear_backward(real* inGrad,
                          const size_t inImgH,
                          const size_t inImgW,
                          const size_t inputH,
                          const size_t inputW,
                          const real* outGrad,
                          const size_t outImgH,
                          const size_t outImgW,
                          const size_t outputH,
                          const size_t outputW,
L
liaogang 已提交
489 490 491 492
                          const size_t numChannels,
                          const real ratioH,
                          const real ratioW) {
  int threadNum = outputH * outputW;
L
liaogang 已提交
493 494
  int blocks = (threadNum + 1024 - 1) / 1024;

L
liaogang 已提交
495 496 497 498 499 500 501 502 503 504 505 506 507
  KeBilinearInterpBw<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inGrad,
                                                          inImgH,
                                                          inImgW,
                                                          inputH,
                                                          inputW,
                                                          outGrad,
                                                          outImgH,
                                                          outImgW,
                                                          outputH,
                                                          outputW,
                                                          numChannels,
                                                          ratioH,
                                                          ratioW);
L
liaogang 已提交
508
  CHECK_SYNC("hl_bilinear_backward failed");
L
liaogang 已提交
509 510
}

L
liaogang 已提交
511 512 513 514 515 516 517
__global__ void maxoutFpCompute(size_t nthreads,
                                const real* inData,
                                real* outData,
                                int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
518
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
519
  if (index < nthreads) {
520 521 522 523
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
L
liaogang 已提交
524 525
    size_t data_idx =
        (batch_idx * size + channel_idx * featLen) * groups + feat_idx;
526 527 528 529 530 531 532 533 534 535 536 537 538 539
    real max = inData[data_idx];
    int maxId = 0;
    for (size_t g = 1; g < groups; ++g) {
      real tmp = inData[data_idx + g * featLen];
      if (tmp > max) {
        max = tmp;
        maxId = g;
      }
    }
    outData[index] = max;
    idData[index] = maxId;
  }
}

L
liaogang 已提交
540 541 542 543 544 545 546
void hl_maxout_forward(const real* inData,
                       real* outData,
                       int* idData,
                       size_t batchSize,
                       size_t size,
                       size_t featLen,
                       size_t groups) {
547 548
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
549 550
  maxoutFpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inData, outData, idData, size, featLen, groups);
551 552 553
  CHECK_SYNC("hl_maxout_forward failed");
}

L
liaogang 已提交
554 555 556 557 558 559 560
__global__ void maxoutBpCompute(size_t nthreads,
                                real* inGrad,
                                const real* outGrad,
                                const int* idData,
                                size_t size,
                                size_t featLen,
                                size_t groups) {
561
  int index = blockIdx.x * blockDim.x + threadIdx.x;
L
liaogang 已提交
562
  if (index < nthreads) {
563 564 565 566 567
    size_t batch_idx = index / size;
    size_t i = index % size;
    size_t channel_idx = i / featLen;
    size_t feat_idx = i % featLen;
    size_t newIndex = batch_idx * size;
L
liaogang 已提交
568 569
    size_t gradIdx =
        (channel_idx * groups + (idData + newIndex)[i]) * featLen + feat_idx;
570 571 572 573
    (inGrad + newIndex * groups)[gradIdx] += (outGrad + newIndex)[i];
  }
}

L
liaogang 已提交
574 575 576 577 578 579 580
void hl_maxout_backward(real* inGrad,
                        const real* outGrad,
                        const int* idData,
                        size_t batchSize,
                        size_t size,
                        size_t featLen,
                        size_t groups) {
581 582
  int num_kernels = size * batchSize;
  int blocks = (num_kernels + 1024 - 1) / 1024;
L
liaogang 已提交
583 584
  maxoutBpCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(
      num_kernels, inGrad, outGrad, idData, size, featLen, groups);
585 586
  CHECK_SYNC("hl_maxout_backward failed");
}