提交 3fbff1ee 编写于 作者: S sweetsky0901

for code review 5

上级 350cc61f
......@@ -89,6 +89,7 @@ public:
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false;
break;
}
}
}
......
......@@ -65,6 +65,7 @@ __global__ void KernelMaxoutGrad(
if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len;
continue_match = false;
break;
}
}
if (max_index != -1) {
......
......@@ -17,6 +17,11 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
float>);
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
double>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
float>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册