diff --git a/mace/kernels/softmax.h b/mace/kernels/softmax.h index bd21547d2cf8294913781f1c1cb6bb3828170edb..ac8c99131c4132cf4375ac06fdc443db39912edc 100644 --- a/mace/kernels/softmax.h +++ b/mace/kernels/softmax.h @@ -43,6 +43,7 @@ struct SoftmaxFunctor { const index_t batch = input->dim(0); const index_t class_count = input->dim(1); const index_t class_size = input->dim(2) * input->dim(3); + const index_t batch_size = class_count * class_size; Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); @@ -50,46 +51,37 @@ struct SoftmaxFunctor { float *output_data = output->mutable_data(); for (index_t b = 0; b < batch; ++b) { - std::vector - max_val(class_size, std::numeric_limits::lowest()); - std::vector sum_val(class_size, 0.f); - - // calculate max for each class - for (index_t c = 0; c < class_count; ++c) { - const float - *input_ptr = input_data + (b * class_count + c) * class_size; - for (index_t k = 0; k < class_size; ++k) { - max_val[k] = std::max(max_val[k], input_ptr[k]); - } - } - - // calculate data - max for each class #pragma omp parallel for - for (index_t c = 0; c < class_count; ++c) { - const float - *input_ptr = input_data + (b * class_count + c) * class_size; - float *output_ptr = output_data + (b * class_count + c) * class_size; - for (index_t k = 0; k < class_size; ++k) { - output_ptr[k] = ::exp(input_ptr[k] - max_val[k]); + for (index_t k = 0; k < class_size; ++k) { + const float *input_ptr = input_data + b * batch_size + k; + float *output_ptr = output_data + b * batch_size + k; + + float max_val = std::numeric_limits::lowest(); + index_t channel_offset = 0; + for (index_t c = 0; c < class_count; ++c) { + float data = input_ptr[channel_offset]; + if (data > max_val) { + max_val = data; + } + channel_offset += class_size; } - } - // calculate sum for each class - for (index_t c = 0; c < class_count; ++c) { - float *output_ptr = output_data + (b * class_count + c) * class_size; - for (index_t k = 0; k < class_size; ++k) { - sum_val[k] += output_ptr[k]; + channel_offset = 0; + float sum = 0; + for (index_t c = 0; c < class_count; ++c) { + float exp_value = ::exp(input_ptr[channel_offset] - max_val); + sum += exp_value; + output_ptr[channel_offset] = exp_value; + channel_offset += class_size; } - } - // calculate (data - max) / sum for each class - for (index_t c = 0; c < class_count; ++c) { - float *output_ptr = output_data + (b * class_count + c) * class_size; - for (index_t k = 0; k < class_size; ++k) { - output_ptr[k] /= sum_val[k]; + channel_offset = 0; + for (index_t c = 0; c < class_count; ++c) { + output_ptr[channel_offset] /= sum; + channel_offset += class_size; } - } - } + } // k + } // b } };