提交 8d9babf2 编写于 作者: W wanghaox

maxout code review 2nd

上级 f319fb1c
...@@ -42,11 +42,11 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> { ...@@ -42,11 +42,11 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i; int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c; int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; f++) { for (int f = 0; f < fea_size; ++f) {
T ele = maxout_process.initial(); T ele = maxout_process.initial();
for (int ph = 0; ph < groups; ++ph) { for (int ph = 0; ph < groups; ++ph) {
maxout_process.compute(ele, maxout_process.compute(ele,
...@@ -82,15 +82,15 @@ public: ...@@ -82,15 +82,15 @@ public:
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i; int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c; int clen = fea_size * c;
for (int f = 0; f < fea_size; f++) { for (int f = 0; f < fea_size; ++f) {
int input_idx = 0; int input_idx = 0;
bool stop = false; bool stop = false;
int output_idx = blen + clen + f; int output_idx = blen + clen + f;
for (int g = 0; g < groups && !stop; g++) { for (int g = 0; g < groups && !stop; ++g) {
input_idx = (blen + clen) * groups + fea_size * g + f; input_idx = (blen + clen) * groups + fea_size * g + f;
input_grad_data[input_idx] = 0; input_grad_data[input_idx] = 0;
if (input_data[input_idx] == output_data[output_idx]) { if (input_data[input_idx] == output_data[output_idx]) {
......
...@@ -21,9 +21,10 @@ namespace math { ...@@ -21,9 +21,10 @@ namespace math {
template <typename MaxOutProcess, typename T> template <typename MaxOutProcess, typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data, __global__ void KernelMaxOut(const int nthreads, const T* input_data,
T* output_data, const int channels, const int channels,
const int input_height, const int input_width, const int input_height, const int input_width,
int groups, MaxOutProcess maxout_process) { int groups, T* output_data,
MaxOutProcess maxout_process) {
const int size = input_height * input_width * channels / groups; const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width; const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
...@@ -58,7 +59,7 @@ __global__ void KernelMaxoutGrad( ...@@ -58,7 +59,7 @@ __global__ void KernelMaxoutGrad(
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx; (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int maxIndex = -1; int maxIndex = -1;
bool stop = false; bool stop = false;
for (int g = 0; g < groups && !stop; g++) { for (int g = 0; g < groups && !stop; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[index]) { if (input_data[data_idx + g * feat_len] == output_data[index]) {
maxIndex = data_idx + g * feat_len; maxIndex = data_idx + g * feat_len;
stop = true; stop = true;
...@@ -99,9 +100,9 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> { ...@@ -99,9 +100,9 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
MaxOutProcess, MaxOutProcess,
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, input_channels, .stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups, input_height, input_width, groups,
maxout_process); output_data, maxout_process);
} }
}; };
/* /*
......
...@@ -54,13 +54,11 @@ class MaxOutGradKernel : public framework::OpKernel<T> { ...@@ -54,13 +54,11 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
int groups = context.template Attr<int>("groups"); int groups = context.template Attr<int>("groups");
auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero;
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad); zero(device_ctx, in_x_grad, static_cast<T>(0.0));
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
paddle::operators::math::MaxOutGradFunctor<Place, T> paddle::operators::math::MaxOutGradFunctor<Place, T>
maxout_backward; maxout_backward;
......
...@@ -26,8 +26,6 @@ class TestMaxOutOp(OpTest): ...@@ -26,8 +26,6 @@ class TestMaxOutOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
print self.inputs
print self.outputs
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册