diff --git a/docs/user_guide/benchmark.rst b/docs/user_guide/benchmark.rst index 3a9258e9c5a1eeb3eaff1e2473c3d60adb3b2923..f47f01230d20b101433511dd499f63d98adef59a 100644 --- a/docs/user_guide/benchmark.rst +++ b/docs/user_guide/benchmark.rst @@ -88,7 +88,7 @@ or for Bazel users: .. code-block:: bash - python tools/python/converter.py run --config=/path/to/your/model_deployment.yml --benchmark + python tools/converter.py run --config=/path/to/your/model_deployment.yml --benchmark ====== Output diff --git a/mace/core/quantize.h b/mace/core/quantize.h index 439c9522b20212495e3fc7d0e2a5a9b79df012f0..2424f03cc053cb361b88fbd1d30d8193ddda184c 100644 --- a/mace/core/quantize.h +++ b/mace/core/quantize.h @@ -77,6 +77,11 @@ inline Q Saturate(float value) { } } +template +inline Q Quantize(float value, float scale, int32_t zero_point) { + return Saturate(std::roundf(value / scale + zero_point)); +} + inline void FindMinMax(const float *input, const index_t size, float *min_val, float *max_val) { diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index 7d3b1e4d792a1354cc7e10e6d256e2593d0ee9cc..c542d98e467e74edf17198d713d5cca126e44bde 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -121,6 +121,10 @@ void RegisterActivation(OpRegistry *op_registry) { DeviceType::CPU, float); MACE_REGISTER_BF16_OP(op_registry, "Activation", ActivationOp, DeviceType::CPU); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_OP(op_registry, "Activation", ActivationOp, + DeviceType::CPU, uint8_t); +#endif // MACE_ENABLE_QUANTIZE MACE_REGISTER_GPU_OP(op_registry, "Activation", ActivationOp); MACE_REGISTER_OP_CONDITION( op_registry, diff --git a/mace/ops/arm/base/activation.cc b/mace/ops/arm/base/activation.cc index 6531616ae0ab8b2b749e886a3e2f4431ceb50856..90135b17974c5d04108489eade3f05b7e8d93d9c 100644 --- a/mace/ops/arm/base/activation.cc +++ b/mace/ops/arm/base/activation.cc @@ -37,36 +37,32 @@ template void Activation::DoActivation(const OpContext *context, const Tensor *input, Tensor *output) { - const T *input_data = input->data(); - T *output_data = output->mutable_data(); - const index_t size = input->size(); - utils::ThreadPool &thread_pool = context->device()->cpu_runtime()->thread_pool(); switch (type_) { case RELU: { - ActivateRelu(&thread_pool, input_data, size, output_data); + ActivateRelu(&thread_pool, input, output); break; } case RELUX: { - ActivateRelux(&thread_pool, input_data, size, output_data); + ActivateRelux(&thread_pool, input, output); break; } case LEAKYRELU: { - ActivateLeakyRelu(&thread_pool, input_data, size, output_data); + ActivateLeakyRelu(&thread_pool, input, output); break; } case TANH: { - ActivateTanh(&thread_pool, input_data, size, output_data); + ActivateTanh(&thread_pool, input, output); break; } case SIGMOID: { - ActivateSigmoid(&thread_pool, input_data, size, output_data); + ActivateSigmoid(&thread_pool, input, output); break; } @@ -84,6 +80,11 @@ void RegisterActivationDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, Activation, delegator::ActivationParam, MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, float, ImplType::NEON)); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_DELEGATOR( + registry, Activation, delegator::ActivationParam, + MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, uint8_t, ImplType::NEON)); +#endif // MACE_ENABLE_QUANTIZE } } // namespace arm diff --git a/mace/ops/arm/base/activation.h b/mace/ops/arm/base/activation.h index aac917c9b37642ae8b452331cf86d5b0e51407f4..3e82b1d6d9a21d2b535195c7802a3b8a4b5aab1b 100644 --- a/mace/ops/arm/base/activation.h +++ b/mace/ops/arm/base/activation.h @@ -35,16 +35,16 @@ class Activation : public delegator::Activation { void DoActivation(const OpContext *context, const Tensor *input, Tensor *output); - void ActivateRelu(utils::ThreadPool *thread_pool, const T *input_data, - const index_t input_size, T *output_data); - void ActivateRelux(utils::ThreadPool *thread_pool, const T *input_data, - const index_t input_size, T *output_data); - void ActivateLeakyRelu(utils::ThreadPool *thread_pool, const T *input_data, - const index_t input_size, T *output_data); - void ActivateTanh(utils::ThreadPool *thread_pool, const T *input_data, - const index_t input_size, T *output_data); - void ActivateSigmoid(utils::ThreadPool *thread_pool, const T *input_data, - const index_t input_size, T *output_data); + void ActivateRelu(utils::ThreadPool *thread_pool, const Tensor *input, + Tensor *output); + void ActivateRelux(utils::ThreadPool *thread_pool, const Tensor *input, + Tensor *output); + void ActivateLeakyRelu(utils::ThreadPool *thread_pool, const Tensor *input, + Tensor *output); + void ActivateTanh(utils::ThreadPool *thread_pool, const Tensor *input, + Tensor *output); + void ActivateSigmoid(utils::ThreadPool *thread_pool, const Tensor *input, + Tensor *output); }; } // namespace arm diff --git a/mace/ops/arm/base/bias_add.cc b/mace/ops/arm/base/bias_add.cc index 42357a48e8ce04f5199c39e0c428abcd1562f6e6..3ae9e4162fcb99560b9c46a212c0f82ca493c5ba 100644 --- a/mace/ops/arm/base/bias_add.cc +++ b/mace/ops/arm/base/bias_add.cc @@ -19,8 +19,11 @@ namespace ops { namespace arm { template -MaceStatus BiasAdd::Compute(const OpContext *context, const Tensor *input, - const Tensor *bias, Tensor *output) { +MaceStatus BiasAdd::Compute(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output, + const bool isNCHW) { if (input != output) { if (bias == nullptr) { output->Copy(*input); @@ -29,13 +32,13 @@ MaceStatus BiasAdd::Compute(const OpContext *context, const Tensor *input, Tensor::MappingGuard input_guard(input); Tensor::MappingGuard bias_guard(bias); Tensor::MappingGuard output_guard(output); - AddBias(context, input, bias, output); + AddBias(context, input, bias, output, isNCHW); } } else { if (bias != nullptr) { Tensor::MappingGuard input_guard(input); Tensor::MappingGuard bias_guard(bias); - AddBias(context, input, bias, output); + AddBias(context, input, bias, output, isNCHW); } } @@ -43,28 +46,26 @@ MaceStatus BiasAdd::Compute(const OpContext *context, const Tensor *input, } template -void BiasAdd::AddBias(const OpContext *context, const Tensor *input, - const Tensor *bias, mace::Tensor *output) { - auto input_data = input->data(); - auto bias_data = bias->data(); - auto output_data = output->mutable_data(); - - const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - - const index_t height = input->dim(2); - const index_t width = input->dim(3); - const index_t image_size = height * width; - +void BiasAdd::AddBias(const OpContext *context, + const Tensor *input, + const Tensor *bias, + mace::Tensor *output, + const bool isNCHW) { utils::ThreadPool &thread_pool = context->device()->cpu_runtime()->thread_pool(); - if (bias->dim_size() == 1) { - Add1DimBias(&thread_pool, input_data, bias_data, - output_data, batch, channels, image_size); + if (isNCHW) { + if (bias->dim_size() == 1) { + AddBiasNCHW<1>(&thread_pool, input, bias, output); + } else { + AddBiasNCHW<2>(&thread_pool, input, bias, output); + } } else { - Add2DimsBias(&thread_pool, input_data, bias_data, - output_data, batch, channels, image_size); + if (bias->dim_size() == 1) { + AddBiasNHWC<1>(&thread_pool, input, bias, output); + } else { + AddBiasNHWC<2>(&thread_pool, input, bias, output); + } } } @@ -72,6 +73,11 @@ void RegisterBiasAddDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, BiasAdd, DelegatorParam, MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, float, ImplType::NEON)); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_DELEGATOR( + registry, BiasAdd, DelegatorParam, + MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, uint8_t, ImplType::NEON)); +#endif // MACE_ENABLE_QUANTIZE } } // namespace arm diff --git a/mace/ops/arm/base/bias_add.h b/mace/ops/arm/base/bias_add.h index b0e2e1c09ef19a0d77a817bf55f2992282973b31..855d828680022fe739cc15152b7f0c7fd05a90c6 100644 --- a/mace/ops/arm/base/bias_add.h +++ b/mace/ops/arm/base/bias_add.h @@ -27,24 +27,41 @@ class BiasAdd : public delegator::BiasAdd { explicit BiasAdd(const DelegatorParam ¶m) : delegator::BiasAdd(param) {} ~BiasAdd() = default; - MaceStatus Compute(const OpContext *context, const Tensor *input, - const Tensor *bias, Tensor *output) override; + MaceStatus Compute(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output, + const bool isNCHW = true) override; private: - void AddBias(const OpContext *context, const Tensor *input, - const Tensor *bias, Tensor *output); - - void Add1DimBias(utils::ThreadPool *thread_pool, const T *input_data, - const T *bias_data, T *output_data, - const index_t batch, const index_t channels, - const index_t image_size); - - void Add2DimsBias(utils::ThreadPool *thread_pool, const T *input_data, - const T *bias_data, T *output_data, - const index_t batch, const index_t channels, - const index_t image_size); + void AddBias(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output, + const bool isNCHW = true); + + template + void AddBiasNCHW(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); + template + void AddBiasNHWC(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); }; +template +inline index_t bias_index(index_t offset, index_t channel) { + return offset + channel; +} + +template <> +inline index_t bias_index<1>(index_t offset, index_t channel) { + MACE_UNUSED(offset); + return channel; +} } // namespace arm } // namespace ops } // namespace mace diff --git a/mace/ops/arm/fp32/activation.cc b/mace/ops/arm/fp32/activation.cc index add68ad01e3b0ea93fcce29ba05768ee3d696ae7..4f3daa0f7a762d4f8bd732c654785fb50d0cd74d 100644 --- a/mace/ops/arm/fp32/activation.cc +++ b/mace/ops/arm/fp32/activation.cc @@ -23,9 +23,11 @@ namespace arm { template<> void Activation::ActivateRelu(utils::ThreadPool *thread_pool, - const float *input_data, - const index_t input_size, - float *output_data) { + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); const float32x4_t vzero = vdupq_n_f32(0.f); const index_t block_count = input_size / 4; @@ -53,9 +55,11 @@ void Activation::ActivateRelu(utils::ThreadPool *thread_pool, template<> void Activation::ActivateRelux(utils::ThreadPool *thread_pool, - const float *input_data, - const index_t input_size, - float *output_data) { + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); const float32x4_t vzero = vdupq_n_f32(0.f); const float32x4_t vlimit = vdupq_n_f32(limit_); const index_t block_count = input_size / 4; @@ -85,9 +89,11 @@ void Activation::ActivateRelux(utils::ThreadPool *thread_pool, template<> void Activation::ActivateLeakyRelu(utils::ThreadPool *thread_pool, - const float *input_data, - const index_t input_size, - float *output_data) { + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); const float32x4_t vzero = vdupq_n_f32(0.f); const float32x4_t valpha = vdupq_n_f32(leakyrelu_coefficient_); const index_t block_count = input_size / 4; @@ -119,9 +125,12 @@ void Activation::ActivateLeakyRelu(utils::ThreadPool *thread_pool, template<> void Activation::ActivateTanh(utils::ThreadPool *thread_pool, - const float *input_data, - const index_t input_size, - float *output_data) { + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + thread_pool->Compute1D( [=](index_t start, index_t end, index_t step) { for (index_t i = start; i < end; i += step) { @@ -133,9 +142,12 @@ void Activation::ActivateTanh(utils::ThreadPool *thread_pool, template<> void Activation::ActivateSigmoid(utils::ThreadPool *thread_pool, - const float *input_data, - const index_t input_size, - float *output_data) { + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + thread_pool->Compute1D( [=](index_t start, index_t end, index_t step) { for (index_t i = start; i < end; i += step) { diff --git a/mace/ops/arm/fp32/bias_add.cc b/mace/ops/arm/fp32/bias_add.cc index 042d306d8475ca850ee61cdc0d14185038543ecb..2c0d83261fa33d9f1ac4fb9715669d2e34db820d 100644 --- a/mace/ops/arm/fp32/bias_add.cc +++ b/mace/ops/arm/fp32/bias_add.cc @@ -12,82 +12,112 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "mace/ops/arm/base/bias_add.h" +#include + namespace mace { namespace ops { namespace arm { -template<> -void BiasAdd::Add1DimBias( - utils::ThreadPool *thread_pool, const float *input_data, - const float *bias_data, float *output_data, const index_t batch, - const index_t channels, const index_t image_size) { +template <> +template +void BiasAdd::AddBiasNCHW(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + auto input_data = input->data(); + auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t image_size = input->dim(2) * input->dim(3); const index_t block_count = image_size / 4; const index_t remain = image_size % 4; - thread_pool->Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - const index_t b_offset = b * channels; - for (index_t c = start1; c < end1; c += step1) { - const index_t offset = (b_offset + c) * image_size; - auto input_ptr = input_data + offset; - auto output_ptr = output_data + offset; - const float bias = bias_data[c]; - float32x4_t vbias = vdupq_n_f32(bias); + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + const index_t b_offset = b * channels; + for (index_t c = start1; c < end1; c += step1) { + const index_t offset = (b_offset + c) * image_size; + auto input_ptr = input_data + offset; + auto output_ptr = output_data + offset; + const float bias = bias_data[bias_index(b_offset, c)]; + float32x4_t vbias = vdupq_n_f32(bias); - for (index_t i = 0; i < block_count; ++i) { - float32x4_t v = vld1q_f32(input_ptr); - v = vaddq_f32(v, vbias); - vst1q_f32(output_ptr, v); + for (index_t i = 0; i < block_count; ++i) { + float32x4_t v = vld1q_f32(input_ptr); + v = vaddq_f32(v, vbias); + vst1q_f32(output_ptr, v); - input_ptr += 4; - output_ptr += 4; - } - for (index_t i = 0; i < remain; ++i) { - (*output_ptr++) = (*input_ptr++) + bias; + input_ptr += 4; + output_ptr += 4; + } + for (index_t i = 0; i < remain; ++i) { + (*output_ptr++) = (*input_ptr++) + bias; + } + } } - } - } - }, 0, batch, 1, 0, channels, 1); + }, + 0, batch, 1, 0, channels, 1); } -template<> -void BiasAdd::Add2DimsBias( - utils::ThreadPool *thread_pool, const float *input_data, - const float *bias_data, float *output_data, const index_t batch, - const index_t channels, const index_t image_size) { - const index_t block_count = image_size / 4; - const index_t remain = image_size % 4; - thread_pool->Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - const index_t b_offset = b * channels; - for (index_t c = start1; c < end1; c += step1) { - const index_t offset = (b_offset + c) * image_size; - auto input_ptr = input_data + offset; - auto output_ptr = output_data + offset; - const float bias = bias_data[b * channels + c]; - float32x4_t vbias = vdupq_n_f32(bias); - - for (index_t i = 0; i < block_count; ++i) { - float32x4_t v = vld1q_f32(input_ptr); - v = vaddq_f32(v, vbias); - vst1q_f32(output_ptr, v); +template <> +template +void BiasAdd::AddBiasNHWC(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + const float *input_ptr = input->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); - input_ptr += 4; - output_ptr += 4; + const std::vector &shape = input->shape(); + const index_t channels = *shape.rbegin(); + const auto batch = shape[0]; + if (Dim == 2) { + MACE_CHECK(batch == bias->shape()[0]); + } + const index_t fused_hw = std::accumulate(shape.begin() + 1, shape.end() - 1, + 1, std::multiplies()); + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + auto offset = i * fused_hw; + auto bias_offset = i * channels; + for (index_t j = start1; j < end1; j += step1) { + index_t pos = (offset + j) * channels; + for (index_t c = 0; c < channels; ++c, ++pos) { + output_ptr[pos] = + input_ptr[pos] + bias_ptr[bias_index(bias_offset, c)]; + } + } } - for (index_t i = 0; i < remain; ++i) { - (*output_ptr++) = (*input_ptr++) + bias; - } - } - } - }, 0, batch, 1, 0, channels, 1); + }, + 0, batch, 1, 0, fused_hw, 1); } +template void BiasAdd::AddBiasNCHW<1>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +template void BiasAdd::AddBiasNCHW<2>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); + +template void BiasAdd::AddBiasNHWC<1>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +template void BiasAdd::AddBiasNHWC<2>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); + } // namespace arm } // namespace ops } // namespace mace diff --git a/mace/ops/arm/q8/activation.cc b/mace/ops/arm/q8/activation.cc new file mode 100644 index 0000000000000000000000000000000000000000..875f8ba442ee55b75036b4a1776b1de73564e056 --- /dev/null +++ b/mace/ops/arm/q8/activation.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/quantize.h" +#include "mace/ops/arm/base/activation.h" + +namespace mace { +namespace ops { +namespace arm { + +template<> +void Activation::ActivateRelu(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + output->SetScale(input->scale()); + output->SetZeroPoint(input->zero_point()); + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + const uint8x16_t vzero = vdupq_n_u8(input->zero_point()); + const index_t block_count = input_size / 16; + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + auto input_ptr = input_data + start * 16; + auto output_ptr = output_data + start * 16; + + for (index_t i = start; i < end; i += step) { + uint8x16_t v = vld1q_u8(input_ptr); + v = vmaxq_u8(v, vzero); + vst1q_u8(output_ptr, v); + + input_ptr += 16; + output_ptr += 16; + } + }, 0, block_count, 1); + + // remain + for (index_t i = block_count * 16; i < input_size; ++i) { + output_data[i] = std::max(input->zero_point(), input_data[i]); + } +} + +template<> +void Activation::ActivateRelux(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + output->SetScale(input->scale()); + output->SetZeroPoint(input->zero_point()); + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + const uint8x16_t vzero = vdupq_n_u8(input->zero_point()); + const uint8_t limit = + Quantize(limit_, input->scale(), input->zero_point()); + const uint8x16_t vlimit = vdupq_n_u8(limit); + const index_t block_count = input_size / 16; + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + auto input_ptr = input_data + start * 16; + auto output_ptr = output_data + start * 16; + + for (index_t i = start; i < end; i += step) { + uint8x16_t v = vld1q_u8(input_ptr); + v = vmaxq_u8(v, vzero); + v = vminq_u8(v, vlimit); + vst1q_u8(output_ptr, v); + + input_ptr += 16; + output_ptr += 16; + } + }, 0, block_count, 1); + + // remain + for (index_t i = block_count * 16; i < input_size; ++i) { + output_data[i] = + std::max(input->zero_point(), std::min(limit, input_data[i])); + } +} + +template<> +void Activation::ActivateLeakyRelu(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + MACE_UNUSED(thread_pool); + MACE_UNUSED(input); + MACE_UNUSED(output); + MACE_NOT_IMPLEMENTED; +} + +template<> +void Activation::ActivateTanh(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + MACE_UNUSED(thread_pool); + MACE_UNUSED(input); + MACE_UNUSED(output); + MACE_NOT_IMPLEMENTED; +} + +template<> +void Activation::ActivateSigmoid(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + MACE_UNUSED(thread_pool); + MACE_UNUSED(input); + MACE_UNUSED(output); + MACE_NOT_IMPLEMENTED; +} + +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/q8/bias_add.cc b/mace/ops/arm/q8/bias_add.cc new file mode 100644 index 0000000000000000000000000000000000000000..c69727dee56188bf9c612c722c052980903263e0 --- /dev/null +++ b/mace/ops/arm/q8/bias_add.cc @@ -0,0 +1,279 @@ +// Copyright 2020 The MACE Authors. All Rights sumerved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expsums or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/arm/base/bias_add.h" + +#include + +#include "mace/core/quantize.h" + +namespace mace { +namespace ops { +namespace arm { + +template <> +template +void BiasAdd::AddBiasNCHW(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + auto input_data = input->data(); + auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t image_size = input->dim(2) * input->dim(3); + const index_t block_count = image_size / 8; + const index_t remain = image_size % 8; + + constexpr int left_shift = 20; + const double doubled_scale = 2 * std::max(input->scale(), bias->scale()); + const double adjusted_input_scale = input->scale() / doubled_scale; + const double adjusted_bias_scale = bias->scale() / doubled_scale; + const double adjusted_output_scale = + doubled_scale / ((1 << left_shift) * output->scale()); + + int32_t input_multiplier; + int32_t bias_multiplier; + int32_t output_multiplier; + int32_t input_shift; + int32_t bias_shift; + int32_t output_shift; + QuantizeMultiplier(adjusted_input_scale, &input_multiplier, &input_shift); + QuantizeMultiplier(adjusted_bias_scale, &bias_multiplier, &bias_shift); + QuantizeMultiplier(adjusted_output_scale, &output_multiplier, &output_shift); + const auto left_shift_dup = vdupq_n_s32(left_shift); + const auto input_shift_dup = vdupq_n_s32(input_shift); + + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + const index_t b_offset = b * channels; + for (index_t c = start1; c < end1; c += step1) { + const index_t offset = (b_offset + c) * image_size; + auto input_ptr = input_data + offset; + auto output_ptr = output_data + offset; + + const int32_t offset_bias = + bias_data[bias_index(b_offset, c)] - bias->zero_point(); + const int32_t shifted_bias = offset_bias * (1 << left_shift); + const int32_t multiplied_bias = gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul(shifted_bias, + bias_multiplier), + -bias_shift); + const auto bias_low_s32 = vdupq_n_s32(multiplied_bias); + const auto bias_high_s32 = vdupq_n_s32(multiplied_bias); + + for (index_t i = 0; i < block_count; ++i) { + const auto input_val = vld1_u8(input_ptr); + const auto input_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input_val)); + const auto offset_input = + vaddq_s16(input_val_s16, vdupq_n_s16(-input->zero_point())); + auto input_low_s32 = vmovl_s16(vget_low_s16(offset_input)); + auto input_high_s32 = vmovl_s16(vget_high_s16(offset_input)); + input_low_s32 = vshlq_s32(input_low_s32, left_shift_dup); + input_high_s32 = vshlq_s32(input_high_s32, left_shift_dup); + input_low_s32 = vqrdmulhq_n_s32(input_low_s32, input_multiplier); + input_high_s32 = + vqrdmulhq_n_s32(input_high_s32, input_multiplier); + input_low_s32 = vshlq_s32(input_low_s32, input_shift_dup); + input_high_s32 = vshlq_s32(input_high_s32, input_shift_dup); + auto sum_low = vaddq_s32(input_low_s32, bias_low_s32); + auto sum_high = vaddq_s32(input_high_s32, bias_high_s32); + sum_low = vqrdmulhq_n_s32(sum_low, output_multiplier); + sum_high = vqrdmulhq_n_s32(sum_high, output_multiplier); + sum_low = gemmlowp::RoundingDivideByPOT(sum_low, -output_shift); + sum_high = gemmlowp::RoundingDivideByPOT(sum_high, -output_shift); + const auto sum_low_s16 = vmovn_s32(sum_low); + const auto sum_high_s16 = vmovn_s32(sum_high); + const auto output_val = + vaddq_s16(vcombine_s16(sum_low_s16, sum_high_s16), + vdupq_n_s16(output->zero_point())); + vst1_u8(output_ptr, vqmovun_s16(output_val)); + + input_ptr += 8; + output_ptr += 8; + } + + for (index_t i = 0; i < remain; ++i) { + const int32_t offset_input = input_ptr[i] - input->zero_point(); + const int32_t shifted_input = offset_input * (1 << left_shift); + const int32_t multiplied_input = gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input, + input_multiplier), + -input_shift); + int32_t sum = multiplied_input + multiplied_bias; + const int32_t output_val = + gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul( + sum, output_multiplier), + -output_shift) + + output->zero_point(); + output_ptr[i] = Saturate(output_val); + } + } + } + }, + 0, batch, 1, 0, channels, 1); +} + +template <> +template +void BiasAdd::AddBiasNHWC(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + auto input_data = input->data(); + auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + const std::vector &shape = input->shape(); + const index_t channels = *shape.rbegin(); + const index_t block_count = channels / 8; + const index_t remain = channels % 8; + + constexpr int left_shift = 20; + const double doubled_scale = 2 * std::max(input->scale(), bias->scale()); + const double adjusted_input_scale = input->scale() / doubled_scale; + const double adjusted_bias_scale = bias->scale() / doubled_scale; + const double adjusted_output_scale = + doubled_scale / ((1 << left_shift) * output->scale()); + + int32_t input_multiplier; + int32_t bias_multiplier; + int32_t output_multiplier; + int32_t input_shift; + int32_t bias_shift; + int32_t output_shift; + QuantizeMultiplier(adjusted_input_scale, &input_multiplier, &input_shift); + QuantizeMultiplier(adjusted_bias_scale, &bias_multiplier, &bias_shift); + QuantizeMultiplier(adjusted_output_scale, &output_multiplier, &output_shift); + const auto left_shift_dup = vdupq_n_s32(left_shift); + const auto input_shift_dup = vdupq_n_s32(input_shift); + const auto bias_shift_dup = vdupq_n_s32(bias_shift); + + const auto batch = shape[0]; + if (Dim == 2) { + MACE_CHECK(batch == bias->shape()[0]); + } + const index_t fused_hw = std::accumulate(shape.begin() + 1, shape.end() - 1, + 1, std::multiplies()); + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + auto offset = i * fused_hw; + auto bias_offset = i * channels; + for (index_t j = start1; j < end1; j += step1) { + index_t pos = (offset + j) * channels; + auto input_ptr = input_data + pos; + auto output_ptr = output_data + pos; + auto bias_ptr = bias_data + bias_index(bias_offset, 0); + for (index_t c = 0; c < block_count; ++c) { + const auto input_val = vld1_u8(input_ptr); + const auto bias_val = vld1_u8(bias_ptr); + const auto input_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input_val)); + const auto bias_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(bias_val)); + const auto offset_input = + vaddq_s16(input_val_s16, vdupq_n_s16(-input->zero_point())); + const auto offset_bias = + vaddq_s16(bias_val_s16, vdupq_n_s16(-bias->zero_point())); + auto input_low_s32 = vmovl_s16(vget_low_s16(offset_input)); + auto input_high_s32 = vmovl_s16(vget_high_s16(offset_input)); + auto bias_low_s32 = vmovl_s16(vget_low_s16(offset_bias)); + auto bias_high_s32 = vmovl_s16(vget_high_s16(offset_bias)); + input_low_s32 = vshlq_s32(input_low_s32, left_shift_dup); + input_high_s32 = vshlq_s32(input_high_s32, left_shift_dup); + bias_low_s32 = vshlq_s32(bias_low_s32, left_shift_dup); + bias_high_s32 = vshlq_s32(bias_high_s32, left_shift_dup); + input_low_s32 = vqrdmulhq_n_s32(input_low_s32, input_multiplier); + input_high_s32 = + vqrdmulhq_n_s32(input_high_s32, input_multiplier); + bias_low_s32 = vqrdmulhq_n_s32(bias_low_s32, bias_multiplier); + bias_high_s32 = vqrdmulhq_n_s32(bias_high_s32, bias_multiplier); + input_low_s32 = vshlq_s32(input_low_s32, input_shift_dup); + input_high_s32 = vshlq_s32(input_high_s32, input_shift_dup); + bias_low_s32 = vshlq_s32(bias_low_s32, bias_shift_dup); + bias_high_s32 = vshlq_s32(bias_high_s32, bias_shift_dup); + int32x4_t sum_low = vaddq_s32(input_low_s32, bias_low_s32); + int32x4_t sum_high = vaddq_s32(input_high_s32, bias_high_s32); + sum_low = vqrdmulhq_n_s32(sum_low, output_multiplier); + sum_high = vqrdmulhq_n_s32(sum_high, output_multiplier); + sum_low = gemmlowp::RoundingDivideByPOT(sum_low, -output_shift); + sum_high = gemmlowp::RoundingDivideByPOT(sum_high, -output_shift); + const auto sum_low_s16 = vmovn_s32(sum_low); + const auto sum_high_s16 = vmovn_s32(sum_high); + const auto output_val = + vaddq_s16(vcombine_s16(sum_low_s16, sum_high_s16), + vdupq_n_s16(output->zero_point())); + vst1_u8(output_ptr, vqmovun_s16(output_val)); + + input_ptr += 8; + bias_ptr += 8; + output_ptr += 8; + } + for (index_t c = 0; c < remain; ++c) { + const int32_t offset_input = input_ptr[c] - input->zero_point(); + const int32_t offset_bias = bias_ptr[c] - bias->zero_point(); + const int32_t shifted_input = offset_input * (1 << left_shift); + const int32_t shifted_bias = offset_bias * (1 << left_shift); + const int32_t multiplied_input = gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input, + input_multiplier), + -input_shift); + const int32_t multiplied_bias = gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul(shifted_bias, + bias_multiplier), + -bias_shift); + + int32_t sum = multiplied_input + multiplied_bias; + + const int32_t output_val = + gemmlowp::RoundingDivideByPOT( + gemmlowp::SaturatingRoundingDoublingHighMul( + sum, output_multiplier), + -output_shift) + + output->zero_point(); + output_ptr[c] = Saturate(output_val); + } + } + } + }, + 0, batch, 1, 0, fused_hw, 1); +} + +template void BiasAdd::AddBiasNCHW<1>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +template void BiasAdd::AddBiasNCHW<2>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +template void BiasAdd::AddBiasNHWC<1>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +template void BiasAdd::AddBiasNHWC<2>(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output); +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/q8/eltwise.cc b/mace/ops/arm/q8/eltwise.cc index 97e50e1bc8faa1aa4ccfc2fd022f33879b04839b..ae1dd50f0be2de3123ba991af39dda7361bc4a71 100644 --- a/mace/ops/arm/q8/eltwise.cc +++ b/mace/ops/arm/q8/eltwise.cc @@ -23,7 +23,31 @@ namespace mace { namespace ops { namespace arm { namespace q8 { - +namespace { +template +inline T EltCompute(T input0, T input1) { + MACE_UNUSED(input0); + MACE_UNUSED(input1); + MACE_NOT_IMPLEMENTED; + return input0; +} +template <> +inline int32x4_t EltCompute(int32x4_t input0, int32x4_t input1) { + return vaddq_s32(input0, input1); +} +template <> +inline int32x4_t EltCompute(int32x4_t input0, int32x4_t input1) { + return vsubq_s32(input0, input1); +} +template <> +inline int32_t EltCompute(int32_t input0, int32_t input1) { + return input0 + input1; +} +template <> +inline int32_t EltCompute(int32_t input0, int32_t input1) { + return input0 - input1; +} +} // namespace class Eltwise : public delegator::Eltwise { public: explicit Eltwise(const delegator::EltwiseParam ¶m) @@ -32,16 +56,32 @@ class Eltwise : public delegator::Eltwise { MaceStatus Compute(const OpContext *context, const Tensor *input0, const Tensor *input1, Tensor *output) override; + + private: + template + MaceStatus ComputeSumSub(const OpContext *context, const Tensor *input0, + const Tensor *input1, Tensor *output); }; MaceStatus Eltwise::Compute(const OpContext *context, const Tensor *input0, const Tensor *input1, Tensor *output) { - MACE_UNUSED(context); - MACE_CHECK(type_ == SUM || type_ == SUB, - "Quantized Elementwise only support SUM and SUB now."); + if (type_ == SUM) { + return ComputeSumSub(context, input0, input1, output); + } else if (type_ == SUB) { + return ComputeSumSub(context, input0, input1, output); + } else { + MACE_NOT_IMPLEMENTED; + return MaceStatus::MACE_INVALID_ARGS; + } +} +template +MaceStatus Eltwise::ComputeSumSub(const OpContext *context, + const Tensor *input0, + const Tensor *input1, + Tensor *output) { constexpr int left_shift = 20; const double doubled_scale = 2 * std::max(input0->scale(), input1->scale()); const double adjusted_input0_scale = input0->scale() / doubled_scale; @@ -101,14 +141,8 @@ MaceStatus Eltwise::Compute(const OpContext *context, input0_high_s32 = vshlq_s32(input0_high_s32, input0_shift_dup); input1_low_s32 = vshlq_s32(input1_low_s32, input1_shift_dup); input1_high_s32 = vshlq_s32(input1_high_s32, input1_shift_dup); - int32x4_t res_low, res_high; - if (type_ == SUM) { - res_low = vaddq_s32(input0_low_s32, input1_low_s32); - res_high = vaddq_s32(input0_high_s32, input1_high_s32); - } else { - res_low = vsubq_s32(input0_low_s32, input1_low_s32); - res_high = vsubq_s32(input0_high_s32, input1_high_s32); - } + int32x4_t res_low = EltCompute(input0_low_s32, input1_low_s32); + int32x4_t res_high = EltCompute(input0_high_s32, input1_high_s32); res_low = vqrdmulhq_n_s32(res_low, output_multiplier); res_high = vqrdmulhq_n_s32(res_high, output_multiplier); res_low = gemmlowp::RoundingDivideByPOT(res_low, -output_shift); @@ -141,12 +175,7 @@ MaceStatus Eltwise::Compute(const OpContext *context, input1_multiplier), -input1_shift); - int32_t res; - if (type_ == SUM) { - res = multiplied_input0 + multiplied_input1; - } else { - res = multiplied_input0 - multiplied_input1; - } + int32_t res = EltCompute(multiplied_input0, multiplied_input1); const int32_t output_val = gemmlowp::RoundingDivideByPOT( @@ -162,6 +191,15 @@ MaceStatus Eltwise::Compute(const OpContext *context, return MaceStatus::MACE_SUCCESS; } +template MaceStatus Eltwise::ComputeSumSub(const OpContext *context, + const Tensor *input0, + const Tensor *input1, + Tensor *output); +template MaceStatus Eltwise::ComputeSumSub(const OpContext *context, + const Tensor *input0, + const Tensor *input1, + Tensor *output); + void RegisterEltwiseDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, Eltwise, delegator::EltwiseParam, diff --git a/mace/ops/arm/q8/gemv.cc b/mace/ops/arm/q8/gemv.cc index 4e45ae2ac753ad37414b6418837010bc11c22555..9dfd54106f5867f1f7cf6973fc47e08518bf0856 100644 --- a/mace/ops/arm/q8/gemv.cc +++ b/mace/ops/arm/q8/gemv.cc @@ -163,8 +163,8 @@ MaceStatus Gemv::Compute(const OpContext *context, } if (is_output_type_uint8_) { - *output_ptr = - Saturate(std::roundf(ret * output_multiplier_float)); + *output_ptr = Saturate(std::roundf( + ret * output_multiplier_float + output->zero_point())); } else { *output_ptr = ret; } diff --git a/mace/ops/bias_add.cc b/mace/ops/bias_add.cc index 4d476ea34df2b9c045039fc6f10bd9722af18565..88f7ab7701950e3266b87c00ff8e6a71cd96f96e 100644 --- a/mace/ops/bias_add.cc +++ b/mace/ops/bias_add.cc @@ -50,63 +50,15 @@ class BiasAddOp : public Operation { const Tensor *bias = this->Input(1); Tensor *output = this->Output(0); - if (input->dim_size() == 4 && (has_data_format_ - || input->data_format() == DataFormat::NCHW)) { // NCHW - MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2, - "bias must be 1-dimensional or n*c for caffee.", - MakeString(bias->shape())); - bias_add_delegator_->Compute(context, input, bias, output); - } else { // NHWC - MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2, - "bias must be 1 or 2 dimensionals for caffee.", + MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2, + "bias must be 1 or 2 dimensionals for caffe.", bias->dim_size(), MakeString(bias->shape())); - // TODO(liyin): remove it and tranform bias to add (eltwise) - MACE_RETURN_IF_ERROR(output->ResizeLike(input)); - - Tensor::MappingGuard input_mapper(input); - Tensor::MappingGuard bias_mapper(bias); - Tensor::MappingGuard output_mapper(output); - - const T *input_ptr = input->data(); - const T *bias_ptr = bias->data(); - T *output_ptr = output->mutable_data(); - - const std::vector &shape = input->shape(); - const index_t channels = *shape.rbegin(); - utils::ThreadPool - &thread_pool = context->device()->cpu_runtime()->thread_pool(); - if (bias->dim_size() == 1) { - const index_t fused_batch = std::accumulate( - shape.begin(), shape.end() - 1, 1, std::multiplies()); - thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { - for (index_t n = start; n < end; n += step) { - index_t pos = n * channels; - for (index_t c = 0; c < channels; ++c) { - output_ptr[pos] = input_ptr[pos] + bias_ptr[c]; - ++pos; - } - } - }, 0, fused_batch, 1); - } else { // bias is 2d - const auto n = shape[0]; - MACE_CHECK(n == bias->shape()[0]); - const index_t fused_hw = std::accumulate( - shape.begin() + 1, shape.end() - 1, 1, std::multiplies()); - const auto ch_size = bias->shape()[1]; - thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t i = start0; i < end0; i += step0) { - auto offset = i * fused_hw; - auto bias_offset = i * ch_size; - for (index_t j = start1; j < end1; j += step1) { - index_t pos = (offset + i) * channels; - for (index_t c = 0; c < channels; ++c, ++pos) { - output_ptr[pos] = input_ptr[pos] + bias_ptr[bias_offset + c]; - } - } - } - }, 0, n, 1, 0, fused_hw, 1); - } + if (input->dim_size() == 4 && + ((has_data_format_ && DataTypeToEnum::value != DT_UINT8) || + input->data_format() == DataFormat::NCHW)) { // NCHW + bias_add_delegator_->Compute(context, input, bias, output, true); + } else { // NHWC + bias_add_delegator_->Compute(context, input, bias, output, false); } return MaceStatus::MACE_SUCCESS; @@ -160,9 +112,11 @@ class BiasAddOp : public Operation { #endif // MACE_ENABLE_OPENCL void RegisterBiasAdd(OpRegistry *op_registry) { - MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, - DeviceType::CPU, float); + MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, DeviceType::CPU, float); MACE_REGISTER_BF16_OP(op_registry, "BiasAdd", BiasAddOp, DeviceType::CPU); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, DeviceType::CPU, uint8_t); +#endif // MACE_ENABLE_QUANTIZE MACE_REGISTER_GPU_OP(op_registry, "BiasAdd", BiasAddOp); MACE_REGISTER_OP_CONDITION( op_registry, diff --git a/mace/ops/delegator/bias_add.h b/mace/ops/delegator/bias_add.h index f5fdea0deea984cf2450d2f17cd29c6913a35bd9..29b19c6327b9d08716931dbcbad51853c4de1305 100644 --- a/mace/ops/delegator/bias_add.h +++ b/mace/ops/delegator/bias_add.h @@ -33,7 +33,8 @@ class BiasAdd : public OpDelegator { virtual MaceStatus Compute(const OpContext *context, const Tensor *input, const Tensor *bias, - Tensor *output) = 0; + Tensor *output, + const bool isNCHW = true) = 0; }; } // namespace delegator diff --git a/mace/ops/pad.cc b/mace/ops/pad.cc index e995ba6c176b27396f34eb2d66c5d76f76956d45..bc595385434b3633ab1a179fa91b2a936c5a028b 100644 --- a/mace/ops/pad.cc +++ b/mace/ops/pad.cc @@ -16,6 +16,9 @@ #include #include "mace/core/ops/operator.h" +#ifdef MACE_ENABLE_QUANTIZE +#include "mace/core/quantize.h" +#endif // MACE_ENABLE_QUANTIZE #include "mace/core/registry/ops_registry.h" #include "mace/ops/common/pad_type.h" #ifdef MACE_ENABLE_OPENCL @@ -24,6 +27,29 @@ #include "mace/utils/memory.h" #include "mace/utils/math.h" +namespace { +int get_src_idx(int out, int in_size, int pad, int l_add, int r_add) { + const int diff_left = pad - out; + int in; + + if (diff_left > 0) { + in = diff_left + l_add; + + } else { + const int diff_right = out - (in_size + pad); + + if (diff_right >= 0) { + in = in_size - diff_right + r_add; + + } else { + in = -diff_left; + } + } + + return in; +} +} // namespace + namespace mace { namespace ops { @@ -143,31 +169,128 @@ class PadOp : public Operation { } private: - int get_src_idx(int out, int in_size, int pad, int l_add, int r_add) { - const int diff_left = pad - out; - int in; + PadType type_; + std::vector paddings_; + float constant_value_; +}; - if (diff_left > 0) { - in = diff_left + l_add; +#ifdef MACE_ENABLE_QUANTIZE +template<> +class PadOp : public Operation { + public: + explicit PadOp(OpConstructContext *context) + : Operation(context), + type_( + static_cast(Operation::GetOptionalArg( + "pad_type", static_cast(PadType::CONSTANT)))), + paddings_(Operation::GetRepeatedArgs("paddings")), + constant_value_(Operation::GetOptionalArg( + "constant_value", 0.0)) { + MACE_CHECK(paddings_.size() == 8); + } - } else { - const int diff_right = out - (in_size + pad); + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + MACE_CHECK( + this->paddings_.size() == static_cast(input->dim_size()) * 2); + auto input_shape = input->shape(); + for (size_t i = 0; i < paddings_.size(); ++i) { + if (type_ == PadType::REFLECT || type_ == PadType::SYMMETRIC) { + MACE_CHECK(paddings_[i] < input_shape[i / 2], paddings_[i], + " vs ", input_shape[i / 2]); + } + MACE_CHECK(paddings_[i] >= 0); + } + output->SetScale(input->scale()); + output->SetZeroPoint(input->zero_point()); + MACE_RETURN_IF_ERROR(output->Resize({input_shape[0] + this->paddings_[0] + + this->paddings_[1], + input_shape[1] + this->paddings_[2] + + this->paddings_[3], + input_shape[2] + this->paddings_[4] + + this->paddings_[5], + input_shape[3] + this->paddings_[6] + + this->paddings_[7]})); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const auto input_ptr = input->data(); + auto output_ptr = output->mutable_data(); + + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channel = input->dim(3); + + if (type_ == PadType::CONSTANT) { + uint8_t constant = Quantize( + this->constant_value_, input->scale(), input->zero_point()); + std::fill(output_ptr, output_ptr + output->size(), constant); + + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < height; ++h) { + for (index_t w = 0; w < width; ++w) { + const index_t in_offset = (((b * height + h) * width) + + w) * channel; + const index_t out_offset = + (((b + this->paddings_[0]) * output->dim(1) + + (h + this->paddings_[2])) * output->dim(2) + + (w + this->paddings_[4])) * output->dim(3) + + this->paddings_[6]; + memcpy(output_ptr + out_offset, + input_ptr + in_offset, + channel * sizeof(uint8_t)); + } + } + } + } else if (type_ == PadType::REFLECT || type_ == PadType::SYMMETRIC) { + const index_t o_batch = output->dim(0); + const index_t o_height = output->dim(1); + const index_t o_width = output->dim(2); + const index_t o_channel = output->dim(3); + const int l_add = type_ == PadType::REFLECT ? 0 : -1; + const int r_add = type_ == PadType::REFLECT ? -2 : -1; - if (diff_right >= 0) { - in = in_size - diff_right + r_add; + for (index_t b = 0; b < o_batch; ++b) { + index_t b_in = get_src_idx(b, batch, paddings_[0], l_add, r_add); + for (index_t h = 0; h < o_height; ++h) { + index_t h_in = get_src_idx(h, height, paddings_[2], l_add, r_add); + for (index_t w = 0; w < o_width; ++w) { + index_t w_in = get_src_idx(w, width, paddings_[4], l_add, r_add); + const index_t in_offset = + (((b_in * height + h_in) * width) + w_in) * channel; + index_t out_offset = + (((b * o_height + h) * o_width) + w) * o_channel; - } else { - in = -diff_left; + for (index_t i = 0, j = paddings_[6] + l_add; + i < paddings_[6]; ++i, --j) { + output_ptr[out_offset++] = input_ptr[in_offset + j]; + } + memcpy(output_ptr + out_offset, input_ptr + in_offset, + channel * sizeof(uint8_t)); + out_offset += channel; + for (index_t i = 0, j = channel + r_add; i < paddings_[7]; + ++i, --j) { + output_ptr[out_offset++] = input_ptr[in_offset + j]; + } + } + } } + } else { + LOG(FATAL) << "Pad op doesn't support type " << type_; } - return in; + return MaceStatus::MACE_SUCCESS; } + private: PadType type_; std::vector paddings_; float constant_value_; }; +#endif // MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_OPENCL template<> @@ -202,6 +325,9 @@ class PadOp : public Operation { void RegisterPad(OpRegistry *op_registry) { MACE_REGISTER_OP(op_registry, "Pad", PadOp, DeviceType::CPU, float); MACE_REGISTER_BF16_OP(op_registry, "Pad", PadOp, DeviceType::CPU); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_OP(op_registry, "Pad", PadOp, DeviceType::CPU, uint8_t); +#endif // MACE_ENABLE_QUANTIZE MACE_REGISTER_GPU_OP(op_registry, "Pad", PadOp); } diff --git a/mace/ops/ref/bias_add.cc b/mace/ops/ref/bias_add.cc index bab1199e10152340b88d0e20183cffe0bfab20cc..eb65f80cce18ac1f6bc93d35a22c0dce6610fa8d 100644 --- a/mace/ops/ref/bias_add.cc +++ b/mace/ops/ref/bias_add.cc @@ -24,19 +24,29 @@ class BiasAdd : public delegator::BiasAdd { explicit BiasAdd(const DelegatorParam ¶m) : delegator::BiasAdd(param) {} ~BiasAdd() = default; - MaceStatus Compute(const OpContext *context, const Tensor *input, - const Tensor *bias, Tensor *output) override; + MaceStatus Compute(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output, + const bool isNCHW = true) override; private: - void AddBias(const OpContext *context, const Tensor *input, - const Tensor *bias, Tensor *output); + void AddBiasNCHW(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output); + void AddBiasNHWC(const OpContext *context, + const Tensor *input, + const Tensor *bias, + Tensor *output); }; template MaceStatus BiasAdd::Compute(const OpContext *context, const Tensor *input, const Tensor *bias, - Tensor *output) { + Tensor *output, + const bool isNCHW) { Tensor::MappingGuard input_guard(input); Tensor::MappingGuard bias_guard(bias); if (input != output) { @@ -45,11 +55,19 @@ MaceStatus BiasAdd::Compute(const OpContext *context, output->Copy(*input); } else { Tensor::MappingGuard output_guard(output); - AddBias(context, input, bias, output); + if (isNCHW) { + AddBiasNCHW(context, input, bias, output); + } else { + AddBiasNHWC(context, input, bias, output); + } } } else { if (bias != nullptr) { - AddBias(context, input, bias, output); + if (isNCHW) { + AddBiasNCHW(context, input, bias, output); + } else { + AddBiasNHWC(context, input, bias, output); + } } } @@ -57,10 +75,10 @@ MaceStatus BiasAdd::Compute(const OpContext *context, } template -void BiasAdd::AddBias(const OpContext *context, - const Tensor *input, - const Tensor *bias, - mace::Tensor *output) { +void BiasAdd::AddBiasNCHW(const OpContext *context, + const Tensor *input, + const Tensor *bias, + mace::Tensor *output) { MACE_UNUSED(context); auto input_data = input->data(); auto bias_data = bias->data(); @@ -87,6 +105,46 @@ void BiasAdd::AddBias(const OpContext *context, } } +template +void BiasAdd::AddBiasNHWC(const OpContext *context, + const Tensor *input, + const Tensor *bias, + mace::Tensor *output) { + MACE_UNUSED(context); + auto input_data = input->data(); + auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + const auto &shape = input->shape(); + const index_t channels = *shape.rbegin(); + + if (bias->dim_size() == 1) { + const index_t fused_batch = std::accumulate(shape.begin(), shape.end() - 1, + 1, std::multiplies()); + index_t pos = 0; + for (index_t b = 0; b < fused_batch; ++b) { + for (index_t c = 0; c < channels; ++c, ++pos) { + output_data[pos] = input_data[pos] + bias_data[c]; + } + } + } else { + const auto batch = shape[0]; + MACE_CHECK(batch == bias->shape()[0]); + const index_t fused_hw = std::accumulate( + shape.begin() + 1, shape.end() - 1, 1, std::multiplies()); + for (index_t b = 0; b < batch; ++b) { + index_t offset = b * fused_hw; + auto bias_offset = b * channels; + for (index_t hw = 0; hw < fused_hw; ++hw) { + index_t pos = (offset + hw) * channels; + for (index_t c = 0; c < channels; ++c, ++pos) { + output_data[pos] = input_data[pos] + bias_data[bias_offset + c]; + } + } + } + } +} + void RegisterBiasAddDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, BiasAdd, DelegatorParam, diff --git a/mace/ops/ref/q8/eltwise.cc b/mace/ops/ref/q8/eltwise.cc index b34a62ea5b3f1763418a35a581ce80474ceb2f85..a34be3e49c28a18212d88879c30b34ad4748f594 100644 --- a/mace/ops/ref/q8/eltwise.cc +++ b/mace/ops/ref/q8/eltwise.cc @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - -#include "mace/ops/common/gemmlowp_util.h" +#include "mace/core/quantize.h" #include "mace/ops/delegator/eltwise.h" #include "mace/utils/logging.h" @@ -38,29 +35,6 @@ MaceStatus Eltwise::Compute(const OpContext *context, const Tensor *input0, const Tensor *input1, Tensor *output) { - constexpr int left_shift = 20; - const double doubled_scale = 2 * std::max(input0->scale(), input1->scale()); - const double adjusted_input0_scale = input0->scale() / doubled_scale; - const double adjusted_input1_scale = input1->scale() / doubled_scale; - const double adjusted_output_scale = - doubled_scale / ((1 << left_shift) * output->scale()); - - int32_t input0_multiplier; - int32_t input1_multiplier; - int32_t output_multiplier; - int32_t input0_shift; - int32_t input1_shift; - int32_t output_shift; - QuantizeMultiplier(adjusted_input0_scale, - &input0_multiplier, - &input0_shift); - QuantizeMultiplier(adjusted_input1_scale, - &input1_multiplier, - &input1_shift); - QuantizeMultiplier(adjusted_output_scale, - &output_multiplier, - &output_shift); - Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input1_guard(input1); Tensor::MappingGuard output_guard(output); @@ -73,34 +47,19 @@ MaceStatus Eltwise::Compute(const OpContext *context, &thread_pool = context->device()->cpu_runtime()->thread_pool(); thread_pool.Compute1D([=](index_t start, index_t end, index_t step) { for (index_t i = start; i < end; i += step) { - const int32_t offset_input0 = input0_ptr[i] - input0->zero_point(); - const int32_t offset_input1 = input1_ptr[i] - input1->zero_point(); - const int32_t shifted_input0 = offset_input0 * (1 << left_shift); - const int32_t shifted_input1 = offset_input1 * (1 << left_shift); - const int32_t multiplied_input0 = - gemmlowp::RoundingDivideByPOT( - gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input0, - input0_multiplier), - -input0_shift); - const int32_t multiplied_input1 = - gemmlowp::RoundingDivideByPOT( - gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input1, - input1_multiplier), - -input1_shift); - + float real_input0 = + input0->scale() * (input0_ptr[i] - input0->zero_point()); + float real_input1 = + input1->scale() * (input1_ptr[i] - input1->zero_point()); int32_t res; if (type_ == SUM) { - res = multiplied_input0 + multiplied_input1; + res = real_input0 + real_input1; } else { - res = multiplied_input0 - multiplied_input1; + res = real_input0 - real_input1; } - const int32_t output_val = - gemmlowp::RoundingDivideByPOT( - gemmlowp::SaturatingRoundingDoublingHighMul(res, - output_multiplier), - -output_shift) + output->zero_point(); - output_ptr[i] = Saturate(output_val); + output_ptr[i] = + Quantize(res, output->scale(), output->zero_point()); } }, 0, output->size(), 1); diff --git a/mace/ops/transpose.cc b/mace/ops/transpose.cc index 2d15d05312030f05d03f7853d3cac07582fc6dd8..4929164469d491e162bf39ff0d6466ceff19c697 100644 --- a/mace/ops/transpose.cc +++ b/mace/ops/transpose.cc @@ -48,6 +48,8 @@ class TransposeOp : public Operation { output_shape.push_back(input_shape[dims_[i]]); } MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + output->SetScale(input->scale()); + output->SetZeroPoint(input->zero_point()); Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); @@ -69,6 +71,10 @@ void RegisterTranspose(OpRegistry *op_registry) { DeviceType::CPU, half); MACE_REGISTER_BF16_OP(op_registry, "Transpose", TransposeOp, DeviceType::CPU); +#ifdef MACE_ENABLE_QUANTIZE + MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp, + DeviceType::CPU, uint8_t); +#endif // MACE_ENABLE_QUANTIZE } } // namespace ops diff --git a/mace/tools/mace_run.cc b/mace/tools/mace_run.cc index 61fc3369e5a9b4bf350cc2ae547bde778b6bd333..73238bf8b54f7d22b009e2f7f076b451fc39d8d3 100644 --- a/mace/tools/mace_run.cc +++ b/mace/tools/mace_run.cc @@ -225,6 +225,8 @@ bool RunModel(const std::string &model_name, } } + // model_weights_data should be kept the lifetime of MaceEngine if device_type + // is CPU except half/uint8 weights are used to compress model data size. std::unique_ptr model_weights_data = make_unique(); if (FLAGS_model_data_file != "") { diff --git a/test/ccbenchmark/mace/ops/bias_add_benchmark.cc b/test/ccbenchmark/mace/ops/bias_add_benchmark.cc index 7c3a17e45aaf6e4769a539e9fbfb02db64eb59dc..477fc04ad57a6eca1a5ce5183fb6a6999a8791a4 100644 --- a/test/ccbenchmark/mace/ops/bias_add_benchmark.cc +++ b/test/ccbenchmark/mace/ops/bias_add_benchmark.cc @@ -27,14 +27,14 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { OpsTestNet net; // Add input data - if (D == DeviceType::CPU) { + if (D == DeviceType::CPU && DataTypeToEnum::value != DT_UINT8) { net.AddRandomInput("Input", {batch, channels, height, width}); - } else if (D == DeviceType::GPU) { - net.AddRandomInput("Input", {batch, height, width, channels}); } else { - MACE_NOT_IMPLEMENTED; + net.AddRandomInput("Input", {batch, height, width, channels}); } net.AddRandomInput("Bias", {channels}, true, true); + net.GetTensor("Input")->SetScale(0.1); + net.GetTensor("Bias")->SetScale(0.1); OpDefBuilder("BiasAdd", "BiasAddBM") .Input("Input") @@ -44,17 +44,18 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); + net.Setup(D); + net.GetTensor("Output")->SetScale(0.1); + // Warm-up for (int i = 0; i < 5; ++i) { - net.RunOp(D); + net.Run(); } - net.Sync(); mace::testing::StartTiming(); while (iters--) { - net.RunOp(D); + net.Run(); } - net.Sync(); } } // namespace @@ -67,12 +68,21 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#ifdef MACE_ENABLE_OPENCL +#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE) #define MACE_BM_BIAS_ADD(N, C, H, W) \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, uint8_t, CPU); \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, GPU); \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, half, GPU); -#else +#elif defined(MACE_ENABLE_OPENCL) +#define MACE_BM_BIAS_ADD(N, C, H, W) \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, GPU); \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, half, GPU); +#elif defined(MACE_ENABLE_QUANTIZE) +#define MACE_BM_BIAS_ADD(N, C, H, W) \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, uint8_t, CPU); #define MACE_BM_BIAS_ADD(N, C, H, W) \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); #endif diff --git a/test/ccunit/mace/ops/activation_test.cc b/test/ccunit/mace/ops/activation_test.cc index c2c9588226e91b4de6e237bf5785a18c8d1798c7..27cfefb7a861e717ade5b1808aa1e7b76fbe4891 100644 --- a/test/ccunit/mace/ops/activation_test.cc +++ b/test/ccunit/mace/ops/activation_test.cc @@ -308,6 +308,62 @@ TEST_F(ActivationOpTest, OPENCLSimpleSigmoid) { TestSimpleSigmoid(); } +namespace { +void TestQuantized(const index_t size, const char *type) { + OpsTestNet net; + std::vector input_shape{size}; + net.AddRandomInput( + "Input", input_shape, false, false); + net.AddRandomInput( + "Output", input_shape, false, true, true); + OpDefBuilder("Activation", "ActivationTest") + .Input("Input") + .Output("Output") + .AddStringArg("activation", type) + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + + net.RunOp(CPU); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + net.AddRandomInput("QuantizedOutput", input_shape); + OpDefBuilder("Activation", "QuantizedActivationTest") + .Input("QuantizedInput") + .Output("QuantizedOutput") + .AddStringArg("activation", type) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(ActivationOpTest, Quantized) { + TestQuantized(64, "RELU"); + TestQuantized(64, "RELUX"); + TestQuantized(37, "RELU"); + TestQuantized(37, "RELUX"); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccunit/mace/ops/bias_add_test.cc b/test/ccunit/mace/ops/bias_add_test.cc index 07f6172608ada11ea9ef3d817498b4a0abb0c690..a1d37e0012d138d3b1e65bf0cb22bc469bd70168 100644 --- a/test/ccunit/mace/ops/bias_add_test.cc +++ b/test/ccunit/mace/ops/bias_add_test.cc @@ -214,6 +214,102 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } +namespace { +void TestQuantized(const bool batched_bias, + const bool has_data_format) { + static unsigned int seed = time(NULL); + index_t batch = 1 + rand_r(&seed) % 10; + index_t channels = 3 + rand_r(&seed) % 50; + index_t height = 64 + rand_r(&seed) % 50; + index_t width = 64 + rand_r(&seed) % 50; + + OpsTestNet net; + std::vector input_shape{batch, height, width, channels}; + net.AddRandomInput("Input", input_shape, false, false); + net.TransformDataFormat( + "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW); + if (batched_bias) { + net.AddRandomInput("Bias", {batch, channels}, true); + } else { + net.AddRandomInput("Bias", {channels}, true); + } + + net.AddRandomInput( + "OutputNCHW", input_shape, false, true, true); + OpDefBuilder("BiasAdd", "BiasAddTest") + .Input("InputNCHW") + .Input("Bias") + .Output("OutputNCHW") + .AddIntArg("has_data_format", has_data_format) + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + + net.RunOp(CPU); + net.TransformDataFormat( + "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeBias") + .Input("Bias") + .Output("QuantizedBias") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeOutput") + .Input("Output") + .Output("ExpectedQuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + net.AddRandomInput("QuantizedOutput", input_shape); + OpDefBuilder("BiasAdd", "BiasAddTest") + .Input("QuantizedInput") + .Input("QuantizedBias") + .Output("QuantizedOutput") + .AddIntArg("has_data_format", has_data_format) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(eq_output->scale()); + q_output->SetZeroPoint(eq_output->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(BiasAddOpTest, Quantized) { + TestQuantized(false, false); + TestQuantized(false, true); + TestQuantized(true, false); + TestQuantized(true, true); +} + + } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccunit/mace/ops/matmul_test.cc b/test/ccunit/mace/ops/matmul_test.cc index 4ab2ec767f4ba9ee25f71c5ff51de095d59fb7c2..fb1858b99674e050724fde902083bafa7dd6dc13 100644 --- a/test/ccunit/mace/ops/matmul_test.cc +++ b/test/ccunit/mace/ops/matmul_test.cc @@ -182,8 +182,8 @@ void QuantOutputUint8(const std::vector &batch, if (rhs_batched) { rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end()); } - net.AddRandomInput("A", lhs_shape); - net.AddRandomInput("B", rhs_shape); + net.AddRandomInput("A", lhs_shape, false, false); + net.AddRandomInput("B", rhs_shape, false, false); OpDefBuilder("MatMul", "MatMulTest") .Input("A") @@ -276,8 +276,8 @@ void QuantOutputInt32(const std::vector &batch, if (rhs_batched) { rhs_shape.insert(rhs_shape.begin(), batch.begin(), batch.end()); } - net.AddRandomInput("A", lhs_shape); - net.AddRandomInput("B", rhs_shape); + net.AddRandomInput("A", lhs_shape, false, false); + net.AddRandomInput("B", rhs_shape, false, false); OpDefBuilder("MatMul", "MatMulTest") .Input("A") @@ -408,10 +408,16 @@ TEST_F(MatMulOpTest, QuantOutputUint8) { QuantOutputUint8({1}, 64, 128, 32, true, true); QuantOutputUint8({1}, 64, 32, 128, true, true); QuantOutputUint8({2, 3}, 64, 32, 128, true, true); + QuantOutputUint8({1}, 1, 30000, 256, false, true); + QuantOutputUint8({1}, 30000, 256, 1, false, false); + QuantOutputUint8({2}, 1, 256, 128, false, true); + QuantOutputUint8({3}, 128, 256, 1, false, false); // UnAligned QuantOutputUint8({16}, 31, 61, 67, false, true); QuantOutputUint8({31}, 31, 61, 67, true, false); QuantOutputUint8({2, 3}, 31, 61, 67, true, true); + QuantOutputUint8({1}, 1, 30001, 253, false, true); + QuantOutputUint8({2}, 253, 300, 1, false, false); QuantOutputUint8({2, 3}, 31, 61, 67, true, true, true, false); QuantOutputUint8({2, 3}, 31, 61, 67, true, true, false, true); diff --git a/test/ccunit/mace/ops/pad_test.cc b/test/ccunit/mace/ops/pad_test.cc index 3d785ac7603b75d9a2e11ca65faeefb1cc40abbc..44b12d825f5fb8d5d9464ef429857457f3bb55d7 100644 --- a/test/ccunit/mace/ops/pad_test.cc +++ b/test/ccunit/mace/ops/pad_test.cc @@ -470,6 +470,78 @@ TEST_F(PadTest, SymmetricCPU) { expected_data, paddings, PadType::SYMMETRIC); } +namespace { +void TestQuantized(const std::vector &paddings, + const int pad_type) { + static unsigned int seed = time(NULL); + index_t batch = 1 + rand_r(&seed) % 10; + index_t channels = 3 + rand_r(&seed) % 50; + index_t height = 64 + rand_r(&seed) % 50; + index_t width = 64 + rand_r(&seed) % 50; + + OpsTestNet net; + std::vector input_shape{batch, height, width, channels}; + net.AddRandomInput( + "Input", input_shape, false, false); + net.TransformDataFormat( + "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW); + OpDefBuilder("Pad", "PadTest") + .Input("TInput") + .Output("TOutput") + .AddIntsArg("paddings", paddings) + .AddIntArg("pad_type", pad_type) + .AddFloatArg("constant_value", 1.0) + .AddIntArg("has_data_format", 1) + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + + net.RunOp(); + net.TransformDataFormat( + "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + net.AddRandomInput("QuantizedOutput", input_shape); + OpDefBuilder("Pad", "QuantizedPadTest") + .Input("QuantizedInput") + .Output("QuantizedOutput") + .AddIntsArg("paddings", paddings) + .AddIntArg("pad_type", pad_type) + .AddFloatArg("constant_value", 1.0) + .AddIntArg("has_data_format", 1) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(PadTest, Quantized) { + for (int i = PadType::CONSTANT; i <= PadType::SYMMETRIC; i++) { + TestQuantized({0, 0, 2, 2, 1, 1, 0, 0}, i); + TestQuantized({0, 0, 2, 0, 1, 0, 0, 0}, i); + TestQuantized({0, 0, 0, 1, 0, 2, 0, 0}, i); + } +} + } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccutils/mace/ops/testing/test_utils.h b/test/ccutils/mace/ops/testing/test_utils.h index ef830781d7fd597599180d8882c47eda1800a3e8..5ca247e1cbab0a8eea90d1438e97aad38e07a721 100644 --- a/test/ccutils/mace/ops/testing/test_utils.h +++ b/test/ccutils/mace/ops/testing/test_utils.h @@ -315,8 +315,13 @@ void ExpectTensorSimilar(const Tensor &x, double norm_product = sqrt(x_norm) * sqrt(y_norm); double error = rel_err * std::abs(dot_product); - EXPECT_NEAR(dot_product, norm_product, error) - << "Shape " << ShapeToString(x); + // When y_norm is 0, dot_product and norm_product are all 0 + if (y_norm == 0.0) { + EXPECT_NEAR(x_norm, y_norm, rel_err) << "Shape " << ShapeToString(x); + } else { + EXPECT_NEAR(dot_product, norm_product, error) + << "Shape " << ShapeToString(x); + } } } // namespace test diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 8f032acefbc352b94cd488f8653a3a1392ce5e94..ab81aa65f5e54cd964d7e78ee029b7c500528be3 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -866,7 +866,8 @@ class Transformer(base_converter.ConverterInterface): for op in net.op: if (((op.type == MaceOp.Conv2D.name or op.type == MaceOp.DepthwiseConv2d.name - or op.type == MaceOp.FullyConnected.name) + or op.type == MaceOp.FullyConnected.name + or op.type == MaceOp.MatMul.name) and len(op.input) == 2) or (op.type == MaceOp.Deconv2D.name and ((ConverterUtil.get_arg( @@ -1753,7 +1754,8 @@ class Transformer(base_converter.ConverterInterface): check_conv =\ ops[0].type in [MaceOp.Conv2D.name, MaceOp.DepthwiseConv2d.name, - MaceOp.FullyConnected.name]\ + MaceOp.FullyConnected.name, + MaceOp.MatMul.name]\ and ops[0].input[2] == tensor.name # in tensorflow deconv's bias is the forth input if ops[0].type in [MaceOp.Deconv2D.name, @@ -2036,7 +2038,8 @@ class Transformer(base_converter.ConverterInterface): MaceOp.BatchToSpaceND.name, MaceOp.SpaceToBatchND.name, MaceOp.SpaceToDepth.name, - MaceOp.DepthToSpace.name]: + MaceOp.DepthToSpace.name, + MaceOp.Transpose.name]: del op.quantize_info[:] producer_op = self._producer[op.input[0]] if producer_op.output[0] in self._option.input_nodes: