提交 18088a9f 编写于 作者: 李寅

Merge branch 'softmax_nhwc' into 'master'

support NHWC format and opt code for softmax op.

See merge request !1158
......@@ -111,7 +111,8 @@ class Tensor {
scale_(0.f),
zero_point_(0),
minval_(0.f),
maxval_(0.f) {}
maxval_(0.f),
data_format_(DataFormat::NONE) {}
Tensor(BufferBase *buffer, DataType dtype,
bool is_weight = false,
......@@ -125,7 +126,8 @@ class Tensor {
scale_(0.f),
zero_point_(0),
minval_(0.f),
maxval_(0.f) {}
maxval_(0.f),
data_format_(DataFormat::NONE) {}
Tensor(const BufferSlice &buffer_slice,
DataType dtype,
......@@ -140,7 +142,8 @@ class Tensor {
scale_(0.f),
zero_point_(0),
minval_(0.f),
maxval_(0.f) {
maxval_(0.f),
data_format_(DataFormat::NONE) {
buffer_ = &buffer_slice_;
}
......
......@@ -43,120 +43,202 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
public:
explicit SoftmaxOp(OpConstructContext *context)
: Operation(context),
use_log_(Operation::GetOptionalArg<bool>("use_log", false)) {}
use_log_(Operation::GetOptionalArg<bool>("use_log", false)),
has_df_(Operation::GetOptionalArg<int>("has_data_format", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
if (isNCHW(input)) { // NCHW
return RunForNCHW(context);
} else {
return RunForNHWC(context);
}
}
protected:
bool use_log_;
bool has_df_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
protected:
MaceStatus RunForNCHW(OpContext *context) {
const Tensor *input = this->Input(INPUT);
const float *input_data = input->data<float>();
Tensor *output = this->Output(OUTPUT);
float *output_data = output->mutable_data<float>();
MACE_CHECK(input->dim_size() == 4, "The dim size of NCHW should be 4.");
index_t hw_stride = input->dim(3);
index_t hw_size = hw_stride * input->dim(2);
index_t class_stride = hw_size;
index_t class_size = class_stride * input->dim(1);
index_t batch_stride = class_size;
index_t batch_size = batch_stride * input->dim(0);
Buffer cache_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(cache_buffer.Allocate(hw_size * sizeof(float)));
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
float std_lowest = std::numeric_limits<float>::lowest();
float *cache_ptr = cache_buffer.mutable_data<float>();
// softmax for nchw image
if (input->dim_size() == 4) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
const index_t batch_size = class_count * class_size;
for (index_t b_offset = 0;
b_offset < batch_size; b_offset += batch_stride) {
const float *input_b_base = input_data + b_offset;
float *output_b_base = output_data + b_offset;
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
const auto raw_step_size = step * sizeof(float);
for (index_t k = start; k < end; k += step) {
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] = std_lowest;
}
}
for (index_t b = 0; b < batch; ++b) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
const float *input_c_base = input_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
const float *input_ptr = input_data + b * batch_size + k;
float *output_ptr = output_data + b * batch_size + k;
float max_val = std::numeric_limits<float>::lowest();
index_t channel_offset = 0;
for (index_t c = 0; c < class_count; ++c) {
float data = input_ptr[channel_offset];
if (data > max_val) {
max_val = data;
}
channel_offset += class_size;
const float *input_ptr = input_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] = std::max(cache_k_ptr[i], input_ptr[i]);
}
}
}
channel_offset = 0;
float sum = 0;
for (index_t c = 0; c < class_count; ++c) {
float exp_value = ::exp(input_ptr[channel_offset] - max_val);
sum += exp_value;
output_ptr[channel_offset] = exp_value;
channel_offset += class_size;
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
const float *input_c_base = input_b_base + c_offset;
float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
const float *input_ptr = input_c_base + k;
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = ::exp(input_ptr[i] - cache_k_ptr[i]);
}
}
}
sum = std::max(sum, std::numeric_limits<float>::min());
channel_offset = 0;
if (use_log_) {
for (index_t c = 0; c < class_count; ++c) {
output_ptr[channel_offset] /= sum;
output_ptr[channel_offset] =
std::log(output_ptr[channel_offset]);
channel_offset += class_size;
}
} else {
for (index_t c = 0; c < class_count; ++c) {
output_ptr[channel_offset] /= sum;
channel_offset += class_size;
for (index_t k = start; k < end; k += step) {
memset(cache_ptr + k, 0, raw_step_size);
}
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
cache_k_ptr[i] += output_ptr[i];
}
}
}
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
float *output_ptr = output_c_base + k;
float *cache_k_ptr = cache_ptr + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = output_ptr[i] / cache_k_ptr[i];
}
}
}
if (use_log_) {
for (index_t c_offset = 0; c_offset < class_size;
c_offset += class_stride) {
float *output_c_base = output_b_base + c_offset;
for (index_t k = start; k < end; k += step) {
float *output_ptr = output_c_base + k;
for (index_t i = 0; i < step; ++i) {
output_ptr[i] = std::log(output_ptr[i]);
}
}
} // k
}, 0, class_size, 1);
} // b
} else if (input->dim_size() == 2 || input->dim_size() == 3) {
// normal 2d softmax and 3d softmax (dim(0) is batch)
index_t class_size = 0;
index_t class_count = 0;
if (input->dim_size() == 2) {
class_size = input->dim(0);
class_count = input->dim(1);
} else {
class_size = input->dim(0) * input->dim(1);
class_count = input->dim(2);
}
}
} // use_log_
}, 0, hw_size, hw_stride);
}
return MaceStatus::MACE_SUCCESS;
}
MaceStatus RunForNHWC(OpContext *context) {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
float *output_data = output->mutable_data<float>();
MACE_CHECK(input->dim_size() >= 2, "The input->dim_size() >= 2 failed.");
index_t class_size = input->dim(input->dim_size() - 1);
index_t hw_stride = class_size;
index_t hw_size = std::accumulate(input->shape().begin() + 1,
input->shape().end() - 1,
hw_stride,
std::multiplies<index_t>());
index_t batch_stride = hw_size;
index_t batch_size = std::accumulate(input->shape().begin(),
input->shape().end(),
1,
std::multiplies<index_t>());
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
const float *input_data = input->data<float>();
float std_lowest = std::numeric_limits<float>::lowest();
for (index_t b_offset = 0; b_offset < batch_size;
b_offset += batch_stride) {
const float *input_b_ptr = input_data + b_offset;
float *output_b_ptr = output_data + b_offset;
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t k = start; k < end; k += step) {
const float *input_ptr = input_data + k * class_count;
float *output_ptr = output_data + k * class_count;
const float *input_ptr = input_b_ptr + k;
float *output_ptr = output_b_ptr + k;
float max_val = std::numeric_limits<float>::lowest();
for (index_t c = 0; c < class_count; ++c) {
float max_val = std_lowest;
for (index_t c = 0; c < class_size; ++c) {
max_val = std::max(max_val, input_ptr[c]);
}
float sum = 0;
for (index_t c = 0; c < class_count; ++c) {
float exp_value = std::exp(input_ptr[c] - max_val);
for (index_t c = 0; c < class_size; ++c) {
float exp_value = ::exp(input_ptr[c] - max_val);
sum += exp_value;
output_ptr[c] = exp_value;
}
sum = std::max(sum, std::numeric_limits<float>::min());
if (use_log_) {
for (index_t c = 0; c < class_count; ++c) {
for (index_t c = 0; c < class_size; ++c) {
output_ptr[c] /= sum;
output_ptr[c] = std::log(output_ptr[c]);
}
} else {
for (index_t c = 0; c < class_count; ++c) {
for (index_t c = 0; c < class_size; ++c) {
output_ptr[c] /= sum;
}
}
}
}, 0, class_size, 1);
} else {
MACE_NOT_IMPLEMENTED;
}
} // k
}, 0, hw_size, hw_stride);
} // b_offset
return MaceStatus::MACE_SUCCESS;
}
protected:
bool use_log_;
inline bool isNCHW(const Tensor *input) {
auto data_format = input->data_format();
auto dim_size = input->dim_size();
return dim_size == 4 && (has_df_ || data_format == DataFormat::NCHW);
}
};
#ifdef MACE_ENABLE_QUANTIZE
......@@ -173,8 +255,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation {
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
MACE_CHECK(!use_log_, "MACE dose not support quantized logsoftmax yet.");
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
// Ignore range stat, fix range to [0, 1]. For large depth, each softmax
// output may be too small (<<1), which causes precision issue. But it is
......@@ -403,6 +485,8 @@ class SoftmaxOp<DeviceType::CPU, uint8_t> : public Operation {
protected:
bool use_log_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_QUANTIZE
......@@ -421,8 +505,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation {
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
......@@ -430,6 +514,8 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation {
private:
std::unique_ptr<OpenCLSoftmaxKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
......
......@@ -22,17 +22,17 @@ namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void SoftmaxBenchmark(
int iters, int batch, int channels, int height, int width) {
template<DeviceType D, typename T>
void SoftmaxBenchmark(int iters, int batch, int channels,
int height, int width, DataFormat data_format) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::CPU) {
if (D == DeviceType::CPU && data_format == DataFormat::NCHW) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
} else if (D == DeviceType::GPU) {
} else if (D == DeviceType::GPU || data_format == DataFormat::NHWC) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
} else {
MACE_NOT_IMPLEMENTED;
......@@ -42,6 +42,7 @@ void SoftmaxBenchmark(
.Input("Input")
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", data_format == DataFormat::NCHW)
.Finalize(net.NewOperatorDef());
// Warm-up
......@@ -58,10 +59,12 @@ void SoftmaxBenchmark(
}
#ifdef MACE_ENABLE_QUANTIZE
template <>
template<>
void SoftmaxBenchmark<CPU, uint8_t>(
int iters, int batch, int channels, int height, int width) {
int iters, int batch, int channels, int height,
int width, DataFormat data_format) {
mace::testing::StopTiming();
MACE_UNUSED(data_format);
OpsTestNet net;
......@@ -100,33 +103,39 @@ void SoftmaxBenchmark<CPU, uint8_t>(
} // namespace
#define MACE_BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
#define MACE_BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE, DF) \
static void \
MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE##_##DF( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W); \
SoftmaxBenchmark<DEVICE, TYPE>(iters, N, C, H, W, (DataFormat::DF)); \
} \
MACE_BENCHMARK(MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
MACE_BENCHMARK( \
MACE_BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE##_##DF)
#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU, NHWC)
#elif defined(MACE_ENABLE_OPENCL)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU, NHWC)
#elif defined(MACE_ENABLE_QUANTIZE)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU, NHWC)
#else
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU)
#define MACE_BM_SOFTMAX(N, C, H, W) \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NCHW); \
MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU, NHWC)
#endif
MACE_BM_SOFTMAX(1, 2, 512, 512);
......
......@@ -55,6 +55,7 @@ void Simple(bool use_log = false) {
OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntArg("has_data_format", 1)
.AddIntArg("use_log", static_cast<int>(use_log))
.Finalize(net.NewOperatorDef());
......
......@@ -244,6 +244,7 @@ class OpsTestNet {
}
}
}
output->set_data_format(DataFormat::NCHW);
} else if (src_format == DataFormat::NCHW &&
dst_format == DataFormat::NHWC) {
index_t batch = input_shape[0];
......@@ -265,6 +266,7 @@ class OpsTestNet {
}
}
}
output->set_data_format(DataFormat::NHWC);
} else {
MACE_NOT_IMPLEMENTED;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册