提交 04fd9893 编写于 作者: S sweetsky0901

for code review 6

上级 95cbbd77
......@@ -85,11 +85,9 @@ public:
int output_idx = blen + clen + f;
for (int g = 0; g < groups && continue_match; ++g) {
int input_idx = input_idx0 + fea_size * g;
input_grad_data[input_idx] = 0;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
break;
}
}
}
......
......@@ -69,8 +69,7 @@ __global__ void KernelMaxoutGrad(
}
}
if (max_index != -1) {
// atomic add
platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]);
input_grad[max_index] += output_grad[index];
}
}
}
......
......@@ -15,13 +15,11 @@
#include "paddle/operators/maxout_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
float>);
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
double>);
REGISTER_OP_GPU_KERNEL(maxout,
ops::MaxOutKernel<paddle::platform::GPUPlace, float>,
ops::MaxOutKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
float>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
float>,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);
double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册