提交 cdb2da6f 编写于 作者: 李寅

Improve Softmax perf

上级 dbbf8596
...@@ -43,6 +43,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> { ...@@ -43,6 +43,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t class_count = input->dim(1); const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3); const index_t class_size = input->dim(2) * input->dim(3);
const index_t batch_size = class_count * class_size;
Tensor::MappingGuard input_guard(input); Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
...@@ -50,46 +51,37 @@ struct SoftmaxFunctor<DeviceType::CPU, float> { ...@@ -50,46 +51,37 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
float *output_data = output->mutable_data<float>(); float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) { for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
}
// calculate data - max for each class
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) { for (index_t k = 0; k < class_size; ++k) {
const float const float *input_ptr = input_data + b * batch_size + k;
*input_ptr = input_data + (b * class_count + c) * class_size; float *output_ptr = output_data + b * batch_size + k;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) { float max_val = std::numeric_limits<float>::lowest();
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]); index_t channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) {
float data = input_ptr[channel_offset];
if (data > max_val) {
max_val = data;
}
channel_offset += class_size;
} }
}
// calculate sum for each class channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) { float sum = 0;
float *output_ptr = output_data + (b * class_count + c) * class_size; for (index_t c = 0; c < class_count; ++c) {
for (index_t k = 0; k < class_size; ++k) { float exp_value = ::exp(input_ptr[channel_offset] - max_val);
sum_val[k] += output_ptr[k]; sum += exp_value;
output_ptr[channel_offset] = exp_value;
channel_offset += class_size;
} }
}
// calculate (data - max) / sum for each class channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) { for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size; output_ptr[channel_offset] /= sum;
for (index_t k = 0; k < class_size; ++k) { channel_offset += class_size;
output_ptr[k] /= sum_val[k];
} }
} } // k
} } // b
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册