提交 ec01107c 编写于 作者: L liyin

Replace thread pool

上级 7d3849a5
......@@ -49,9 +49,9 @@ docs:
platform_compatible_tests:
stage: platform_compatible_tests
script:
- bazel build mace/core:core --define openmp=true
- bazel build --config arm_linux_gnueabihf --define openmp=true --define opencl=true --define neon=true //mace/libmace:libmace.so
- bazel build --config aarch64_linux_gnu --define openmp=true --define opencl=true --define neon=true //mace/libmace:libmace.so
- bazel build mace/core:core --define openmp=false
- bazel build --config arm_linux_gnueabihf --define openmp=false --define opencl=true --define neon=true //mace/libmace:libmace.so
- bazel build --config aarch64_linux_gnu --define openmp=false --define opencl=true --define neon=true //mace/libmace:libmace.so
build_libraries:
stage: build_libraries
......@@ -202,13 +202,13 @@ so_size_check:
stage: so_size_check
script:
- DYNAMIC_LIB_PATH="bazel-bin/mace/libmace/libmace.so"
- bazel build -s --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=true --define opencl=false --define quantize=false --cpu=armeabi-v7a
- bazel build -s --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=false --define quantize=false --cpu=armeabi-v7a
- CURRENT_LIBMACE_SO_SIZE=`ls -l $DYNAMIC_LIB_PATH --block-size=K -s | cut -f 1 -d "K"`
- TARGET_MACE_WORK_DIR=`mktemp -d`
- pushd $TARGET_MACE_WORK_DIR
- GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone git@github.com:XiaoMi/mace.git
- pushd mace
- bazel build -s --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=true --define opencl=false --define quantize=false --cpu=armeabi-v7a
- bazel build -s --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=false --define quantize=false --cpu=armeabi-v7a
- TARGET_LIBMACE_SO_SIZE=`ls -l $DYNAMIC_LIB_PATH --block-size=K -s | cut -f 1 -d "K"`
- popd
- popd
......
......@@ -79,19 +79,19 @@ new_http_archive(
http_archive(
name = "gemmlowp",
sha256 = "4e9cd60f7871ae9e06dcea5fec1a98ddf1006b32a85883480273e663f143f303",
strip_prefix = "gemmlowp-master-66fb41a7cafd2034a50e0b32791359897d657f7a",
sha256 = "afbea037aee2d21b625985238486b4219396f9c2550b0fde3157fab4d2580205",
strip_prefix = "gemmlowp-master-1f6d8d442805a400c74e63a4a017390733df2e28",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-66fb41a7cafd2034a50e0b32791359897d657f7a.zip",
"http://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-1f6d8d442805a400c74e63a4a017390733df2e28.zip",
],
)
http_archive(
name = "tflite",
sha256 = "1bb4571ee5cbde427ecfed076b39edaad96ace897ab86bb2495bdb93c706b203",
strip_prefix = "tensorflow-mace-ffc8cc7e8c9d1894753509e88b17e251bc6255e3",
sha256 = "8b4c1b2ad2d31da9859e17b0ad551b12e1db7ff2faf7e83218901ab48d9fa91a",
strip_prefix = "tensorflow-mace-dfabaf85145e4d5ad39f34a0cea57b44c32dbe43",
urls = [
"http://cnbj1.fds.api.xiaomi.com/mace/third-party/tflite/tensorflow-mace-ffc8cc7e8c9d1894753509e88b17e251bc6255e3_custom.zip",
"http://cnbj1.fds.api.xiaomi.com/mace/third-party/tflite/tensorflow-mace-dfabaf85145e4d5ad39f34a0cea57b44c32dbe43.zip",
],
)
......
......@@ -252,8 +252,7 @@ int Main(int argc, char **argv) {
MaceEngineConfig config(device_type);
mace_status = config.SetCPUThreadPolicy(
FLAGS_omp_num_threads,
static_cast<CPUAffinityPolicy >(FLAGS_cpu_affinity_policy),
true);
static_cast<CPUAffinityPolicy >(FLAGS_cpu_affinity_policy));
if (mace_status != MaceStatus::MACE_SUCCESS) {
LOG(INFO) << "Set openmp or cpu affinity failed.";
}
......
......@@ -21,10 +21,10 @@ namespace mace {
CPUDevice::CPUDevice(const int num_threads,
const CPUAffinityPolicy policy,
const bool use_gemmlowp)
utils::ThreadPool *thread_pool)
: cpu_runtime_(make_unique<CPURuntime>(num_threads,
policy,
use_gemmlowp)),
thread_pool)),
scratch_buffer_(make_unique<ScratchBuffer>(GetCPUAllocator())) {}
CPUDevice::~CPUDevice() = default;
......
......@@ -46,7 +46,7 @@ class CPUDevice : public Device {
public:
CPUDevice(const int num_threads,
const CPUAffinityPolicy policy,
const bool use_gemmlowp);
utils::ThreadPool *thread_pool);
virtual ~CPUDevice();
#ifdef MACE_ENABLE_OPENCL
......
......@@ -136,7 +136,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
make_unique<CPUDevice>(
target_device->cpu_runtime()->num_threads(),
target_device->cpu_runtime()->policy(),
target_device->cpu_runtime()->use_gemmlowp())) {
&target_device->cpu_runtime()->thread_pool())) {
MACE_LATENCY_LOGGER(1, "Constructing SerialNet");
// quantize model flag
bool is_quantize_model = IsQuantizedModel(*net_def);
......@@ -154,7 +154,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_map[tensor.name()] =
std::vector<index_t>(tensor.dims().begin(), tensor.dims().end());
std::vector<index_t>(tensor.dims().begin(), tensor.dims().end());
}
bool has_data_format = false;
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif // MACE_ENABLE_NEON
#include "mace/core/quantize.h"
namespace mace {
#ifdef MACE_ENABLE_NEON
template<>
void QuantizeUtil<uint8_t>::QuantizeWithScaleAndZeropoint(
const float *input,
const index_t size,
float scale,
int32_t zero_point,
uint8_t *output) {
const float32x4_t vround = vdupq_n_f32(0.5);
const float32x4_t
vzero = vaddq_f32(vround, vcvtq_f32_s32(vdupq_n_s32(zero_point)));
const float recip_scale = 1.f / scale;
const float32x4_t vrecip_scale = vdupq_n_f32(recip_scale);
const index_t block_count = size / 16;
thread_pool_->Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
float32x4_t vi0 = vld1q_f32(input + i * 16);
float32x4_t vi1 = vld1q_f32(input + i * 16 + 4);
float32x4_t vi2 = vld1q_f32(input + i * 16 + 8);
float32x4_t vi3 = vld1q_f32(input + i * 16 + 12);
int32x4_t vo0_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi0, vrecip_scale));
int32x4_t vo1_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi1, vrecip_scale));
int32x4_t vo2_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi2, vrecip_scale));
int32x4_t vo3_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi3, vrecip_scale));
uint8x8_t vo0_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo0_s32), vqmovn_s32(vo1_s32)));
uint8x8_t vo1_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo2_s32), vqmovn_s32(vo3_s32)));
uint8x16_t vo = vcombine_u8(vo0_u8, vo1_u8);
vst1q_u8(output + i * 16, vo);
}
}, 0, block_count, 1);
for (index_t i = block_count * 16; i < size; ++i) {
output[i] =
Saturate<uint8_t>(roundf(zero_point + recip_scale * input[i]));
}
}
template<>
void QuantizeUtil<uint8_t>::Dequantize(const uint8_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 16;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
thread_pool_->Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
uint8x16_t vi = vld1q_u8(input + i * 16);
float32x4x4_t vo = {
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
};
vst1q_f32(output + i * 16, vo.val[0]);
vst1q_f32(output + i * 16 + 4, vo.val[1]);
vst1q_f32(output + i * 16 + 8, vo.val[2]);
vst1q_f32(output + i * 16 + 12, vo.val[3]);
}
}, 0, block_count, 1);
for (index_t i = block_count * 16; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
template<>
void QuantizeUtil<int32_t>::Dequantize(const int *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 4;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
thread_pool_->Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
int32x4_t vi = vld1q_s32(input + i * 4);
float32x4_t vo = vmulq_f32(vscale, vcvtq_f32_s32(vsubq_s32(vi, vzero)));
vst1q_f32(output + i * 4, vo);
}
}, 0, block_count, 1);
for (index_t i = block_count * 4; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
#endif
} // namespace mace
......@@ -12,18 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_UTILS_QUANTIZE_H_
#define MACE_UTILS_QUANTIZE_H_
#ifndef MACE_CORE_QUANTIZE_H_
#define MACE_CORE_QUANTIZE_H_
#include <algorithm>
#include <cmath>
#include <limits>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif // MACE_ENABLE_NEON
#include "mace/utils/logging.h"
#include "mace/utils/thread_pool.h"
#include "mace/core/tensor.h"
namespace mace {
......@@ -92,185 +90,6 @@ inline void FindMinMax(const float *input,
*max_val = max_v;
}
template<typename T>
inline void QuantizeWithScaleAndZeropoint(const float *input,
const index_t size,
float scale,
int32_t zero_point,
T *output) {
float recip_scale = 1 / scale;
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < size; ++i) {
output[i] = Saturate<T>(roundf(zero_point + recip_scale * input[i]));
}
}
template<typename T>
inline void Quantize(const float *input,
const index_t size,
bool non_zero,
T *output,
float *scale,
int32_t *zero_point) {
float in_min_data;
float in_max_data;
FindMinMax(input, size, &in_min_data, &in_max_data);
AdjustRange<T>(in_min_data, in_max_data, non_zero,
scale, zero_point);
QuantizeWithScaleAndZeropoint(input, size, *scale, *zero_point, output);
}
template<typename T>
inline void Quantize(const Tensor &input,
Tensor *output,
float *min_out,
float *max_out) {
MACE_CHECK(input.size() != 0);
Tensor::MappingGuard input_guard(&input);
Tensor::MappingGuard output_guard(output);
auto *input_data = input.data<float>();
auto *output_data = output->mutable_data<T>();
float scale;
int32_t zero_point;
Quantize(input_data, input.size(), false, output_data, &scale, &zero_point);
*min_out = scale * (std::numeric_limits<T>::lowest() - zero_point);
*max_out = scale * (std::numeric_limits<T>::max() - zero_point);
}
template<typename T>
inline void Dequantize(const T *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
#if defined(MACE_ENABLE_NEON)
template<>
inline void QuantizeWithScaleAndZeropoint<uint8_t>(const float *input,
const index_t size,
float scale,
int32_t zero_point,
uint8_t *output) {
const float32x4_t vround = vdupq_n_f32(0.5);
const float32x4_t
vzero = vaddq_f32(vround, vcvtq_f32_s32(vdupq_n_s32(zero_point)));
const float recip_scale = 1.f / scale;
const float32x4_t vrecip_scale = vdupq_n_f32(recip_scale);
const index_t block_count = size / 16;
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
float32x4_t vi0 = vld1q_f32(input + i * 16);
float32x4_t vi1 = vld1q_f32(input + i * 16 + 4);
float32x4_t vi2 = vld1q_f32(input + i * 16 + 8);
float32x4_t vi3 = vld1q_f32(input + i * 16 + 12);
int32x4_t vo0_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi0, vrecip_scale));
int32x4_t vo1_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi1, vrecip_scale));
int32x4_t vo2_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi2, vrecip_scale));
int32x4_t vo3_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi3, vrecip_scale));
uint8x8_t vo0_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo0_s32), vqmovn_s32(vo1_s32)));
uint8x8_t vo1_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo2_s32), vqmovn_s32(vo3_s32)));
uint8x16_t vo = vcombine_u8(vo0_u8, vo1_u8);
vst1q_u8(output + i * 16, vo);
}
#pragma omp parallel for schedule(runtime)
for (index_t i = block_count * 16; i < size; ++i) {
output[i] = Saturate<uint8_t>(roundf(zero_point + recip_scale * input[i]));
}
}
template<>
inline void Dequantize<int32_t>(const int32_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 4;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
int32x4_t vi = vld1q_s32(input + i * 4);
float32x4_t vo = vmulq_f32(vscale, vcvtq_f32_s32(vsubq_s32(vi, vzero)));
vst1q_f32(output + i * 4, vo);
}
for (index_t i = block_count * 4; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
template<>
inline void Dequantize<uint8_t>(const uint8_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 16;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
uint8x16_t vi = vld1q_u8(input + i * 16);
float32x4x4_t vo = {
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
};
vst1q_f32(output + i * 16, vo.val[0]);
vst1q_f32(output + i * 16 + 4, vo.val[1]);
vst1q_f32(output + i * 16 + 8, vo.val[2]);
vst1q_f32(output + i * 16 + 12, vo.val[3]);
}
for (index_t i = block_count * 16; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
#endif // MACE_ENABLE_NEON
template<typename T>
inline void DeQuantize(const Tensor &input,
const float min_in,
const float max_in,
Tensor *output) {
MACE_CHECK(input.size() != 0);
Tensor::MappingGuard input_guard(&input);
Tensor::MappingGuard output_guard(output);
auto *input_data = input.data<T>();
auto *output_data = output->mutable_data<float>();
float scale;
int32_t zero_point;
AdjustRange<T>(min_in, max_in, false, &scale, &zero_point);
Dequantize(input_data, input.size(), scale, zero_point, output_data);
}
inline void QuantizeMultiplier(double multiplier,
int32_t *output_multiplier,
int32_t *shift) {
......@@ -296,6 +115,118 @@ inline void GetOutputMultiplierAndShift(
MACE_CHECK(*right_shift >= 0);
}
template<typename T>
class QuantizeUtil {
public:
explicit QuantizeUtil(utils::ThreadPool *thread_pool)
: thread_pool_(thread_pool) {}
void QuantizeWithScaleAndZeropoint(const float *input,
const index_t size,
float scale,
int32_t zero_point,
T *output) {
float recip_scale = 1 / scale;
thread_pool_->Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output[i] = Saturate<T>(roundf(zero_point + recip_scale * input[i]));
}
}, 0, size, 1);
}
void Quantize(const float *input,
const index_t size,
bool non_zero,
T *output,
float *scale,
int32_t *zero_point) {
float in_min_data;
float in_max_data;
FindMinMax(input, size, &in_min_data, &in_max_data);
AdjustRange<T>(in_min_data, in_max_data, non_zero,
scale, zero_point);
QuantizeWithScaleAndZeropoint(input, size, *scale, *zero_point, output);
}
void Quantize(const Tensor &input,
Tensor *output,
float *min_out,
float *max_out) {
MACE_CHECK(input.size() != 0);
Tensor::MappingGuard input_guard(&input);
Tensor::MappingGuard output_guard(output);
auto *input_data = input.data<float>();
auto *output_data = output->mutable_data<T>();
float scale;
int32_t zero_point;
Quantize(input_data, input.size(), false, output_data, &scale, &zero_point);
*min_out = scale * (std::numeric_limits<T>::lowest() - zero_point);
*max_out = scale * (std::numeric_limits<T>::max() - zero_point);
}
void Dequantize(const T *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
thread_pool_->Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output[i] = scale * (input[i] - zero_point);
}
}, 0, size, 1);
}
void DeQuantize(const Tensor &input,
const float min_in,
const float max_in,
Tensor *output) {
MACE_CHECK(input.size() != 0);
Tensor::MappingGuard input_guard(&input);
Tensor::MappingGuard output_guard(output);
auto *input_data = input.data<T>();
auto *output_data = output->mutable_data<float>();
float scale;
int32_t zero_point;
AdjustRange<T>(min_in, max_in, false, &scale, &zero_point);
Dequantize(input_data, input.size(), scale, zero_point, output_data);
}
private:
utils::ThreadPool *thread_pool_;
};
#ifdef MACE_ENABLE_NEON
template<>
void QuantizeUtil<uint8_t>::QuantizeWithScaleAndZeropoint(
const float *input,
const index_t size,
float scale,
int32_t zero_point,
uint8_t *output);
template<>
void QuantizeUtil<uint8_t>::Dequantize(const uint8_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output);
template<>
void QuantizeUtil<int32_t>::Dequantize(const int *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output);
#endif
} // namespace mace
#endif // MACE_UTILS_QUANTIZE_H_
#endif // MACE_CORE_QUANTIZE_H_
......@@ -68,7 +68,7 @@ MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads,
#else
MACE_UNUSED(omp_num_threads);
MACE_UNUSED(schedule_policy);
LOG(WARNING) << "Set OpenMP threads number failed: OpenMP not enabled.";
VLOG(2) << "Set OpenMP threads number failed: OpenMP not enabled.";
#endif
#ifdef MACE_ENABLE_OPENMP
......@@ -143,7 +143,7 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy(
#ifdef MACE_ENABLE_OPENMP
omp_set_num_threads(num_threads_hint);
#else
LOG(WARNING) << "Set OpenMP threads number failed: OpenMP not enabled.";
VLOG(2) << "Set OpenMP threads number failed: OpenMP not enabled.";
#endif
return MaceStatus::MACE_SUCCESS;
}
......
......@@ -35,24 +35,17 @@ class CPURuntime {
public:
CPURuntime(const int num_threads,
CPUAffinityPolicy policy,
bool use_gemmlowp)
utils::ThreadPool *thread_pool)
: num_threads_(num_threads),
policy_(policy),
gemm_context_(nullptr),
thread_pool_(static_cast<size_t>(num_threads), policy) {
thread_pool_(thread_pool) {
#ifdef MACE_ENABLE_QUANTIZE
if (use_gemmlowp) {
MACE_CHECK_NOTNULL(GetGemmlowpContext());
}
#else
MACE_UNUSED(use_gemmlowp);
MACE_CHECK_NOTNULL(GetGemmlowpContext());
#endif // MACE_ENABLE_QUANTIZE
SetOpenMPThreadsAndAffinityPolicy(num_threads_,
policy_,
gemm_context_);
// TODO(liyin): After we replace OpenMP to thread_pool, uncomment the
// following line.
// thread_pool_.Init();
}
#ifdef MACE_ENABLE_QUANTIZE
......@@ -80,12 +73,8 @@ class CPURuntime {
return policy_;
}
bool use_gemmlowp() const {
return gemm_context_ != nullptr;
}
utils::ThreadPool &thread_pool() {
return thread_pool_;
return *thread_pool_;
}
private:
......@@ -97,7 +86,7 @@ class CPURuntime {
int num_threads_;
CPUAffinityPolicy policy_;
void *gemm_context_;
utils::ThreadPool thread_pool_;
utils::ThreadPool *thread_pool_;
};
} // namespace mace
......
......@@ -31,8 +31,9 @@ namespace mace {
class HexagonDevice : public CPUDevice {
public:
explicit HexagonDevice(DeviceType device_type)
: CPUDevice(0, AFFINITY_NONE, false),
explicit HexagonDevice(DeviceType device_type,
utils::ThreadPool *thread_pool)
: CPUDevice(0, AFFINITY_NONE, thread_pool),
device_type_(device_type) {}
DeviceType device_type() const override {
......@@ -44,9 +45,9 @@ class HexagonDevice : public CPUDevice {
};
std::unique_ptr<HexagonControlWrapper> CreateHexagonControlWrapper(
DeviceType device_type) {
Device *device) {
std::unique_ptr<HexagonControlWrapper> hexagon_controller;
auto device_type = device->device_type();
switch (device_type) {
#ifdef MACE_ENABLE_HEXAGON
case HEXAGON:
......@@ -55,11 +56,10 @@ std::unique_ptr<HexagonControlWrapper> CreateHexagonControlWrapper(
#endif
#ifdef MACE_ENABLE_HTA
case HTA:
hexagon_controller = make_unique<HexagonHTAWrapper>();
hexagon_controller = make_unique<HexagonHTAWrapper>(device);
break;
#endif
default:
LOG(FATAL) << "Not supported Hexagon device type: " << device_type;
default:LOG(FATAL) << "Not supported Hexagon device type: " << device_type;
}
return hexagon_controller;
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/runtime/hexagon/hexagon_dsp_wrapper.h"
#include <algorithm>
#include <iomanip>
#include <map>
......@@ -22,7 +24,6 @@
#include <string>
#include <utility>
#include "mace/core/runtime/hexagon/hexagon_dsp_wrapper.h"
#include "mace/core/runtime/hexagon/hexagon_dsp_ops.h"
#include "mace/core/types.h"
#include "mace/port/env.h"
......
......@@ -26,11 +26,15 @@
#include "mace/core/runtime/hexagon/hexagon_hta_ops.h"
#include "mace/core/types.h"
#include "mace/utils/memory.h"
#include "mace/utils/quantize.h"
#include "mace/core/quantize.h"
#include "third_party/hta/hta_hexagon_api.h"
namespace mace {
HexagonHTAWrapper::HexagonHTAWrapper(Device *device)
: device_(device), quantize_util_(&device->cpu_runtime()->thread_pool()) {
}
int HexagonHTAWrapper::GetVersion() {
int version;
MACE_CHECK(hexagon_hta_nn_version(&version) == 0, "get version error");
......@@ -237,8 +241,8 @@ bool HexagonHTAWrapper::ExecuteGraph(const Tensor &input_tensor,
}
bool HexagonHTAWrapper::ExecuteGraphNew(
const std::map<std::string, Tensor*> &input_tensors,
std::map<std::string, Tensor*> *output_tensors) {
const std::map<std::string, Tensor *> &input_tensors,
std::map<std::string, Tensor *> *output_tensors) {
VLOG(2) << "Execute graph new: " << nn_id_;
uint32_t num_inputs = static_cast<uint32_t>(input_tensors.size());
uint32_t num_outputs = static_cast<uint32_t>(output_tensors->size());
......@@ -261,11 +265,11 @@ bool HexagonHTAWrapper::ExecuteGraphNew(
const float *input_data = input_tensor->data<float>();
uint8_t *input_data_u8 = input_info_[i].tensor_u8->mutable_data<uint8_t>();
QuantizeWithScaleAndZeropoint(input_data,
input_tensor->size(),
input_info_[i].scale,
input_info_[i].zero_point,
input_data_u8);
quantize_util_.QuantizeWithScaleAndZeropoint(input_data,
input_tensor->size(),
input_info_[i].scale,
input_info_[i].zero_point,
input_data_u8);
inputs[i].data = const_cast<unsigned char *>(
reinterpret_cast<const unsigned char *>(
......@@ -315,11 +319,11 @@ bool HexagonHTAWrapper::ExecuteGraphNew(
const uint8_t *output_data_u8 = output_info_[i].tensor_u8->data<uint8_t>();
float *output_data = output_tensor->mutable_data<float>();
Dequantize(output_data_u8,
output_info_[i].tensor_u8->size(),
output_info_[i].scale,
output_info_[i].zero_point,
output_data);
quantize_util_.Dequantize(output_data_u8,
output_info_[i].tensor_u8->size(),
output_info_[i].scale,
output_info_[i].zero_point,
output_data);
}
return res == 0;
......
......@@ -19,15 +19,18 @@
#include <string>
#include <vector>
#include "mace/utils/thread_pool.h"
#include "mace/core/quantize.h"
#include "mace/core/runtime/hexagon/hexagon_control_wrapper.h"
#include "mace/core/tensor.h"
#include "mace/core/device.h"
#include "mace/public/mace.h"
namespace mace {
class HexagonHTAWrapper : public HexagonControlWrapper {
public:
HexagonHTAWrapper() = default;
explicit HexagonHTAWrapper(Device *device);
int GetVersion() override;
bool Config() override;
......@@ -46,6 +49,9 @@ class HexagonHTAWrapper : public HexagonControlWrapper {
void ResetPerfInfo() override;
void SetDebugLevel(int level) override;
private:
Device *device_;
QuantizeUtil<uint8_t> quantize_util_;
MACE_DISABLE_COPY_AND_ASSIGN(HexagonHTAWrapper);
};
} // namespace mace
......
......@@ -25,8 +25,10 @@ GPUDevice::GPUDevice(std::shared_ptr<Tuner<uint32_t>> tuner,
std::shared_ptr<KVStorage> opencl_binary_storage,
const int num_threads,
CPUAffinityPolicy cpu_affinity_policy,
bool use_gemmlowp) :
CPUDevice(num_threads, cpu_affinity_policy, use_gemmlowp),
utils::ThreadPool *thread_pool) :
CPUDevice(num_threads,
cpu_affinity_policy,
thread_pool),
runtime_(new OpenCLRuntime(opencl_cache_storage, priority, perf,
opencl_binary_storage, tuner)),
allocator_(new OpenCLAllocator(runtime_.get())),
......@@ -35,7 +37,7 @@ GPUDevice::GPUDevice(std::shared_ptr<Tuner<uint32_t>> tuner,
GPUDevice::~GPUDevice() = default;
GPURuntime* GPUDevice::gpu_runtime() {
GPURuntime *GPUDevice::gpu_runtime() {
return gpu_runtime_.get();
}
......
......@@ -33,7 +33,7 @@ class GPUDevice : public CPUDevice {
std::shared_ptr<KVStorage> opencl_binary_storage = nullptr,
const int num_threads = -1,
CPUAffinityPolicy cpu_affinity_policy = AFFINITY_NONE,
bool use_gemmlowp = false);
utils::ThreadPool *thread_pool = nullptr);
~GPUDevice();
GPURuntime *gpu_runtime() override;
Allocator *allocator() override;
......
......@@ -20,6 +20,8 @@
#include <utility>
#include <vector>
#include "mace/core/types.h"
#define MACE_BENCHMARK(n) \
static ::mace::testing::Benchmark *__benchmark_##n = \
(new ::mace::testing::Benchmark(#n, (n)))
......
......@@ -33,8 +33,7 @@ int main(int argc, char **argv) {
// config runtime
mace::ops::test::OpTestContext::Get(
FLAGS_omp_num_threads,
static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy),
true);
static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy));
mace::testing::Benchmark::Run(FLAGS_filter.c_str());
return 0;
......
......@@ -54,6 +54,12 @@ MACE_MAPPING_DATA_TYPE_AND_ENUM(half, DT_HALF);
MACE_MAPPING_DATA_TYPE_AND_ENUM(float, DT_FLOAT);
MACE_MAPPING_DATA_TYPE_AND_ENUM(uint8_t, DT_UINT8);
MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32);
enum FrameworkType {
TENSORFLOW = 0,
CAFFE = 1,
};
} // namespace mace
#endif // MACE_CORE_TYPES_H_
......@@ -19,7 +19,7 @@
#include "mace/core/arg_helper.h"
#include "mace/core/memory_optimizer.h"
#include "mace/utils/quantize.h"
#include "mace/core/quantize.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_runtime.h"
......@@ -95,8 +95,8 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
model_data_size = std::max(
model_data_size,
static_cast<index_t>(const_tensor.offset() +
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type())));
const_tensor.data_size() *
GetEnumTypeSize(const_tensor.data_type())));
}
VLOG(3) << "Model data size: " << model_data_size;
......@@ -163,11 +163,13 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
auto quantized_data = reinterpret_cast<const uint8_t *>(
model_data + const_tensor.offset());
auto dequantized_data = tensor->mutable_data<float>();
Dequantize(quantized_data,
tensor->size(),
const_tensor.scale(),
const_tensor.zero_point(),
dequantized_data);
QuantizeUtil<uint8_t>
quantize_util(&device->cpu_runtime()->thread_pool());
quantize_util.Dequantize(quantized_data,
tensor->size(),
const_tensor.scale(),
const_tensor.zero_point(),
dequantized_data);
} else {
tensor->CopyBytes(model_data + const_tensor.offset(),
const_tensor.data_size() *
......@@ -185,14 +187,14 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
if (device_type == DeviceType::CPU) {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(device->allocator(),
const_cast<unsigned char*>(model_data),
const_cast<unsigned char *>(model_data),
model_data_size));
} else {
tensor_buffer_ = std::unique_ptr<Buffer>(
new Buffer(device->allocator()));
MACE_RETURN_IF_ERROR(tensor_buffer_->Allocate(model_data_size));
tensor_buffer_->Map(nullptr);
tensor_buffer_->Copy(const_cast<unsigned char*>(model_data),
tensor_buffer_->Copy(const_cast<unsigned char *>(model_data),
0, model_data_size);
tensor_buffer_->UnMap();
}
......
......@@ -112,8 +112,7 @@ Java_com_xiaomi_mace_JniMaceUtils_maceMobilenetCreateEngine(
mace::MaceEngineConfig config(mace_context.device_type);
status = config.SetCPUThreadPolicy(
omp_num_threads,
static_cast<mace::CPUAffinityPolicy>(cpu_affinity_policy),
true);
static_cast<mace::CPUAffinityPolicy>(cpu_affinity_policy));
if (status != mace::MaceStatus::MACE_SUCCESS) {
__android_log_print(ANDROID_LOG_ERROR,
"image_classify attrs",
......
......@@ -5,6 +5,7 @@ load(
"if_darwin",
"if_hexagon_enabled",
"if_hta_enabled",
"if_linux",
"if_opencl_enabled",
"if_openmp_enabled",
)
......@@ -21,13 +22,12 @@ cc_binary(
linkopts = [
"-lm",
"-ldl",
] + if_darwin(
[],
] + if_linux(["-lpthread"]) + if_darwin(
["-lpthread"],
default_value = ["-fuse-ld=gold"],
) + if_openmp_enabled([
"-fopenmp",
]) + if_android([
"-ldl",
"-pie",
"-llog",
]),
......@@ -60,11 +60,10 @@ cc_binary(
linkopts = [
"-lm",
"-ldl",
] + if_darwin(
[],
] + if_linux(["-lpthread"]) + if_darwin(
["-lpthread"],
default_value = ["-fuse-ld=gold"],
) + if_android([
"-ldl",
"-pie",
"-llog",
]),
......
......@@ -149,7 +149,7 @@ GPUContextBuilder &GPUContextBuilder::SetOpenCLBinaryPaths(
return *this;
}
GPUContextBuilder& GPUContextBuilder::SetOpenCLBinary(
GPUContextBuilder &GPUContextBuilder::SetOpenCLBinary(
const unsigned char *data, const size_t size) {
impl_->SetOpenCLBinary(data, size);
return *this;
......@@ -161,7 +161,7 @@ GPUContextBuilder &GPUContextBuilder::SetOpenCLParameterPath(
return *this;
}
GPUContextBuilder& GPUContextBuilder::SetOpenCLParameter(
GPUContextBuilder &GPUContextBuilder::SetOpenCLParameter(
const unsigned char *data, const size_t size) {
impl_->SetOpenCLParameter(data, size);
return *this;
......@@ -181,8 +181,7 @@ class MaceEngineConfig::Impl {
MaceStatus SetGPUHints(GPUPerfHint perf_hint, GPUPriorityHint priority_hint);
MaceStatus SetCPUThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy,
bool use_gemmlowp);
CPUAffinityPolicy policy);
inline DeviceType device_type() const {
return device_type_;
......@@ -196,10 +195,6 @@ class MaceEngineConfig::Impl {
return cpu_affinity_policy_;
}
inline bool use_gemmlowp() const {
return use_gemmlowp_;
}
inline std::shared_ptr<GPUContext> gpu_context() const {
return gpu_context_;
}
......@@ -216,7 +211,6 @@ class MaceEngineConfig::Impl {
DeviceType device_type_;
int num_threads_;
CPUAffinityPolicy cpu_affinity_policy_;
bool use_gemmlowp_;
std::shared_ptr<GPUContext> gpu_context_;
GPUPriorityHint gpu_priority_hint_;
GPUPerfHint gpu_perf_hint_;
......@@ -226,7 +220,6 @@ MaceEngineConfig::Impl::Impl(const DeviceType device_type)
: device_type_(device_type),
num_threads_(-1),
cpu_affinity_policy_(CPUAffinityPolicy::AFFINITY_NONE),
use_gemmlowp_(false),
gpu_context_(new GPUContext),
gpu_priority_hint_(GPUPriorityHint::PRIORITY_LOW),
gpu_perf_hint_(GPUPerfHint::PERF_NORMAL) {}
......@@ -247,15 +240,12 @@ MaceStatus MaceEngineConfig::Impl::SetGPUHints(
MaceStatus MaceEngineConfig::Impl::SetCPUThreadPolicy(
int num_threads,
CPUAffinityPolicy policy,
bool use_gemmlowp) {
CPUAffinityPolicy policy) {
num_threads_ = num_threads;
cpu_affinity_policy_ = policy;
use_gemmlowp_ = use_gemmlowp;
return MaceStatus::MACE_SUCCESS;
}
MaceEngineConfig::MaceEngineConfig(
const DeviceType device_type)
: impl_(new MaceEngineConfig::Impl(device_type)) {}
......@@ -275,9 +265,8 @@ MaceStatus MaceEngineConfig::SetGPUHints(
MaceStatus MaceEngineConfig::SetCPUThreadPolicy(
int num_threads_hint,
CPUAffinityPolicy policy,
bool use_gemmlowp) {
return impl_->SetCPUThreadPolicy(num_threads_hint, policy, use_gemmlowp);
CPUAffinityPolicy policy) {
return impl_->SetCPUThreadPolicy(num_threads_hint, policy);
}
// Mace Tensor
......@@ -407,6 +396,7 @@ class MaceEngine::Impl {
#endif
std::map<std::string, mace::InputOutputInfo> input_info_map_;
std::map<std::string, mace::InputOutputInfo> output_info_map_;
std::unique_ptr<utils::ThreadPool> thread_pool_;
MACE_DISABLE_COPY_AND_ASSIGN(Impl);
};
......@@ -418,16 +408,19 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
device_(nullptr),
ws_(new Workspace()),
net_(nullptr),
is_quantized_model_(false)
is_quantized_model_(false),
thread_pool_(new utils::ThreadPool(config.impl_->num_threads(),
config.impl_->cpu_affinity_policy()))
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
, hexagon_controller_(nullptr)
, hexagon_controller_(nullptr)
#endif
{
LOG(INFO) << "Creating MaceEngine, MACE version: " << MaceVersion();
thread_pool_->Init();
if (device_type_ == DeviceType::CPU) {
device_.reset(new CPUDevice(config.impl_->num_threads(),
config.impl_->cpu_affinity_policy(),
config.impl_->use_gemmlowp()));
thread_pool_.get()));
}
#ifdef MACE_ENABLE_OPENCL
if (device_type_ == DeviceType::GPU) {
......@@ -439,12 +432,13 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
config.impl_->gpu_context()->opencl_binary_storage(),
config.impl_->num_threads(),
config.impl_->cpu_affinity_policy(),
config.impl_->use_gemmlowp()));
thread_pool_.get()));
}
#endif
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
if (device_type_ == DeviceType::HEXAGON || device_type_ == DeviceType::HTA) {
device_.reset(new HexagonDevice(device_type_));
if (device_type_ == DeviceType::HEXAGON
|| device_type_ == DeviceType::HTA) {
device_.reset(new HexagonDevice(device_type_, thread_pool_.get()));
}
#endif
MACE_CHECK_NOTNULL(device_);
......@@ -506,7 +500,7 @@ MaceStatus MaceEngine::Impl::Init(
}
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
if (device_type_ == HEXAGON || device_type_ == HTA) {
hexagon_controller_ = CreateHexagonControlWrapper(device_type_);
hexagon_controller_ = CreateHexagonControlWrapper(device_.get());
MACE_CHECK(hexagon_controller_->Config(), "hexagon config error");
MACE_CHECK(hexagon_controller_->Init(), "hexagon init error");
hexagon_controller_->SetDebugLevel(
......@@ -518,26 +512,26 @@ MaceStatus MaceEngine::Impl::Init(
}
} else {
#endif
MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def,
device_.get(),
model_data));
MemoryOptimizer mem_optimizer;
// Init model
net_ = std::unique_ptr<NetBase>(new SerialNet(op_registry_.get(),
net_def,
ws_.get(),
device_.get(),
&mem_optimizer));
// Preallocate all output tensors of ops
MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(*net_def,
&mem_optimizer,
device_.get()));
if (device_type_ == DeviceType::GPU) {
ws_->RemoveAndReloadBuffer(*net_def, model_data, device_->allocator());
}
MACE_RETURN_IF_ERROR(net_->Init());
MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def,
device_.get(),
model_data));
MemoryOptimizer mem_optimizer;
// Init model
net_ = std::unique_ptr<NetBase>(new SerialNet(op_registry_.get(),
net_def,
ws_.get(),
device_.get(),
&mem_optimizer));
// Preallocate all output tensors of ops
MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(*net_def,
&mem_optimizer,
device_.get()));
if (device_type_ == DeviceType::GPU) {
ws_->RemoveAndReloadBuffer(*net_def, model_data, device_->allocator());
}
MACE_RETURN_IF_ERROR(net_->Init());
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
......@@ -554,10 +548,10 @@ MaceStatus MaceEngine::Impl::Init(
auto fs = GetFileSystem();
MACE_RETURN_IF_ERROR(fs->NewReadOnlyMemoryRegionFromFile(
model_data_file.c_str(), &model_data_));
model_data_file.c_str(), &model_data_));
MACE_RETURN_IF_ERROR(Init(net_def, input_nodes, output_nodes,
reinterpret_cast<const unsigned char *>(model_data_->data())));
reinterpret_cast<const unsigned char *>(model_data_->data())));
if (device_type_ == DeviceType::GPU || device_type_ == DeviceType::HEXAGON ||
device_type_ == DeviceType::HTA ||
......@@ -611,18 +605,18 @@ MaceStatus MaceEngine::Impl::TransposeInput(
Tensor::MappingGuard input_guard(input_tensor);
if (input_dt == DataType::DT_FLOAT) {
auto input_data = input_tensor->mutable_data<float>();
return ops::Transpose(input.second.data<float>().get(),
return ops::Transpose(thread_pool_.get(),
input.second.data<float>().get(),
input.second.shape(),
dst_dims,
input_data,
input_dt);
input_data);
} else if (input_dt == DataType::DT_INT32) {
auto input_data = input_tensor->mutable_data<int>();
return ops::Transpose(input.second.data<int>().get(),
return ops::Transpose(thread_pool_.get(),
input.second.data<int>().get(),
input.second.shape(),
dst_dims,
input_data,
input_dt);
input_data);
} else {
LOG(FATAL) << "MACE do not support the input data type: " << input_dt;
}
......@@ -668,7 +662,7 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
output->second.data_format() == NCHW) {
dst_dims = {0, 3, 1, 2};
} else {
LOG(FATAL) <<"Not supported output data format: "
LOG(FATAL) << "Not supported output data format: "
<< output->second.data_format() << " vs "
<< output_tensor->data_format();
}
......@@ -688,17 +682,18 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
Tensor::MappingGuard output_guard(output_tensor);
if (output_dt == DataType::DT_FLOAT) {
auto output_data = output_tensor->data<float>();
return ops::Transpose(output_data,
return ops::Transpose(thread_pool_.get(),
output_data,
output_tensor->shape(),
dst_dims,
output->second.data<float>().get());
} else if (output_dt == DataType::DT_INT32) {
auto output_data = output_tensor->data<int>();
return ops::Transpose(output_data,
return ops::Transpose(thread_pool_.get(),
output_data,
output_tensor->shape(),
dst_dims,
output->second.data<int>().get(),
output_dt);
output->second.data<int>().get());
} else {
LOG(FATAL) << "MACE do not support the output data type: " << output_dt;
return MaceStatus::MACE_INVALID_ARGS;
......@@ -719,8 +714,8 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
output_size * sizeof(float));
} else if (output_dt == DataType::DT_INT32) {
std::memcpy(output->second.data<int>().get(),
output_tensor->data<int>(),
output_size * sizeof(int));
output_tensor->data<int>(),
output_size * sizeof(int));
} else {
LOG(FATAL) << "MACE do not support the output data type: " << output_dt;
}
......@@ -736,8 +731,8 @@ MaceStatus MaceEngine::Impl::Run(
std::map<std::string, MaceTensor> *outputs,
RunMetadata *run_metadata) {
MACE_CHECK_NOTNULL(outputs);
std::map<std::string, Tensor*> input_tensors;
std::map<std::string, Tensor*> output_tensors;
std::map<std::string, Tensor *> input_tensors;
std::map<std::string, Tensor *> output_tensors;
for (auto &input : inputs) {
if (input_info_map_.find(input.first) == input_info_map_.end()) {
LOG(FATAL) << "'" << input.first
......@@ -766,7 +761,7 @@ MaceStatus MaceEngine::Impl::Run(
hexagon_controller_->ExecuteGraphNew(input_tensors, &output_tensors);
} else {
#endif
MACE_RETURN_IF_ERROR(net_->Run(run_metadata));
MACE_RETURN_IF_ERROR(net_->Run(run_metadata));
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
......@@ -785,7 +780,7 @@ MaceStatus MaceEngine::Impl::Run(
return MaceStatus::MACE_SUCCESS;
}
MaceEngine::MaceEngine(const MaceEngineConfig &config):
MaceEngine::MaceEngine(const MaceEngineConfig &config) :
impl_(make_unique<MaceEngine::Impl>(config)) {}
MaceEngine::~MaceEngine() = default;
......@@ -797,7 +792,6 @@ MaceStatus MaceEngine::Init(const NetDef *net_def,
return impl_->Init(net_def, input_nodes, output_nodes, model_data);
}
MaceStatus MaceEngine::Init(const NetDef *net_def,
const std::vector<std::string> &input_nodes,
const std::vector<std::string> &output_nodes,
......
......@@ -279,7 +279,6 @@ cc_library(
srcs = glob(
[
"*.cc",
"arm/*.cc", # remove it after refactor
],
exclude = [
"*_test.cc",
......@@ -303,7 +302,6 @@ cc_library(
hdrs = glob(
[
"*.h",
"arm/*.h", # remove it after refactor
],
exclude = [
"ops_registry.h",
......
......@@ -15,9 +15,14 @@
#include "mace/ops/activation.h"
#include <memory>
#include "mace/core/operator.h"
#if defined(MACE_ENABLE_NEON)
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/activation.h"
#endif
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h"
#include "mace/ops/opencl/image/activation.h"
......@@ -27,52 +32,54 @@
namespace mace {
namespace ops {
template <DeviceType D, class T>
template<DeviceType D, class T>
class ActivationOp;
template <>
template<>
class ActivationOp<DeviceType::CPU, float> : public Operation {
public:
explicit ActivationOp(OpConstructContext *context)
: Operation(context),
activation_(ops::StringToActivationType(
activation_type_(ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit",
0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
activation_delegator_(activation_type_,
Operation::GetOptionalArg<float>("max_limit",
0.0f),
Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) {
if (activation_type_ == PRELU) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
MACE_CHECK(this->InputSize() > 1);
const Tensor *alpha = this->Input(1);
const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr, outer_size, input->dim(1), inner_size,
PReLUActivation(context, input_ptr, outer_size, input->dim(1), inner_size,
alpha_ptr, output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
relux_max_limit_, leakyrelu_coefficient_);
activation_delegator_.Compute(context, input, output);
}
return MaceStatus::MACE_SUCCESS;
}
private:
ActivationType activation_;
float relux_max_limit_;
float leakyrelu_coefficient_;
ActivationType activation_type_;
#if defined(MACE_ENABLE_NEON)
arm::fp32::Activation activation_delegator_;
#else
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
class ActivationOp<DeviceType::GPU, T> : public Operation {
......@@ -114,7 +121,6 @@ class ActivationOp<DeviceType::GPU, T> : public Operation {
};
#endif // MACE_ENABLE_OPENCL
void RegisterActivation(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Activation", ActivationOp,
DeviceType::CPU, float);
......
......@@ -20,8 +20,8 @@
#include <string>
#include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/arm/activation_neon.h"
#include "mace/utils/logging.h"
namespace mace {
......@@ -41,118 +41,39 @@ inline ActivationType StringToActivationType(const std::string type) {
} else if (type == "NOOP") {
return ActivationType::NOOP;
} else if (type == "LEAKYRELU") {
return ActivationType ::LEAKYRELU;
return ActivationType::LEAKYRELU;
} else {
LOG(FATAL) << "Unknown activation type: " << type;
}
return ActivationType::NOOP;
}
template <typename T>
void DoActivation(const T *input_ptr,
T *output_ptr,
const index_t size,
const ActivationType type,
const float relux_max_limit,
const float leakyrelu_coefficient) {
MACE_CHECK(DataTypeToEnum<T>::value != DataType::DT_HALF);
switch (type) {
case NOOP:
break;
case RELU:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0));
}
break;
case RELUX:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min(std::max(input_ptr[i], static_cast<T>(0)),
static_cast<T>(relux_max_limit));
}
break;
case TANH:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::tanh(input_ptr[i]);
}
break;
case SIGMOID:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
break;
case LEAKYRELU:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0))
+ leakyrelu_coefficient * std::min(input_ptr[i], static_cast<T>(0));
}
break;
default:
LOG(FATAL) << "Unknown activation type: " << type;
}
}
template<>
inline void DoActivation(const float *input_ptr,
float *output_ptr,
const index_t size,
const ActivationType type,
const float relux_max_limit,
const float leakyrelu_coefficient) {
switch (type) {
case NOOP:
break;
case RELU:
ReluNeon(input_ptr, size, output_ptr);
break;
case RELUX:
ReluxNeon(input_ptr, relux_max_limit, size, output_ptr);
break;
case TANH:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::tanh(input_ptr[i]);
}
break;
case SIGMOID:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
break;
case LEAKYRELU:
LeakyReluNeon(input_ptr, leakyrelu_coefficient, size, output_ptr);
break;
default:
LOG(FATAL) << "Unknown activation type: " << type;
}
}
template <typename T>
void PReLUActivation(const T *input_ptr,
template<typename T>
void PReLUActivation(const OpContext *context,
const T *input_ptr,
const index_t outer_size,
const index_t input_chan,
const index_t inner_size,
const T *alpha_ptr,
T *output_ptr) {
#pragma omp parallel for collapse(3) schedule(runtime)
for (index_t i = 0; i < outer_size; ++i) {
for (index_t chan_idx = 0; chan_idx < input_chan; ++chan_idx) {
for (index_t j = 0; j < inner_size; ++j) {
index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j;
if (input_ptr[idx] < 0) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else {
output_ptr[idx] = input_ptr[idx];
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t chan_idx = start1; chan_idx < end1; chan_idx += step1) {
for (index_t j = 0; j < inner_size; ++j) {
index_t idx = i * input_chan * inner_size + chan_idx * inner_size + j;
if (input_ptr[idx] < 0) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[chan_idx];
} else {
output_ptr[idx] = input_ptr[idx];
}
}
}
}
}
}, 0, outer_size, 1, 0, input_chan, 1);
}
} // namespace ops
......
......@@ -42,61 +42,23 @@ class AddNOp<DeviceType::CPU, float> : public Operation {
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
Tensor *output_tensor = this->Output(0);
size_t input_size = this->inputs_.size();
MACE_RETURN_IF_ERROR(output_tensor->ResizeLike(inputs_[0]));
index_t size = output_tensor->size();
Tensor::MappingGuard output_map(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
memset(output_data, 0, size * sizeof(float));
int64_t cost = size * input_size;
int64_t groups = 1;
if (cost > kCostPerGroup) {
groups = cost / kCostPerGroup;
}
int64_t element_per_group = size / groups;
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(inputs_[0]));
const index_t size = output->size();
std::vector<Tensor::MappingGuard> mappers;
for (size_t i = 0; i < input_size; ++i) {
MACE_CHECK(inputs_[0]->dim_size() == inputs_[i]->dim_size());
MACE_CHECK(inputs_[0]->size() == inputs_[i]->size())
<< "Input 0: " << MakeString(inputs_[0]->shape())
<< ", size: " << inputs_[0]->size() << ". Input " << i << ": "
<< MakeString(inputs_[i]->shape()) << ", size: " << inputs_[i]->size();
mappers.emplace_back(Tensor::MappingGuard(inputs_[i]));
}
Tensor::MappingGuard output_guard(output);
auto output_data = output->mutable_data<float>();
memset(output_data, 0, size * sizeof(float));
#pragma omp parallel for
for (int64_t i = 0; i < size; i += element_per_group) {
int64_t count = std::min(element_per_group, size - i);
int nn = count >> 2;
int remain = count - (nn << 2);
for (size_t j = 0; j < input_size; ++j) {
const float *input_data = inputs_[j]->data<float>();
const float *input_ptr = input_data + i;
float *output_ptr = output_data + i;
for (int k = 0; k < nn; ++k) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t in = vld1q_f32(input_ptr);
float32x4_t out = vld1q_f32(output_ptr);
out = vaddq_f32(out, in);
vst1q_f32(output_ptr, out);
#else
for (int m = 0; m < 4; ++m) {
output_ptr[m] += input_ptr[m];
}
#endif
for (auto &input : inputs_) {
Tensor::MappingGuard input_guard(input);
auto input_data = input->data<float>();
input_ptr += 4;
output_ptr += 4;
}
for (int k = 0; k < remain; ++k) {
*output_ptr += *input_ptr;
++input_ptr;
++output_ptr;
}
for (index_t j = 0; j < size; ++j) {
output_data[j] += input_data[j];
}
}
return MaceStatus::MACE_SUCCESS;
}
};
......
......@@ -71,7 +71,6 @@ class ArgMaxOp : public Operation {
index_t inner_size = input->dim(axis_value);
if (argmin_) {
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T min_value = std::numeric_limits<T>::max();
......@@ -85,7 +84,6 @@ class ArgMaxOp : public Operation {
output_data[i] = idx;
}
} else {
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T max_value = std::numeric_limits<T>::lowest();
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include <algorithm>
#include "mace/ops/arm/activation_neon.h"
namespace mace {
namespace ops {
void ReluNeon(const float *input, const index_t size, float *output) {
#if defined(MACE_ENABLE_NEON)
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i);
v = vmaxq_f32(v, vzero);
vst1q_f32(output + i, v);
}
// remain
for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::max(input[i], 0.f);
}
#else
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], 0.f);
}
#endif
}
void ReluxNeon(const float *input, const float limit,
const index_t size, float *output) {
#if defined(MACE_ENABLE_NEON)
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vlimit = vdupq_n_f32(limit);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i);
v = vmaxq_f32(v, vzero);
v = vminq_f32(v, vlimit);
vst1q_f32(output + i, v);
}
// remain
for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::min(std::max(input[i], 0.f), limit);
}
#else
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output[i] = std::min(std::max(input[i], 0.f), limit);
}
#endif
}
void LeakyReluNeon(const float *input, const float alpha,
const index_t size, float *output) {
#if defined(MACE_ENABLE_NEON)
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t valpha = vdupq_n_f32(alpha);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i);
float32x4_t u = vminq_f32(v, vzero);;
v = vmaxq_f32(v, vzero);
v = vmlaq_f32(v, valpha, u);
vst1q_f32(output + i, v);
}
// remain
for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
}
#else
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], 0.f) + std::min(input[i], 0.f) * alpha;
}
#endif
}
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/utils/macros.h"
#include "mace/ops/arm/deconv_2d_neon.h"
namespace mace {
namespace ops {
void Deconv2dNeonK2x2S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t inch = in_shape[1];
const index_t h = in_shape[2];
const index_t w = in_shape[3];
const index_t outch = out_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t oc = 0; oc < outch; oc += 2) {
if (oc + 1 < outch) {
float *out_base0 = output + (b * outch + oc) * out_img_size;
float *out_base1 = out_base0 + out_img_size;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base0 = filter + (oc * inch + ic) * 4;
const float *kernel_base1 = kernel_base0 + inch * 4;
const float *in = input_base;
// output channel 0
const float *k0 = kernel_base0;
// output channel 1
const float *k1 = kernel_base1;
#if defined(MACE_ENABLE_NEON)
// load filter
float32x4_t k0_vec = vld1q_f32(k0);
float32x4_t k1_vec = vld1q_f32(k1);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base0 = out_base0 + i * outw;
float *out_row0_0 = out_row_base0;
float *out_row0_1 = out_row_base0 + outw;
float *out_row_base1 = out_base1 + i * outw;
float *out_row1_0 = out_row_base1;
float *out_row1_1 = out_row_base1 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00, out01, out02, out03;
float32x4_t out10, out11, out12, out13;
out00 = vld1q_f32(out_row0_0);
out00 = neon_vfma_lane_0(out00, in_vec, k0_vec);
vst1q_f32(out_row0_0, out00);
out01 = vld1q_f32(out_row0_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k0_vec);
vst1q_f32(out_row0_0 + 1, out01);
out02 = vld1q_f32(out_row0_1);
out02 = neon_vfma_lane_2(out02, in_vec, k0_vec);
vst1q_f32(out_row0_1, out02);
out03 = vld1q_f32(out_row0_1 + 1);
out03 = neon_vfma_lane_3(out03, in_vec, k0_vec);
vst1q_f32(out_row0_1 + 1, out03);
out10 = vld1q_f32(out_row1_0);
out10 = neon_vfma_lane_0(out10, in_vec, k1_vec);
vst1q_f32(out_row1_0, out10);
out11 = vld1q_f32(out_row1_0 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k1_vec);
vst1q_f32(out_row1_0 + 1, out11);
out12 = vld1q_f32(out_row1_1);
out12 = neon_vfma_lane_2(out12, in_vec, k1_vec);
vst1q_f32(out_row1_1, out12);
out13 = vld1q_f32(out_row1_1 + 1);
out13 = neon_vfma_lane_3(out13, in_vec, k1_vec);
vst1q_f32(out_row1_1 + 1, out13);
in += 4;
out_row0_0 += 4;
out_row0_1 += 4;
out_row1_0 += 4;
out_row1_1 += 4;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 2; ++k) {
out_row0_0[k] += val * k0[k];
out_row0_1[k] += val * k0[k + 2];
out_row1_0[k] += val * k1[k];
out_row1_1[k] += val * k1[k + 2];
}
in++;
out_row0_0++;
out_row0_1++;
out_row1_0++;
out_row1_1++;
}
}
}
} else {
float *out_base0 = output + (b * outch + oc) * outh * outw;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base0 = filter + (oc * inch + ic) * 4;
const float *in = input_base;
const float *k0 = kernel_base0;
#if defined(MACE_ENABLE_NEON)
// load filter
float32x4_t k0_vec = vld1q_f32(k0);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base0 = out_base0 + i * outw;
float *out_row0_0 = out_row_base0;
float *out_row0_1 = out_row_base0 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00, out01, out02, out03;
out00 = vld1q_f32(out_row0_0);
out00 = neon_vfma_lane_0(out00, in_vec, k0_vec);
vst1q_f32(out_row0_0, out00);
out01 = vld1q_f32(out_row0_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k0_vec);
vst1q_f32(out_row0_0 + 1, out01);
out02 = vld1q_f32(out_row0_1);
out02 = neon_vfma_lane_2(out02, in_vec, k0_vec);
vst1q_f32(out_row0_1, out02);
out03 = vld1q_f32(out_row0_1 + 1);
out03 = neon_vfma_lane_3(out03, in_vec, k0_vec);
vst1q_f32(out_row0_1 + 1, out03);
in += 4;
out_row0_0 += 4;
out_row0_1 += 4;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 2; ++k) {
out_row0_0[k] += val * k0[k];
out_row0_1[k] += val * k0[k + 2];
}
in++;
out_row0_0++;
out_row0_1++;
}
}
}
}
}
}
}
void Deconv2dNeonK2x2S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t inch = in_shape[1];
const index_t h = in_shape[2];
const index_t w = in_shape[3];
const index_t outch = out_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t oc = 0; oc < outch; ++oc) {
float *out_base = output + (b * outch + oc) * out_img_size;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base = filter + (oc * inch + ic) * 4;
const float *in = input_base;
const float *k0 = kernel_base;
#if defined(MACE_ENABLE_NEON)
float32x4_t k0_vec = vld1q_f32(k0);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base = out_base + i * 2 * outw;
float *out_row_0 = out_row_base;
float *out_row_1 = out_row_0 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
// out row 0
float32x4x2_t out00 = vld2q_f32(out_row_0);
out00.val[0] =
neon_vfma_lane_0(out00.val[0], in_vec, k0_vec);
out00.val[1] =
neon_vfma_lane_1(out00.val[1], in_vec, k0_vec);
vst2q_f32(out_row_0, out00);
// out row 1
float32x4x2_t out10 = vld2q_f32(out_row_1);
out10.val[0] =
neon_vfma_lane_2(out10.val[0], in_vec, k0_vec);
out10.val[1] =
neon_vfma_lane_3(out10.val[1], in_vec, k0_vec);
vst2q_f32(out_row_1, out10);
in += 4;
out_row_0 += 8;
out_row_1 += 8;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 2; ++k) {
out_row_0[k] += val * k0[k];
out_row_1[k] += val * k0[k + 2];
}
in++;
out_row_0 += 2;
out_row_1 += 2;
}
}
}
}
}
}
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/utils/macros.h"
#include "mace/ops/arm/deconv_2d_neon.h"
namespace mace {
namespace ops {
void Deconv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t inch = in_shape[1];
const index_t h = in_shape[2];
const index_t w = in_shape[3];
const index_t outch = out_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t oc = 0; oc < outch; oc += 2) {
if (oc + 1 < outch) {
float *out_base0 = output + (b * outch + oc) * out_img_size;
float *out_base1 = out_base0 + out_img_size;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base0 = filter + (oc * inch + ic) * 9;
const float *kernel_base1 = kernel_base0 + inch * 9;
const float *in = input_base;
// output channel 0
const float *k0_0 = kernel_base0;
const float *k0_1 = kernel_base0 + 3;
const float *k0_2 = kernel_base0 + 5;
// output channel 1
const float *k1_0 = kernel_base1;
const float *k1_1 = kernel_base1 + 3;
const float *k1_2 = kernel_base1 + 5;
#if defined(MACE_ENABLE_NEON)
// load filter
float32x4_t k00_vec, k01_vec, k02_vec;
float32x4_t k10_vec, k11_vec, k12_vec;
k00_vec = vld1q_f32(k0_0);
k01_vec = vld1q_f32(k0_1);
k02_vec = vld1q_f32(k0_2);
k10_vec = vld1q_f32(k1_0);
k11_vec = vld1q_f32(k1_1);
k12_vec = vld1q_f32(k1_2);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base0 = out_base0 + i * outw;
float *out_row0_0 = out_row_base0;
float *out_row0_1 = out_row_base0 + outw;
float *out_row0_2 = out_row_base0 + 2 * outw;
float *out_row_base1 = out_base1 + i * outw;
float *out_row1_0 = out_row_base1;
float *out_row1_1 = out_row_base1 + outw;
float *out_row1_2 = out_row_base1 + 2 * outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00, out01, out02;
float32x4_t out10, out11, out12;
float32x4_t out20, out21, out22;
out00 = vld1q_f32(out_row0_0);
out00 = neon_vfma_lane_0(out00, in_vec, k00_vec);
vst1q_f32(out_row0_0, out00);
out01 = vld1q_f32(out_row0_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k00_vec);
vst1q_f32(out_row0_0 + 1, out01);
out02 = vld1q_f32(out_row0_0 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k00_vec);
vst1q_f32(out_row0_0 + 2, out02);
out10 = vld1q_f32(out_row0_1 + 0);
out10 = neon_vfma_lane_0(out10, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 0, out10);
out11 = vld1q_f32(out_row0_1 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 1, out11);
out12 = vld1q_f32(out_row0_1 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 2, out12);
out20 = vld1q_f32(out_row0_2 + 0);
out20 = neon_vfma_lane_1(out20, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 0, out20);
out21 = vld1q_f32(out_row0_2 + 1);
out21 = neon_vfma_lane_2(out21, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 1, out21);
out22 = vld1q_f32(out_row0_2 + 2);
out22 = neon_vfma_lane_3(out22, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 2, out22);
out00 = vld1q_f32(out_row1_0 + 0);
out00 = neon_vfma_lane_0(out00, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 0, out00);
out01 = vld1q_f32(out_row1_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 1, out01);
out02 = vld1q_f32(out_row1_0 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 2, out02);
out10 = vld1q_f32(out_row1_1 + 0);
out10 = neon_vfma_lane_0(out10, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 0, out10);
out11 = vld1q_f32(out_row1_1 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 1, out11);
out12 = vld1q_f32(out_row1_1 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 2, out12);
out20 = vld1q_f32(out_row1_2 + 0);
out20 = neon_vfma_lane_1(out20, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 0, out20);
out21 = vld1q_f32(out_row1_2 + 1);
out21 = neon_vfma_lane_2(out21, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 1, out21);
out22 = vld1q_f32(out_row1_2 + 2);
out22 = neon_vfma_lane_3(out22, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 2, out22);
in += 4;
out_row0_0 += 4;
out_row0_1 += 4;
out_row0_2 += 4;
out_row1_0 += 4;
out_row1_1 += 4;
out_row1_2 += 4;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 3; ++k) {
out_row0_0[k] += val * k0_0[k];
out_row0_1[k] += val * k0_1[k];
out_row0_2[k] += val * k0_2[k + 1];
out_row1_0[k] += val * k1_0[k];
out_row1_1[k] += val * k1_1[k];
out_row1_2[k] += val * k1_2[k + 1];
}
in++;
out_row0_0++;
out_row0_1++;
out_row0_2++;
out_row1_0++;
out_row1_1++;
out_row1_2++;
}
}
}
} else {
float *out_base0 = output + (b * outch + oc) * outh * outw;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base0 = filter + (oc * inch + ic) * 9;
const float *in = input_base;
const float *k0_0 = kernel_base0;
const float *k0_1 = kernel_base0 + 3;
const float *k0_2 = kernel_base0 + 5;
#if defined(MACE_ENABLE_NEON)
// load filter
float32x4_t k00_vec = vld1q_f32(k0_0);
float32x4_t k01_vec = vld1q_f32(k0_1);
float32x4_t k02_vec = vld1q_f32(k0_2);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base0 = out_base0 + i * outw;
float *out_row0_0 = out_row_base0;
float *out_row0_1 = out_row_base0 + outw;
float *out_row0_2 = out_row_base0 + 2 * outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00, out01, out02;
float32x4_t out10, out11, out12;
float32x4_t out20, out21, out22;
out00 = vld1q_f32(out_row0_0 + 0);
out00 = neon_vfma_lane_0(out00, in_vec, k00_vec);
vst1q_f32(out_row0_0 + 0, out00);
out01 = vld1q_f32(out_row0_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k00_vec);
vst1q_f32(out_row0_0 + 1, out01);
out02 = vld1q_f32(out_row0_0 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k00_vec);
vst1q_f32(out_row0_0 + 2, out02);
out10 = vld1q_f32(out_row0_1 + 0);
out10 = neon_vfma_lane_0(out10, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 0, out10);
out11 = vld1q_f32(out_row0_1 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 1, out11);
out12 = vld1q_f32(out_row0_1 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k01_vec);
vst1q_f32(out_row0_1 + 2, out12);
out20 = vld1q_f32(out_row0_2 + 0);
out20 = neon_vfma_lane_1(out20, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 0, out20);
out21 = vld1q_f32(out_row0_2 + 1);
out21 = neon_vfma_lane_2(out21, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 1, out21);
out22 = vld1q_f32(out_row0_2 + 2);
out22 = neon_vfma_lane_3(out22, in_vec, k02_vec);
vst1q_f32(out_row0_2 + 2, out22);
in += 4;
out_row0_0 += 4;
out_row0_1 += 4;
out_row0_2 += 4;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 3; ++k) {
out_row0_0[k] += val * k0_0[k];
out_row0_1[k] += val * k0_1[k];
out_row0_2[k] += val * k0_2[k + 1];
}
in++;
out_row0_0++;
out_row0_1++;
out_row0_2++;
}
}
}
}
}
}
}
void Deconv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t inch = in_shape[1];
const index_t h = in_shape[2];
const index_t w = in_shape[3];
const index_t outch = out_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t oc = 0; oc < outch; ++oc) {
float *out_base = output + (b * outch + oc) * out_img_size;
for (index_t ic = 0; ic < inch; ++ic) {
const float *input_base = input + (b * inch + ic) * h * w;
const float *kernel_base = filter + (oc * inch + ic) * 9;
const float *in = input_base;
const float *k0 = kernel_base;
const float *k1 = kernel_base + 3;
const float *k2 = kernel_base + 5;
#if defined(MACE_ENABLE_NEON)
float32x4_t k0_vec = vld1q_f32(k0);
float32x4_t k1_vec = vld1q_f32(k1);
float32x4_t k2_vec = vld1q_f32(k2);
#endif
for (index_t i = 0; i < h; ++i) {
float *out_row_base = out_base + i * 2 * outw;
float *out_row_0 = out_row_base;
float *out_row_1 = out_row_0 + outw;
float *out_row_2 = out_row_1 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (index_t n = 0; n + 9 < outw; n += 8) {
float32x4_t in_vec = vld1q_f32(in);
// out row 0
float32x4x2_t out00 = vld2q_f32(out_row_0);
out00.val[0] =
neon_vfma_lane_0(out00.val[0], in_vec, k0_vec);
out00.val[1] =
neon_vfma_lane_1(out00.val[1], in_vec, k0_vec);
vst2q_f32(out_row_0, out00);
float32x4x2_t out01 = vld2q_f32(out_row_0 + 2);
out01.val[0] =
neon_vfma_lane_2(out01.val[0], in_vec, k0_vec);
vst2q_f32(out_row_0 + 2, out01);
// out row 1
float32x4x2_t out10 = vld2q_f32(out_row_1);
out10.val[0] =
neon_vfma_lane_0(out10.val[0], in_vec, k1_vec);
out10.val[1] =
neon_vfma_lane_1(out10.val[1], in_vec, k1_vec);
vst2q_f32(out_row_1, out10);
float32x4x2_t out11 = vld2q_f32(out_row_1 + 2);
out11.val[0] =
neon_vfma_lane_2(out11.val[0], in_vec, k1_vec);
vst2q_f32(out_row_1 + 2, out11);
// out row 2
float32x4x2_t out20 = vld2q_f32(out_row_2);
out20.val[0] =
neon_vfma_lane_1(out20.val[0], in_vec, k2_vec);
out20.val[1] =
neon_vfma_lane_2(out20.val[1], in_vec, k2_vec);
vst2q_f32(out_row_2, out20);
float32x4x2_t out21 = vld2q_f32(out_row_2 + 2);
out21.val[0] =
neon_vfma_lane_3(out21.val[0], in_vec, k2_vec);
vst2q_f32(out_row_2 + 2, out21);
in += 4;
out_row_0 += 8;
out_row_1 += 8;
out_row_2 += 8;
j += 4;
}
#endif
for (; j < w; ++j) {
float val = in[0];
for (int k = 0; k < 3; ++k) {
out_row_0[k] += val * k0[k];
out_row_1[k] += val * k1[k];
out_row_2[k] += val * k2[k + 1];
}
in++;
out_row_0 += 2;
out_row_1 += 2;
out_row_2 += 2;
}
}
}
}
}
}
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/utils/macros.h"
#include "mace/ops/arm/deconv_2d_neon.h"
namespace mace {
namespace ops {
void Deconv2dNeonK4x4S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t w = in_shape[3];
const index_t h = in_shape[2];
const index_t inch = in_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t outch = out_shape[1];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t oc = 0; oc < outch; oc += 2) {
if (oc + 1 < outch) {
float *out_base = output + (b * outch + oc) * out_img_size;
float *out_base1 = out_base + out_img_size;
for (index_t q = 0; q < inch; q++) {
const float *input_base = input + (b * inch + q) * h * w;
const float *in = input_base;
const float *kernel_base = filter + (oc * inch + q) * 16;
const float *k0 = kernel_base;
const float *k1 = kernel_base + 4;
const float *k2 = kernel_base + 8;
const float *k3 = kernel_base + 12;
const float *kernel_base1 = kernel_base + inch * 16;
const float *k10 = kernel_base1;
const float *k11 = kernel_base1 + 4;
const float *k12 = kernel_base1 + 8;
const float *k13 = kernel_base1 + 12;
#if defined(MACE_ENABLE_NEON)
float32x4_t k0_vec = vld1q_f32(k0);
float32x4_t k1_vec = vld1q_f32(k1);
float32x4_t k2_vec = vld1q_f32(k2);
float32x4_t k3_vec = vld1q_f32(k3);
float32x4_t k10_vec = vld1q_f32(k10);
float32x4_t k11_vec = vld1q_f32(k11);
float32x4_t k12_vec = vld1q_f32(k12);
float32x4_t k13_vec = vld1q_f32(k13);
#endif
for (index_t i = 0; i < h; i++) {
float *out_row = out_base + i * outw;
float *out_row_0 = out_row;
float *out_row_1 = out_row_0 + outw;
float *out_row_2 = out_row_1 + outw;
float *out_row_3 = out_row_2 + outw;
float *out_row1 = out_base1 + i * outw;
float *out_row1_0 = out_row1;
float *out_row1_1 = out_row1_0 + outw;
float *out_row1_2 = out_row1_1 + outw;
float *out_row1_3 = out_row1_2 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00, out01, out02, out03;
float32x4_t out10, out11, out12, out13;
out00 = vld1q_f32(out_row_0);
out00 = neon_vfma_lane_0(out00, in_vec, k0_vec);
vst1q_f32(out_row_0, out00);
out10 = vld1q_f32(out_row1_0);
out10 = neon_vfma_lane_0(out10, in_vec, k10_vec);
vst1q_f32(out_row1_0, out10);
out01 = vld1q_f32(out_row_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k0_vec);
vst1q_f32(out_row_0 + 1, out01);
out11 = vld1q_f32(out_row1_0 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 1, out11);
out02 = vld1q_f32(out_row_0 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k0_vec);
vst1q_f32(out_row_0 + 2, out02);
out12 = vld1q_f32(out_row1_0 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 2, out12);
out03 = vld1q_f32(out_row_0 + 3);
out03 = neon_vfma_lane_3(out03, in_vec, k0_vec);
vst1q_f32(out_row_0 + 3, out03);
out13 = vld1q_f32(out_row1_0 + 3);
out13 = neon_vfma_lane_3(out13, in_vec, k10_vec);
vst1q_f32(out_row1_0 + 3, out13);
//
out00 = vld1q_f32(out_row_1);
out00 = neon_vfma_lane_0(out00, in_vec, k1_vec);
vst1q_f32(out_row_1, out00);
out10 = vld1q_f32(out_row1_1);
out10 = neon_vfma_lane_0(out10, in_vec, k11_vec);
vst1q_f32(out_row1_1, out10);
out01 = vld1q_f32(out_row_1 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k1_vec);
vst1q_f32(out_row_1 + 1, out01);
out11 = vld1q_f32(out_row1_1 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 1, out11);
out02 = vld1q_f32(out_row_1 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k1_vec);
vst1q_f32(out_row_1 + 2, out02);
out12 = vld1q_f32(out_row1_1 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 2, out12);
out03 = vld1q_f32(out_row_1 + 3);
out03 = neon_vfma_lane_3(out03, in_vec, k1_vec);
vst1q_f32(out_row_1 + 3, out03);
out13 = vld1q_f32(out_row1_1 + 3);
out13 = neon_vfma_lane_3(out13, in_vec, k11_vec);
vst1q_f32(out_row1_1 + 3, out13);
//
out00 = vld1q_f32(out_row_2 + 0);
out00 = neon_vfma_lane_0(out00, in_vec, k2_vec);
vst1q_f32(out_row_2 + 0, out00);
out10 = vld1q_f32(out_row1_2 + 0);
out10 = neon_vfma_lane_0(out10, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 0, out10);
out01 = vld1q_f32(out_row_2 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k2_vec);
vst1q_f32(out_row_2 + 1, out01);
out11 = vld1q_f32(out_row1_2 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 1, out11);
out02 = vld1q_f32(out_row_2 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k2_vec);
vst1q_f32(out_row_2 + 2, out02);
out12 = vld1q_f32(out_row1_2 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 2, out12);
out03 = vld1q_f32(out_row_2 + 3);
out03 = neon_vfma_lane_3(out03, in_vec, k2_vec);
vst1q_f32(out_row_2 + 3, out03);
out13 = vld1q_f32(out_row1_2 + 3);
out13 = neon_vfma_lane_3(out13, in_vec, k12_vec);
vst1q_f32(out_row1_2 + 3, out13);
//
out00 = vld1q_f32(out_row_3 + 0);
out00 = neon_vfma_lane_0(out00, in_vec, k3_vec);
vst1q_f32(out_row_3 + 0, out00);
out10 = vld1q_f32(out_row1_3 + 0);
out10 = neon_vfma_lane_0(out10, in_vec, k13_vec);
vst1q_f32(out_row1_3 + 0, out10);
out01 = vld1q_f32(out_row_3 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k3_vec);
vst1q_f32(out_row_3 + 1, out01);
out11 = vld1q_f32(out_row1_3 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k13_vec);
vst1q_f32(out_row1_3 + 1, out11);
out02 = vld1q_f32(out_row_3 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k3_vec);
vst1q_f32(out_row_3 + 2, out02);
out12 = vld1q_f32(out_row1_3 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k13_vec);
vst1q_f32(out_row1_3 + 2, out12);
out03 = vld1q_f32(out_row_3 + 3);
out03 = neon_vfma_lane_3(out03, in_vec, k3_vec);
vst1q_f32(out_row_3 + 3, out03);
out13 = vld1q_f32(out_row1_3 + 3);
out13 = neon_vfma_lane_3(out13, in_vec, k13_vec);
vst1q_f32(out_row1_3 + 3, out13);
in += 4;
out_row_0 += 4;
out_row_1 += 4;
out_row_2 += 4;
out_row_3 += 4;
out_row1_0 += 4;
out_row1_1 += 4;
out_row1_2 += 4;
out_row1_3 += 4;
}
#endif
for (; j < w; j++) {
float val = in[0];
for (int k = 0; k < 4; ++k) {
out_row_0[k] += val * k0[k];
out_row_1[k] += val * k1[k];
out_row_2[k] += val * k2[k];
out_row_3[k] += val * k3[k];
out_row1_0[k] += val * k10[k];
out_row1_1[k] += val * k11[k];
out_row1_2[k] += val * k12[k];
out_row1_3[k] += val * k13[k];
}
in++;
out_row_0++;
out_row_1++;
out_row_2++;
out_row_3++;
out_row1_0++;
out_row1_1++;
out_row1_2++;
out_row1_3++;
}
}
}
} else {
float *out_base = output + (b * outch + oc) * out_img_size;
for (index_t q = 0; q < inch; q++) {
const float *input_base = input + (b * inch + q) * h * w;
const float *kernel_base = filter + (oc * inch + q) * 16;
const float *in = input_base;
const float *k0 = kernel_base;
const float *k1 = kernel_base + 4;
const float *k2 = kernel_base + 8;
const float *k3 = kernel_base + 12;
#if defined(MACE_ENABLE_NEON)
float32x4_t k0_vec = vld1q_f32(k0);
float32x4_t k1_vec = vld1q_f32(k1);
float32x4_t k2_vec = vld1q_f32(k2);
float32x4_t k3_vec = vld1q_f32(k3);
#endif
for (index_t i = 0; i < h; i++) {
float *out_row = out_base + i * outw;
float *out_row_0 = out_row;
float *out_row_1 = out_row_0 + outw;
float *out_row_2 = out_row_1 + outw;
float *out_row_3 = out_row_2 + outw;
int j = 0;
#if defined(MACE_ENABLE_NEON)
for (; j + 3 < w; j += 4) {
float32x4_t in_vec = vld1q_f32(in);
float32x4_t out00 = vld1q_f32(out_row_0);
out00 = neon_vfma_lane_0(out00, in_vec, k0_vec);
vst1q_f32(out_row_0, out00);
float32x4_t out01 = vld1q_f32(out_row_0 + 1);
out01 = neon_vfma_lane_1(out01, in_vec, k0_vec);
vst1q_f32(out_row_0 + 1, out01);
float32x4_t out02 = vld1q_f32(out_row_0 + 2);
out02 = neon_vfma_lane_2(out02, in_vec, k0_vec);
vst1q_f32(out_row_0 + 2, out02);
float32x4_t out03 = vld1q_f32(out_row_0 + 3);
out03 = neon_vfma_lane_3(out03, in_vec, k0_vec);
vst1q_f32(out_row_0 + 3, out03);
//
float32x4_t out10 = vld1q_f32(out_row_1);
out10 = neon_vfma_lane_0(out10, in_vec, k1_vec);
vst1q_f32(out_row_1, out10);
float32x4_t out11 = vld1q_f32(out_row_1 + 1);
out11 = neon_vfma_lane_1(out11, in_vec, k1_vec);
vst1q_f32(out_row_1 + 1, out11);
float32x4_t out12 = vld1q_f32(out_row_1 + 2);
out12 = neon_vfma_lane_2(out12, in_vec, k1_vec);
vst1q_f32(out_row_1 + 2, out12);
float32x4_t out13 = vld1q_f32(out_row_1 + 3);
out13 = neon_vfma_lane_3(out13, in_vec, k1_vec);
vst1q_f32(out_row_1 + 3, out13);
//
float32x4_t out20 = vld1q_f32(out_row_2 + 0);
out20 = neon_vfma_lane_0(out20, in_vec, k2_vec);
vst1q_f32(out_row_2 + 0, out20);
float32x4_t out21 = vld1q_f32(out_row_2 + 1);
out21 = neon_vfma_lane_1(out21, in_vec, k2_vec);
vst1q_f32(out_row_2 + 1, out21);
float32x4_t out22 = vld1q_f32(out_row_2 + 2);
out22 = neon_vfma_lane_2(out22, in_vec, k2_vec);
vst1q_f32(out_row_2 + 2, out22);
float32x4_t out23 = vld1q_f32(out_row_2 + 3);
out23 = neon_vfma_lane_3(out23, in_vec, k2_vec);
vst1q_f32(out_row_2 + 3, out23);
//
float32x4_t out30 = vld1q_f32(out_row_3 + 0);
out30 = neon_vfma_lane_0(out30, in_vec, k3_vec);
vst1q_f32(out_row_3 + 0, out30);
float32x4_t out31 = vld1q_f32(out_row_3 + 1);
out31 = neon_vfma_lane_1(out31, in_vec, k3_vec);
vst1q_f32(out_row_3 + 1, out31);
float32x4_t out32 = vld1q_f32(out_row_3 + 2);
out32 = neon_vfma_lane_2(out32, in_vec, k3_vec);
vst1q_f32(out_row_3 + 2, out32);
float32x4_t out33 = vld1q_f32(out_row_3 + 3);
out33 = neon_vfma_lane_3(out33, in_vec, k3_vec);
vst1q_f32(out_row_3 + 3, out33);
in += 4;
out_row_0 += 4;
out_row_1 += 4;
out_row_2 += 4;
out_row_3 += 4;
}
#endif
for (; j < w; j++) {
float val = in[0];
for (int k = 0; k < 4; ++k) {
out_row_0[k] += val * k0[k];
out_row_1[k] += val * k1[k];
out_row_2[k] += val * k2[k];
out_row_3[k] += val * k3[k];
}
in++;
out_row_0++;
out_row_1++;
out_row_2++;
out_row_3++;
}
}
}
}
}
}
}
void Deconv2dNeonK4x4S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output) {
const index_t w = in_shape[3];
const index_t h = in_shape[2];
const index_t inch = in_shape[1];
const index_t outh = out_shape[2];
const index_t outw = out_shape[3];
const index_t outch = out_shape[1];
const index_t out_img_size = outh * outw;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < out_shape[0]; ++b) {
for (index_t p = 0; p < outch; p++) {
float *out_base = output + (b * outch + p) * out_img_size;
for (index_t q = 0; q < inch; q++) {
const float *input_base = input + (b * inch + q) * h * w;
const float *kernel_base = filter + (p * inch + q) * 16;
const float *in = input_base;
const float *k0 = kernel_base;
const float *k1 = kernel_base + 4;
const float *k2 = kernel_base + 8;
const float *k3 = kernel_base + 12;
#if defined(MACE_ENABLE_NEON)
float32x4_t k0_vec = vld1q_f32(k0);
float32x4_t k1_vec = vld1q_f32(k1);
float32x4_t k2_vec = vld1q_f32(k2);
float32x4_t k3_vec = vld1q_f32(k3);
#endif
for (index_t i = 0; i < h; i++) {
float *out_row = out_base + 2 * i * outw;
float *out_row_0 = out_row;
float *out_row_1 = out_row_0 + outw;
float *out_row_2 = out_row_1 + outw;
float *out_row_3 = out_row_2 + outw;
index_t j = 0;
#if defined(MACE_ENABLE_NEON)
for (index_t n = 0; n + 9 < outw; n += 8) {
float32x4_t in_vec = vld1q_f32(in);
// row 0
float32x4x2_t out0 = vld2q_f32(out_row_0);
out0.val[0] =
neon_vfma_lane_0(out0.val[0], in_vec, k0_vec);
out0.val[1] =
neon_vfma_lane_1(out0.val[1], in_vec, k0_vec);
vst2q_f32(out_row_0, out0);
out0 = vld2q_f32(out_row_0 + 2);
out0.val[0] =
neon_vfma_lane_2(out0.val[0], in_vec, k0_vec);
out0.val[1] =
neon_vfma_lane_3(out0.val[1], in_vec, k0_vec);
vst2q_f32(out_row_0 + 2, out0);
// row 1
float32x4x2_t out1 = vld2q_f32(out_row_1);
out1.val[0] =
neon_vfma_lane_0(out1.val[0], in_vec, k1_vec);
out1.val[1] =
neon_vfma_lane_1(out1.val[1], in_vec, k1_vec);
vst2q_f32(out_row_1, out1);
out1 = vld2q_f32(out_row_1 + 2);
out1.val[0] =
neon_vfma_lane_2(out1.val[0], in_vec, k1_vec);
out1.val[1] =
neon_vfma_lane_3(out1.val[1], in_vec, k1_vec);
vst2q_f32(out_row_1 + 2, out1);
// row 2
float32x4x2_t out2 = vld2q_f32(out_row_2);
out2.val[0] =
neon_vfma_lane_0(out2.val[0], in_vec, k2_vec);
out2.val[1] =
neon_vfma_lane_1(out2.val[1], in_vec, k2_vec);
vst2q_f32(out_row_2, out2);
out2 = vld2q_f32(out_row_2 + 2);
out2.val[0] =
neon_vfma_lane_2(out2.val[0], in_vec, k2_vec);
out2.val[1] =
neon_vfma_lane_3(out2.val[1], in_vec, k2_vec);
vst2q_f32(out_row_2 + 2, out2);
// row 3
float32x4x2_t out3 = vld2q_f32(out_row_3);
out3.val[0] =
neon_vfma_lane_0(out3.val[0], in_vec, k3_vec);
out3.val[1] =
neon_vfma_lane_1(out3.val[1], in_vec, k3_vec);
vst2q_f32(out_row_3, out3);
out3 = vld2q_f32(out_row_3 + 2);
out3.val[0] =
neon_vfma_lane_2(out3.val[0], in_vec, k3_vec);
out3.val[1] =
neon_vfma_lane_3(out3.val[1], in_vec, k3_vec);
vst2q_f32(out_row_3 + 2, out3);
in += 4;
out_row_0 += 8;
out_row_1 += 8;
out_row_2 += 8;
out_row_3 += 8;
j += 4;
}
#endif
for (; j < w; j++) {
float val = in[0];
for (int k = 0; k < 4; ++k) {
out_row_0[k] += val * k0[k];
out_row_1[k] += val * k1[k];
out_row_2[k] += val * k2[k];
out_row_3[k] += val * k3[k];
}
in++;
out_row_0 += 2;
out_row_1 += 2;
out_row_2 += 2;
out_row_3 += 2;
}
}
}
}
}
}
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
#define MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
#include "mace/core/types.h"
#include "mace/ops/arm/common_neon.h"
namespace mace {
namespace ops {
void DepthwiseDeconv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK4x4S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void DepthwiseDeconv2dNeonK4x4S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK3x3S1(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK3x3S2(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK4x4S1(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
void GroupDeconv2dNeonK4x4S2(const float *input,
const float *filter,
const int group,
const index_t *in_shape,
const index_t *out_shape,
float *output);
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_DEPTHWISE_DECONV2D_NEON_H_
此差异已折叠。
此差异已折叠。
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/activation.h"
#include <arm_neon.h>
#include <algorithm>
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
Activation::Activation(ActivationType type,
const float limit,
const float leakyrelu_coefficient)
: type_(type),
limit_(limit),
leakyrelu_coefficient_(leakyrelu_coefficient) {}
MaceStatus Activation::Compute(const OpContext *context,
const Tensor *input,
Tensor *output) {
Tensor::MappingGuard input_guard(input);
if (input != output) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
Tensor::MappingGuard output_guard(output);
DoActivation(context, input, output);
} else {
DoActivation(context, input, output);
}
return MaceStatus::MACE_SUCCESS;
}
void Activation::DoActivation(const OpContext *context,
const Tensor *input,
Tensor *output) {
auto input_data = input->data<float>();
auto output_data = output->mutable_data<float>();
const index_t size = input->size();
utils::ThreadPool &thread_pool =
context->device()->cpu_runtime()->thread_pool();
switch (type_) {
case RELU: {
const float32x4_t vzero = vdupq_n_f32(0.f);
const index_t block_count = size / 4;
thread_pool.Compute1D(
[=](index_t start, index_t end, index_t step) {
auto input_ptr = input_data + start * 4;
auto output_ptr = output_data + start * 4;
for (index_t i = start; i < end; i += step) {
float32x4_t v = vld1q_f32(input_ptr);
v = vmaxq_f32(v, vzero);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
},
0, block_count, 1);
// remain
for (index_t i = block_count * 4; i < size; ++i) {
output_data[i] = std::max(0.f, input_data[i]);
}
break;
}
case RELUX: {
const float32x4_t vzero = vdupq_n_f32(0.f);
const float32x4_t vlimit = vdupq_n_f32(limit_);
const index_t block_count = size / 4;
thread_pool.Compute1D(
[=](index_t start, index_t end, index_t step) {
auto input_ptr = input_data + start * 4;
auto output_ptr = output_data + start * 4;
for (index_t i = start; i < end; i += step) {
float32x4_t v = vld1q_f32(input_ptr);
v = vmaxq_f32(v, vzero);
v = vminq_f32(v, vlimit);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
},
0, block_count, 1);
// remain
for (index_t i = block_count * 4; i < size; ++i) {
output_data[i] = std::max(0.f, std::min(limit_, input_data[i]));
}
break;
}
case LEAKYRELU: {
const float32x4_t vzero = vdupq_n_f32(0.f);
const float32x4_t valpha = vdupq_n_f32(leakyrelu_coefficient_);
const index_t block_count = size / 4;
thread_pool.Compute1D(
[=](index_t start, index_t end, index_t step) {
auto input_ptr = input_data + start * 4;
auto output_ptr = output_data + start * 4;
for (index_t i = start; i < end; i += step) {
float32x4_t v = vld1q_f32(input_ptr);
float32x4_t u = vminq_f32(v, vzero);
v = vmaxq_f32(v, vzero);
v = vmlaq_f32(v, valpha, u);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
},
0, block_count, 1);
// remain
for (index_t i = block_count * 4; i < size; ++i) {
output_data[i] = std::max(input_data[i], 0.f) +
std::min(input_data[i], 0.f) * leakyrelu_coefficient_;
}
break;
}
case TANH: {
thread_pool.Compute1D(
[=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = std::tanh(input_data[i]);
}
},
0, size, 1);
break;
}
case SIGMOID: {
thread_pool.Compute1D(
[=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = 1 / (1 + std::exp(-(input_data[i])));
}
},
0, size, 1);
break;
}
case NOOP:
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_ACTIVATION_H_
#define MACE_OPS_ARM_FP32_ACTIVATION_H_
#include "mace/core/op_context.h"
#include "mace/ops/common/activation_type.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Activation {
public:
explicit Activation(ActivationType type,
const float limit,
const float leakyrelu_coefficient);
~Activation() = default;
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
Tensor *output);
private:
void DoActivation(const OpContext *context,
const Tensor *input,
Tensor *output);
ActivationType type_;
const float limit_;
const float leakyrelu_coefficient_;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_ACTIVATION_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/bias_add.h"
#include <arm_neon.h>
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
MaceStatus BiasAdd::Compute(const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output) {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard bias_guard(bias);
if (input != output) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
if (bias == nullptr) {
output->Copy(*input);
} else {
Tensor::MappingGuard output_guard(output);
AddBias(context, input, bias, output);
}
} else {
if (bias != nullptr) {
AddBias(context, input, bias, output);
}
}
return MaceStatus::MACE_SUCCESS;
}
void BiasAdd::AddBias(const OpContext *context,
const Tensor *input,
const Tensor *bias,
mace::Tensor *output) {
auto input_data = input->data<float>();
auto bias_data = bias->data<float>();
auto output_data = output->mutable_data<float>();
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = output->dim(2);
const index_t width = output->dim(3);
const index_t image_size = height * width;
const index_t block_count = image_size / 4;
const index_t remain = image_size % 4;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
}
}, 0, batch, 1, 0, channels, 1);
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_BIAS_ADD_H_
#define MACE_OPS_ARM_FP32_BIAS_ADD_H_
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class BiasAdd {
public:
BiasAdd() = default;
~BiasAdd() = default;
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output);
private:
void AddBias(const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output);
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_BIAS_ADD_H_
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_COMMON_NEON_H_
#define MACE_OPS_ARM_COMMON_NEON_H_
#ifndef MACE_OPS_ARM_FP32_COMMON_NEON_H_
#define MACE_OPS_ARM_FP32_COMMON_NEON_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
......@@ -21,6 +21,8 @@
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
#ifdef MACE_ENABLE_NEON
inline float32x4_t neon_vfma_lane_0(float32x4_t a,
......@@ -64,7 +66,9 @@ inline float32x4_t neon_vfma_lane_3(float32x4_t a,
}
#endif
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_COMMON_NEON_H_
#endif // MACE_OPS_ARM_FP32_COMMON_NEON_H_
......@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/conv_2d.h"
#include <memory>
#include <utility>
#include <algorithm>
#include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/utils/memory.h"
namespace mace {
......@@ -195,7 +196,7 @@ MaceStatus Conv2dBase::ResizeOutAndPadInOut(const OpContext *context,
void Conv2dBase::PadInput(const Tensor &src,
const int pad_top,
const int pad_left,
mace::Tensor *dst) {
Tensor *dst) {
if (dst == &src) return;
const index_t batch = src.dim(0);
const index_t channels = src.dim(1);
......@@ -211,7 +212,6 @@ void Conv2dBase::PadInput(const Tensor &src,
const index_t img_size = height * width;
const index_t padded_img_size = padded_height * padded_width;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t bc = b * channels + c;
......@@ -238,7 +238,7 @@ void Conv2dBase::PadInput(const Tensor &src,
}
}
void Conv2dBase::UnPadOutput(const mace::Tensor &src, mace::Tensor *dst) {
void Conv2dBase::UnPadOutput(const Tensor &src, Tensor *dst) {
if (dst == &src) return;
const index_t batch = dst->dim(0);
const index_t channels = dst->dim(1);
......@@ -253,7 +253,6 @@ void Conv2dBase::UnPadOutput(const mace::Tensor &src, mace::Tensor *dst) {
const index_t img_size = height * width;
const index_t padded_img_size = padded_height * padded_width;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t bc = (b * channels + c);
......
......@@ -31,9 +31,9 @@ namespace fp32 {
class Conv2dBase {
public:
Conv2dBase(const std::vector<int> strides,
const std::vector<int> dilations,
const std::vector<int> paddings,
Conv2dBase(const std::vector<int> &strides,
const std::vector<int> &dilations,
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
......
......@@ -29,7 +29,7 @@ namespace fp32 {
class Conv2dK1x1 : public Conv2dBase {
public:
Conv2dK1x1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK1x1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK1x1() {}
......@@ -37,7 +37,7 @@ class Conv2dK1x1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
private:
Gemm gemm_;
......
此差异已折叠。
......@@ -28,7 +28,7 @@ namespace fp32 {
class Conv2dK1x7S1 : public Conv2dBase {
public:
Conv2dK1x7S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK1x7S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK1x7S1() {}
......@@ -36,12 +36,12 @@ class Conv2dK1x7S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK7x1S1 : public Conv2dBase {
public:
Conv2dK7x1S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK7x1S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK7x1S1() {}
......@@ -49,12 +49,12 @@ class Conv2dK7x1S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK1x15S1 : public Conv2dBase {
public:
Conv2dK1x15S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK1x15S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK1x15S1() {}
......@@ -62,12 +62,12 @@ class Conv2dK1x15S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK15x1S1 : public Conv2dBase {
public:
Conv2dK15x1S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK15x1S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK15x1S1() {}
......@@ -75,7 +75,7 @@ class Conv2dK15x1S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
} // namespace fp32
......
此差异已折叠。
......@@ -28,7 +28,7 @@ namespace fp32 {
class Conv2dK3x3S1 : public Conv2dBase {
public:
Conv2dK3x3S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK3x3S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK3x3S1() {}
......@@ -36,12 +36,12 @@ class Conv2dK3x3S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK3x3S2 : public Conv2dBase {
public:
Conv2dK3x3S2(const std::vector<int> paddings, const Padding padding_type)
Conv2dK3x3S2(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK3x3S2() {}
......@@ -49,7 +49,7 @@ class Conv2dK3x3S2 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
} // namespace fp32
......
......@@ -31,7 +31,7 @@ namespace fp32 {
class Conv2dK3x3Winograd : public Conv2dBase {
public:
Conv2dK3x3Winograd(const std::vector<int> paddings,
Conv2dK3x3Winograd(const std::vector<int> &paddings,
const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type),
gemm_(),
......@@ -44,20 +44,23 @@ class Conv2dK3x3Winograd : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
private:
void TransformFilter4x4(const float *filter,
void TransformFilter4x4(const OpContext *context,
const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void TransformFilter8x8(const float *filter,
void TransformFilter8x8(const OpContext *context,
const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output);
void TransformInput4x4(const float *input,
void TransformInput4x4(const OpContext *context,
const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
......@@ -65,7 +68,8 @@ class Conv2dK3x3Winograd : public Conv2dBase {
const index_t tile_count,
float *output);
void TransformInput8x8(const float *input,
void TransformInput8x8(const OpContext *context,
const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
......@@ -73,7 +77,8 @@ class Conv2dK3x3Winograd : public Conv2dBase {
const index_t tile_count,
float *output);
void TransformOutput4x4(const float *input,
void TransformOutput4x4(const OpContext *context,
const float *input,
index_t batch,
index_t out_height,
index_t out_width,
......@@ -81,7 +86,8 @@ class Conv2dK3x3Winograd : public Conv2dBase {
index_t tile_count,
float *output);
void TransformOutput8x8(const float *input,
void TransformOutput8x8(const OpContext *context,
const float *input,
index_t batch,
index_t out_height,
index_t out_width,
......
此差异已折叠。
......@@ -28,7 +28,7 @@ namespace fp32 {
class Conv2dK5x5S1 : public Conv2dBase {
public:
Conv2dK5x5S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK5x5S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK5x5S1() {}
......@@ -36,7 +36,7 @@ class Conv2dK5x5S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
......
此差异已折叠。
......@@ -28,7 +28,7 @@ namespace fp32 {
class Conv2dK7x7S1 : public Conv2dBase {
public:
Conv2dK7x7S1(const std::vector<int> paddings, const Padding padding_type)
Conv2dK7x7S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK7x7S1() {}
......@@ -36,12 +36,12 @@ class Conv2dK7x7S1 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK7x7S2 : public Conv2dBase {
public:
Conv2dK7x7S2(const std::vector<int> paddings, const Padding padding_type)
Conv2dK7x7S2(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK7x7S2() {}
......@@ -49,12 +49,12 @@ class Conv2dK7x7S2 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
class Conv2dK7x7S3 : public Conv2dBase {
public:
Conv2dK7x7S3(const std::vector<int> paddings, const Padding padding_type)
Conv2dK7x7S3(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({3, 3}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK7x7S3() {}
......@@ -62,7 +62,7 @@ class Conv2dK7x7S3 : public Conv2dBase {
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output);
Tensor *output) override;
};
} // namespace fp32
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -56,7 +56,6 @@ class ChannelShuffleOp<DeviceType::CPU, T> : public Operation {
index_t batch_size = channels * image_size;
index_t channels_per_group = channels / groups_;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
index_t g = c % groups_;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册