提交 3e3b214b 编写于 作者: L Liangliang He

Fix softmax performance issue

上级 3236ad7c
......@@ -264,7 +264,6 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
bias ? bias[i] : 0);
}
}
// benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < round_up_channels; c += kOutputChannelBlockSize) {
......@@ -326,7 +325,6 @@ void Conv2dNeonPixelK1x1S1(
const index_t total_loops = total_pixels >> 3;
const index_t loop_remaining = total_pixels & 7;
// benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) {
......
......@@ -6,55 +6,55 @@
#define MACE_KERNELS_SOFTMAX_H_
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct SoftmaxFunctor {
void operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future) {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future) {
Tensor::MappingGuard logits_guard(logits);
Tensor::MappingGuard output_guard(output);
const T *logits_ptr = logits->data<T>();
T *output_ptr = output->mutable_data<T>();
auto &logits_shape = logits->shape();
const index_t batch_size = std::accumulate(logits_shape.begin(), logits_shape.end()-1,
1, std::multiplies<index_t>());
const index_t batch_size =
std::accumulate(logits_shape.begin(), logits_shape.end() - 1, 1,
std::multiplies<index_t>());
const index_t num_classes = logits_shape.back();
#pragma omp parallel for
for (index_t i = 0; i < batch_size; ++i) {
const index_t pos = i * num_classes;
T max_value = logits_ptr[pos];
for (index_t c = 1; c < num_classes; ++c) {
max_value = std::max(max_value, logits_ptr[pos + c]);
}
// TODO: check overflow?
T sum = 0;
#pragma omp parallel
{
// Allocate per thread buffer
std::vector<T> exp_data(num_classes);
for (index_t c = 0; c < num_classes; ++c) {
exp_data[c] = ::exp((logits_ptr[pos + c] - max_value));
sum += exp_data[c];
}
for (index_t c = 0; c < num_classes; ++c) {
output_ptr[pos + c] = exp_data[c] / sum;
#pragma omp for
for (index_t i = 0; i < batch_size; ++i) {
const index_t pos = i * num_classes;
T max_value = logits_ptr[pos];
for (index_t c = 1; c < num_classes; ++c) {
max_value = std::max(max_value, logits_ptr[pos + c]);
}
// TODO: check overflow?
T sum = 0;
for (index_t c = 0; c < num_classes; ++c) {
exp_data[c] = ::exp((logits_ptr[pos + c] - max_value));
sum += exp_data[c];
}
for (index_t c = 0; c < num_classes; ++c) {
output_ptr[pos + c] = exp_data[c] / sum;
}
}
}
}
};
template<typename T>
template <typename T>
struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future);
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
};
......
......@@ -46,12 +46,13 @@ static void SoftmaxBenchmark(
net.Sync();
}
#define BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(int iters) { \
#define BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
......@@ -60,9 +61,9 @@ static void SoftmaxBenchmark(
BM_SOFTMAX_MACRO(N, C, H, W, float, OPENCL); \
BM_SOFTMAX_MACRO(N, C, H, W, half, OPENCL);
BM_SOFTMAX(1, 1, 512, 512);
BM_SOFTMAX(1, 3, 128, 128);
BM_SOFTMAX(1, 2, 512, 512);
BM_SOFTMAX(1, 3, 512, 512);
BM_SOFTMAX(1, 32, 112, 112);
BM_SOFTMAX(1, 64, 256, 256);
BM_SOFTMAX(1, 4, 512, 512);
BM_SOFTMAX(1, 10, 256, 256);
BM_SOFTMAX(1, 1024, 7, 7);
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册