提交 8d9babf2 编写于 作者: W wanghaox

maxout code review 2nd

上级 f319fb1c
......@@ -42,11 +42,11 @@ class MaxOutFunctor<platform::CPUPlace, MaxOutProcess, T> {
const T* input_data = input.data<T>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int i = 0; i < batch_size; ++i) {
int new_bindex = c_size * i;
for (int c = 0; c < output_channels; ++c) {
int new_cindex = fea_size * c;
for (int f = 0; f < fea_size; f++) {
for (int f = 0; f < fea_size; ++f) {
T ele = maxout_process.initial();
for (int ph = 0; ph < groups; ++ph) {
maxout_process.compute(ele,
......@@ -82,15 +82,15 @@ public:
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int i = 0; i < batch_size; ++i) {
int blen = fea_size * output_channels * i;
for (int c = 0; c < output_channels; ++c) {
int clen = fea_size * c;
for (int f = 0; f < fea_size; f++) {
for (int f = 0; f < fea_size; ++f) {
int input_idx = 0;
bool stop = false;
int output_idx = blen + clen + f;
for (int g = 0; g < groups && !stop; g++) {
for (int g = 0; g < groups && !stop; ++g) {
input_idx = (blen + clen) * groups + fea_size * g + f;
input_grad_data[input_idx] = 0;
if (input_data[input_idx] == output_data[output_idx]) {
......
......@@ -21,9 +21,10 @@ namespace math {
template <typename MaxOutProcess, typename T>
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
T* output_data, const int channels,
const int channels,
const int input_height, const int input_width,
int groups, MaxOutProcess maxout_process) {
int groups, T* output_data,
MaxOutProcess maxout_process) {
const int size = input_height * input_width * channels / groups;
const int feat_len = input_height * input_width;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
......@@ -58,7 +59,7 @@ __global__ void KernelMaxoutGrad(
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
int maxIndex = -1;
bool stop = false;
for (int g = 0; g < groups && !stop; g++) {
for (int g = 0; g < groups && !stop; ++g) {
if (input_data[data_idx + g * feat_len] == output_data[index]) {
maxIndex = data_idx + g * feat_len;
stop = true;
......@@ -99,9 +100,9 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
MaxOutProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, input_channels,
.stream()>>>(nthreads, input_data, input_channels,
input_height, input_width, groups,
maxout_process);
output_data, maxout_process);
}
};
/*
......
......@@ -54,13 +54,11 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
int groups = context.template Attr<int>("groups");
auto& device_ctx = context.device_context();
math::SetConstant<Place, T> zero;
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
paddle::operators::math::MaxOutGradFunctor<Place, T>
maxout_backward;
......
......@@ -26,8 +26,6 @@ class TestMaxOutOp(OpTest):
self.check_output()
def test_check_grad(self):
print self.inputs
print self.outputs
self.check_grad(['X'], 'Out')
def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册