提交 18088a9f 编写于 作者: 李寅

Merge branch 'softmax_nhwc' into 'master'

support NHWC format and opt code for softmax op.

See merge request !1158
...@@ -111,7 +111,8 @@ class Tensor { ...@@ -111,7 +111,8 @@ class Tensor {
scale_(0.f), scale_(0.f),
zero_point_(0), zero_point_(0),
minval_(0.f), minval_(0.f),
maxval_(0.f) {} maxval_(0.f),
data_format_(DataFormat::NONE) {}
Tensor(BufferBase *buffer, DataType dtype, Tensor(BufferBase *buffer, DataType dtype,
bool is_weight = false, bool is_weight = false,
...@@ -125,7 +126,8 @@ class Tensor { ...@@ -125,7 +126,8 @@ class Tensor {
scale_(0.f), scale_(0.f),
zero_point_(0), zero_point_(0),
minval_(0.f), minval_(0.f),
maxval_(0.f) {} maxval_(0.f),
data_format_(DataFormat::NONE) {}
Tensor(const BufferSlice &buffer_slice, Tensor(const BufferSlice &buffer_slice,
DataType dtype, DataType dtype,
...@@ -140,7 +142,8 @@ class Tensor { ...@@ -140,7 +142,8 @@ class Tensor {
scale_(0.f), scale_(0.f),
zero_point_(0), zero_point_(0),
minval_(0.f), minval_(0.f),
maxval_(0.f) { maxval_(0.f),
data_format_(DataFormat::NONE) {
buffer_ = &buffer_slice_; buffer_ = &buffer_slice_;
} }
......
...@@ -43,120 +43,202 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation { ...@@ -43,120 +43,202 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
public: public:
explicit SoftmaxOp(OpConstructContext *context) explicit SoftmaxOp(OpConstructContext *context)
: Operation(context), : Operation(context),
use_log_(Operation::GetOptionalArg<bool>("use_log", false)) {} use_log_(Operation::GetOptionalArg<bool>("use_log", false)),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
const Tensor *input = this->Input(0); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(0); Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input); Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
if (isNCHW(input)) { // NCHW
return RunForNCHW(context);
} else {
return RunForNHWC(context);
}
}
protected:
bool use_log_;
bool has_df_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
protected:
MaceStatus RunForNCHW(OpContext *context) {
const Tensor *input = this->Input(INPUT);
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
Tensor *output = this->Output(OUTPUT);
float *output_data = output->mutable_data<float>(); float *output_data = output->mutable_data<float>();
MACE_CHECK(input->dim_size() == 4, "The dim size of NCHW should be 4.");
index_t hw_stride = input->dim(3);
index_t hw_size = hw_stride * input->dim(2);
index_t class_stride = hw_size;
index_t class_size = class_stride * input->dim(1);
index_t batch_stride = class_size;
index_t batch_size = batch_stride * input->dim(0);
Buffer cache_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(cache_buffer.Allocate(hw_size * sizeof(float)));
utils::ThreadPool utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool(); &thread_pool = context->device()->cpu_runtime()->thread_pool();
float std_lowest = std::numeric_limits<float>::lowest();
float *cache_ptr = cache_buffer.mutable_data<float>();
// softmax for nchw image for (index_t b_offset = 0;
if (input->dim_size() == 4) { b_offset < batch_size; b_offset += batch_stride) {
const index_t batch = input->dim(0); const float *input_b_base = input_data + b_offset;
const index_t class_count = input->dim(1); float *output_b_base = output_data + b_offset;
const index_t class_size = input->dim(2) * input->dim(3);
const index_t batch_size = class_count * class_size;
for (index_t b = 0; b < batch; ++b) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
const auto raw_step_size = step * sizeof(float);
for (index_t k = start; k < end; k += step) {
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] = std_lowest;
}
}
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
const float *input_c_base = input_b_base + c_offset;
for (index_t k = start; k < end; k += step) { for (index_t k = start; k < end; k += step) {
const float *input_ptr = input_data + b * batch_size + k; const float *input_ptr = input_c_base + k;
float *output_ptr = output_data + b * batch_size + k; float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] = std::max(cache_k_ptr[i], input_ptr[i]);
}
}
}
float max_val = std::numeric_limits<float>::lowest(); for (index_t c_offset = 0; c_offset < class_size;
index_t channel_offset = 0; c_offset += class_stride) {
for (index_t c = 0; c < class_count; ++c) { const float *input_c_base = input_b_base + c_offset;
float data = input_ptr[channel_offset]; float *output_c_base = output_b_base + c_offset;
if (data > max_val) { for (index_t k = start; k < end; k += step) {
max_val = data; const float *input_ptr = input_c_base + k;
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = ::exp(input_ptr[i] - cache_k_ptr[i]);
}
} }
channel_offset += class_size;
} }
channel_offset = 0; for (index_t k = start; k < end; k += step) {
float sum = 0; memset(cache_ptr + k, 0, raw_step_size);
for (index_t c = 0; c < class_count; ++c) { }
float exp_value = ::exp(input_ptr[channel_offset] - max_val);
sum += exp_value; for (index_t c_offset = 0; c_offset < class_size;
output_ptr[channel_offset] = exp_value; c_offset += class_stride) {
channel_offset += class_size; float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] += output_ptr[i];
}
}
}
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = output_ptr[i] / cache_k_ptr[i];
}
}
} }
sum = std::max(sum, std::numeric_limits<float>::min());
channel_offset = 0;
if (use_log_) { if (use_log_) {
for (index_t c = 0; c < class_count; ++c) { for (index_t c_offset = 0; c_offset < class_size;
output_ptr[channel_offset] /= sum; c_offset += class_stride) {
output_ptr[channel_offset] = float *output_c_base = output_b_base + c_offset;
std::log(output_ptr[channel_offset]); for (index_t k = start; k < end; k += step) {
channel_offset += class_size; float *output_ptr = output_c_base + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = std::log(output_ptr[i]);
} }
} else {
for (index_t c = 0; c < class_count; ++c) {
output_ptr[channel_offset] /= sum;
channel_offset += class_size;
} }
} }
} // k } // use_log_
}, 0, class_size, 1); }, 0, hw_size, hw_stride);
} // b
} else if (input->dim_size() == 2 || input->dim_size() == 3) {
// normal 2d softmax and 3d softmax (dim(0) is batch)
index_t class_size = 0;
index_t class_count = 0;
if (input->dim_size() == 2) {
class_size = input->dim(0);
class_count = input->dim(1);
} else {
class_size = input->dim(0) * input->dim(1);
class_count = input->dim(2);
} }
return MaceStatus::MACE_SUCCESS;
}
MaceStatus RunForNHWC(OpContext *context) {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
float *output_data = output->mutable_data<float>();
MACE_CHECK(input->dim_size() >= 2, "The input->dim_size() >= 2 failed.");
index_t class_size = input->dim(input->dim_size() - 1);
index_t hw_stride = class_size;
index_t hw_size = std::accumulate(input->shape().begin() + 1,
input->shape().end() - 1,
hw_stride,
std::multiplies<index_t>());
index_t batch_stride = hw_size;
index_t batch_size = std::accumulate(input->shape().begin(),
input->shape().end(),
1,
std::multiplies<index_t>());
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
const float *input_data = input->data<float>();
float std_lowest = std::numeric_limits<float>::lowest();
for (index_t b_offset = 0; b_offset < batch_size;
b_offset += batch_stride) {
const float *input_b_ptr = input_data + b_offset;
float *output_b_ptr = output_data + b_offset;
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t k = start; k < end; k += step) { for (index_t k = start; k < end; k += step) {
const float *input_ptr = input_data + k * class_count; const float *input_ptr = input_b_ptr + k;
float *output_ptr = output_data + k * class_count; float *output_ptr = output_b_ptr + k;
float max_val = std::numeric_limits<float>::lowest(); float max_val = std_lowest;
for (index_t c = 0; c < class_count; ++c) { for (index_t c = 0; c < class_size; ++c) {
max_val = std::max(max_val, input_ptr[c]); max_val = std::max(max_val, input_ptr[c]);
} }
float sum = 0; float sum = 0;
for (index_t c = 0; c < class_count; ++c) { for (index_t c = 0; c < class_size; ++c) {
float exp_value = std::exp(input_ptr[c] - max_val); float exp_value = ::exp(input_ptr[c] - max_val);
sum += exp_value; sum += exp_value;
output_ptr[c] = exp_value; output_ptr[c] = exp_value;
} }
sum = std::max(sum, std::numeric_limits<float>::min());
if (use_log_) { if (use_log_) {
for (index_t c = 0; c < class_count; ++c) { for (index_t c = 0; c < class_size; ++c) {
output_ptr[c] /= sum; output_ptr[c] /= sum;
output_ptr[c] = std::log(output_ptr[c]); output_ptr[c] = std::log(output_ptr[c]);
} }
} else { } else {
for (index_t c = 0; c < class_count; ++c) { for (index_t c = 0; c < class_size; ++c) {
output_ptr[c] /= sum; output_ptr[c] /= sum;
} }
} }
} } // k
}, 0, class_size, 1); }, 0, hw_size, hw_stride);
} else { } // b_offset
MACE_NOT_IMPLEMENTED;
}
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
protected: inline bool isNCHW(const Tensor *input) {
bool use_log_; auto data_format = input->data_format();
auto dim_size = input->dim_size();
return dim_size == 4 && (has_df_ || data_format == DataFormat::NCHW);
}
}; };
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
...@@ -173,8 +255,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -173,8 +255,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation {
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
MACE_CHECK(!use_log_, "MACE dose not support quantized logsoftmax yet."); MACE_CHECK(!use_log_, "MACE dose not support quantized logsoftmax yet.");
const Tensor *input = this->Input(0); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(0); Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
// Ignore range stat, fix range to [0, 1]. For large depth, each softmax // Ignore range stat, fix range to [0, 1]. For large depth, each softmax
// output may be too small (<<1), which causes precision issue. But it is // output may be too small (<<1), which causes precision issue. But it is
...@@ -403,6 +485,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -403,6 +485,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation {
protected: protected:
bool use_log_; bool use_log_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
...@@ -421,8 +505,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation { ...@@ -421,8 +505,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation {
} }
} }
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(0); Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output); return kernel_->Compute(context, input, output);
...@@ -430,6 +514,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation { ...@@ -430,6 +514,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation {
private: private:
std::unique_ptr<OpenCLSoftmaxKernel> kernel_; std::unique_ptr<OpenCLSoftmaxKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
......
...@@ -22,17 +22,17 @@ namespace ops { ...@@ -22,17 +22,17 @@ namespace ops {
namespace test { namespace test {
namespace { namespace {
template <DeviceType D, typename T> template<DeviceType D, typename T>
void SoftmaxBenchmark( void SoftmaxBenchmark(int iters, int batch, int channels,
int iters, int batch, int channels, int height, int width) { int height, int width, DataFormat data_format) {
mace::testing::StopTiming(); mace::testing::StopTiming();
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
if (D == DeviceType::CPU) { if (D == DeviceType::CPU && data_format == DataFormat::NCHW) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width}); net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU || data_format == DataFormat::NHWC) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels}); net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -42,6 +42,7 @@ void SoftmaxBenchmark( ...@@ -42,6 +42,7 @@ void SoftmaxBenchmark(
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", data_format == DataFormat::NCHW)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Warm-up // Warm-up
...@@ -58,10 +59,12 @@ void SoftmaxBenchmark( ...@@ -58,10 +59,12 @@ void SoftmaxBenchmark(
} }
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
template <> template<>
void SoftmaxBenchmark<CPU, uint8_t>( void SoftmaxBenchmark<CPU, uint8_t>(
int iters, int batch, int channels, int height, int width) { int iters, int batch, int channels, int height,
int width, DataFormat data_format) {
mace::testing::StopTiming(); mace::testing::StopTiming();
MACE_UNUSED(data_format);
OpsTestNet net; OpsTestNet net;
...@@ -100,33 +103,39 @@ void SoftmaxBenchmark<CPU, uint8_t>( ...@@ -100,33 +103,39 @@ void SoftmaxBenchmark<CPU, uint8_t>(
} // namespace } // namespace
#define MACE_BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \ #define MACE_BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE, DF) \
static void MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ static void \
MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE##_##DF( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \ SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W, (DataFormat::DF)); \
} \ } \
MACE_BENCHMARK(MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) MACE_BENCHMARK( \
MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE##_##DF)
#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE) #if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE)
#define MACE_BM_SOFTMAX(N, C, H, W) \ #define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU) MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU, NHWC)
#elif defined(MACE_ENABLE_OPENCL) #elif defined(MACE_ENABLE_OPENCL)
#define MACE_BM_SOFTMAX(N, C, H, W) \ #define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU) MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU, NHWC)
#elif defined(MACE_ENABLE_QUANTIZE) #elif defined(MACE_ENABLE_QUANTIZE)
#define MACE_BM_SOFTMAX(N, C, H, W) \ #define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU) MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU, NHWC)
#else #else
#define MACE_BM_SOFTMAX(N, C, H, W) \ #define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU) MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC)
#endif #endif
MACE_BM_SOFTMAX(1, 2, 512, 512); MACE_BM_SOFTMAX(1, 2, 512, 512);
......
...@@ -55,6 +55,7 @@ void Simple(bool use_log = false) { ...@@ -55,6 +55,7 @@ void Simple(bool use_log = false) {
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
.AddIntArg("has_data_format", 1)
.AddIntArg("use_log", static_cast<int>(use_log)) .AddIntArg("use_log", static_cast<int>(use_log))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
......
...@@ -244,6 +244,7 @@ class OpsTestNet { ...@@ -244,6 +244,7 @@ class OpsTestNet {
} }
} }
} }
output->set_data_format(DataFormat::NCHW);
} else if (src_format == DataFormat::NCHW && } else if (src_format == DataFormat::NCHW &&
dst_format == DataFormat::NHWC) { dst_format == DataFormat::NHWC) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
...@@ -265,6 +266,7 @@ class OpsTestNet { ...@@ -265,6 +266,7 @@ class OpsTestNet {
} }
} }
} }
output->set_data_format(DataFormat::NHWC);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册