From 8d9babf20407d1ea21ad66cf5c07ec61adb7398d Mon Sep 17 00:00:00 2001 From: wanghaox Date: Wed, 15 Nov 2017 15:47:00 +0800 Subject: [PATCH] maxout code review 2nd --- paddle/operators/math/maxouting.cc | 10 +++++----- paddle/operators/math/maxouting.cu | 11 ++++++----- paddle/operators/maxout_op.h | 8 +++----- python/paddle/v2/framework/tests/test_maxout_op.py | 2 -- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index a634e49f48..b733af7410 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -42,11 +42,11 @@ class MaxOutFunctor { const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; i++) { + for (int i = 0; i < batch_size; ++i) { int new_bindex = c_size * i; for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; - for (int f = 0; f < fea_size; f++) { + for (int f = 0; f < fea_size; ++f) { T ele = maxout_process.initial(); for (int ph = 0; ph < groups; ++ph) { maxout_process.compute(ele, @@ -82,15 +82,15 @@ public: const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad.mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; i++) { + for (int i = 0; i < batch_size; ++i) { int blen = fea_size * output_channels * i; for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; - for (int f = 0; f < fea_size; f++) { + for (int f = 0; f < fea_size; ++f) { int input_idx = 0; bool stop = false; int output_idx = blen + clen + f; - for (int g = 0; g < groups && !stop; g++) { + for (int g = 0; g < groups && !stop; ++g) { input_idx = (blen + clen) * groups + fea_size * g + f; input_grad_data[input_idx] = 0; if (input_data[input_idx] == output_data[output_idx]) { diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 42acaa2c73..c2da29e356 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -21,9 +21,10 @@ namespace math { template __global__ void KernelMaxOut(const int nthreads, const T* input_data, - T* output_data, const int channels, + const int channels, const int input_height, const int input_width, - int groups, MaxOutProcess maxout_process) { + int groups, T* output_data, + MaxOutProcess maxout_process) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; @@ -58,7 +59,7 @@ __global__ void KernelMaxoutGrad( (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; int maxIndex = -1; bool stop = false; - for (int g = 0; g < groups && !stop; g++) { + for (int g = 0; g < groups && !stop; ++g) { if (input_data[data_idx + g * feat_len] == output_data[index]) { maxIndex = data_idx + g * feat_len; stop = true; @@ -99,9 +100,9 @@ class MaxOutFunctor { MaxOutProcess, T><<(context) - .stream()>>>(nthreads, input_data, output_data, input_channels, + .stream()>>>(nthreads, input_data, input_channels, input_height, input_width, groups, - maxout_process); + output_data, maxout_process); } }; /* diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 3f5897abd2..aab878af0f 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -54,13 +54,11 @@ class MaxOutGradKernel : public framework::OpKernel { int groups = context.template Attr("groups"); - - + auto& device_ctx = context.device_context(); + math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*in_x_grad); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); + zero(device_ctx, in_x_grad, static_cast(0.0)); paddle::operators::math::MaxOutGradFunctor maxout_backward; diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py index 406147ef24..a7c47108f1 100644 --- a/python/paddle/v2/framework/tests/test_maxout_op.py +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -26,8 +26,6 @@ class TestMaxOutOp(OpTest): self.check_output() def test_check_grad(self): - print self.inputs - print self.outputs self.check_grad(['X'], 'Out') def init_test_case(self): -- GitLab