提交 cdb2da6f 编写于 作者: 李寅

Improve Softmax perf

上级 dbbf8596
......@@ -43,6 +43,7 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
const index_t batch_size = class_count * class_size;
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
......@@ -50,46 +51,37 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
}
// calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float
*input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
for (index_t k = 0; k < class_size; ++k) {
const float *input_ptr = input_data + b * batch_size + k;
float *output_ptr = output_data + b * batch_size + k;
float max_val = std::numeric_limits<float>::lowest();
index_t channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) {
float data = input_ptr[channel_offset];
if (data > max_val) {
max_val = data;
}
channel_offset += class_size;
}
}
// calculate sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
channel_offset = 0;
float sum = 0;
for (index_t c = 0; c < class_count; ++c) {
float exp_value = ::exp(input_ptr[channel_offset] - max_val);
sum += exp_value;
output_ptr[channel_offset] = exp_value;
channel_offset += class_size;
}
}
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) {
output_ptr[channel_offset] /= sum;
channel_offset += class_size;
}
}
}
} // k
} // b
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册