提交 3fbff1ee 编写于 作者: S sweetsky0901

for code review 5

上级 350cc61f
...@@ -89,6 +89,7 @@ public: ...@@ -89,6 +89,7 @@ public:
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
continue_match = false; continue_match = false;
break;
} }
} }
} }
......
...@@ -65,6 +65,7 @@ __global__ void KernelMaxoutGrad( ...@@ -65,6 +65,7 @@ __global__ void KernelMaxoutGrad(
if (input_data[data_idx + g * feat_len] == output_data[i]) { if (input_data[data_idx + g * feat_len] == output_data[i]) {
max_index = data_idx + g * feat_len; max_index = data_idx + g * feat_len;
continue_match = false; continue_match = false;
break;
} }
} }
if (max_index != -1) { if (max_index != -1) {
......
...@@ -17,6 +17,11 @@ ...@@ -17,6 +17,11 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace, REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
float>); float>);
REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::GPUPlace,
double>);
REGISTER_OP_GPU_KERNEL(maxout_grad, REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace, ops::MaxOutGradKernel<paddle::platform::GPUPlace,
float>); float>);
REGISTER_OP_GPU_KERNEL(maxout_grad,
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册