提交 d043ef53 编写于 作者: 卢旭辉

Merge branch 'fp16' into 'master'

Add armv8.2 fp16 for MobileNet

See merge request applied-machine-learning/sysml/mace!1303
......@@ -11,6 +11,7 @@ option(MACE_ENABLE_HEXAGON_DSP "whether to enable Hexagon DSP support" OFF)
option(MACE_ENABLE_HEXAGON_HTA "whether to enable Hexagon HTA support" OFF)
option(MACE_ENABLE_MTK_APU "whether to enable MTK APU support" OFF)
option(MACE_ENABLE_BFLOAT16 "whether to enable bfloat16 support" OFF)
option(MACE_ENABLE_FP16 "whether to enable armv8.2 fp16 support" OFF)
option(MACE_ENABLE_TESTS "whether to build c++ unit tests" OFF)
option(MACE_ENABLE_BENCHMARKS "whether to build c++ micro benchmarks" OFF)
option(MACE_ENABLE_OPT_SIZE "whether to build with optimized binary size" ON)
......@@ -121,6 +122,10 @@ if(MACE_ENABLE_BFLOAT16)
add_definitions(-DMACE_ENABLE_BFLOAT16)
endif(MACE_ENABLE_BFLOAT16)
if(MACE_ENABLE_FP16)
add_definitions(-DMACE_ENABLE_FP16)
endif(MACE_ENABLE_FP16)
if(MACE_ENABLE_OBFUSCATE)
add_definitions(-DMACE_OBFUSCATE_LITERALS)
endif(MACE_ENABLE_OBFUSCATE)
......
......@@ -85,7 +85,7 @@ in one deployment file.
* - runtime
- The running device, one of [cpu, gpu, dsp, cpu+gpu]. cpu+gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU; [fp16_fp32, bf16_fp32, fp32_fp32] for CPU, default is fp16_fp32.
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU; [fp16_fp32, bf16_fp32, fp32_fp32, fp16_fp16] for CPU, default is fp16_fp32.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - input_data_formats
......@@ -584,9 +584,10 @@ Therefore, the default storage type for a regular model in MACE is half. However
if the model is very sensitive to accuracy, storage type can be changed to float.
In the deployment file, ``data_type`` is ``fp16_fp32`` by default and can be changed to ``fp32_fp32``,
for CPU it can also be changed to ``bf16_fp32``.
for CPU it can also be changed to ``bf16_fp32`` and ``fp16_fp16``(``fp16_fp16`` can only be used on armv8.2 or higher version).
For CPU, ``fp16_fp32`` means that the weights are saved in half and actual inference is in float; while ``bf16_fp32`` means that the weights are saved in bfloat16 and actual inference is in float.
For CPU, ``fp16_fp32`` means that the weights are saved in half and actual inference is in float; while ``bf16_fp32`` means that the weights are saved in bfloat16 and actual inference is in float,85G
and ``fp16_fp16`` means that the weights are saved in half and actual inference is in half.
For GPU, ``fp16_fp32`` means that the ops in GPU take half as inputs and outputs while kernel execution in float.
......
......@@ -63,7 +63,7 @@ There are many advanced options supported.
* - runtime
- The running device, one of [cpu, gpu, dsp, cpu+gpu]. cpu+gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU; [fp16_fp32, bf16_fp32, fp32_fp32] for CPU, default is fp16_fp32.
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU; [fp16_fp32, bf16_fp32, fp32_fp32, fp16_fp16] for CPU, default is fp16_fp32.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - input_data_formats
......@@ -439,9 +439,12 @@ Therefore, the default storage type for a regular model in MACE is half. However
if the model is very sensitive to accuracy, storage type can be changed to float.
In the deployment file, ``data_type`` is ``fp16_fp32`` by default and can be changed to ``fp32_fp32``,
for CPU it can also be changed to ``bf16_fp32``.
for CPU it can also be changed to ``bf16_fp32`` and ``fp16_fp16``(``fp16_fp16`` can only be used on armv8.2 or higher version).
For CPU, ``fp16_fp32`` means that the weights are saved in half and actual inference is in float,
while ``bf16_fp32`` means that the weights are saved in bfloat16 and actual inference is in float,
and ``fp16_fp16`` means that the weights are saved in half and actual inference is in half.
For CPU, ``fp16_fp32`` means that the weights are saved in half and actual inference is in float; while ``bf16_fp32`` means that the weights are saved in bfloat16 and actual inference is in float.
For GPU, ``fp16_fp32`` means that the ops in GPU take half as inputs and outputs while kernel execution in float.
......
......@@ -132,6 +132,14 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "fp16_enabled",
define_values = {
"fp16": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "rpcmem_enabled",
define_values = {
......
......@@ -10,6 +10,7 @@ load(
"if_android_armv7",
"if_apu_enabled",
"if_bfloat16_enabled",
"if_fp16_enabled",
"if_hexagon_enabled",
"if_hexagon_or_hta_enabled",
"if_hta_enabled",
......@@ -86,6 +87,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_FP16_H_
#define MACE_CORE_FP16_H_
#ifdef MACE_ENABLE_FP16
#include <arm_neon.h>
#include <algorithm>
#include <cmath>
#include <sstream>
namespace std {
inline float fabs(const float16_t &value) {
return fabs(static_cast<float>(value));
}
inline float abs(const float16_t &value) {
return abs(static_cast<float>(value));
}
inline float sqrt(const float16_t &value) {
return sqrt(static_cast<float>(value));
}
inline float log(const float16_t &value) {
return log(static_cast<float>(value));
}
inline float tanh(const float16_t &value) {
return tanh(static_cast<float>(value));
}
inline float exp(const float16_t &value) {
return exp(static_cast<float>(value));
}
inline int ceil(const float16_t &value) {
return ceil(static_cast<float>(value));
}
inline int floor(const float16_t &value) {
return floor(static_cast<float>(value));
}
inline float max(const float16_t &a, const float &b) {
return max(static_cast<float>(a), b);
}
inline float max(const float &a, const float16_t &b) {
return max(a, static_cast<float>(b));
}
inline float min(const float16_t &a, const float &b) {
return min(static_cast<float>(a), b);
}
inline float min(const float &a, const float16_t &b) {
return min(a, static_cast<float>(b));
}
inline float pow(const float16_t &a, const float16_t &b) {
return pow(static_cast<float>(a), static_cast<float>(b));
}
inline float pow(const float16_t &a, const float &b) {
return pow(static_cast<float>(a), b);
}
inline float pow(const float &a, const float16_t &b) {
return pow(a, static_cast<float>(b));
}
inline ostream &operator<<(ostream &ss, // NOLINT
const float16_t &value) {
return ss << static_cast<float>(value);
}
} // namespace std
#endif // MACE_ENABLE_FP16
#endif // MACE_CORE_FP16_H_
......@@ -22,6 +22,7 @@
#include <vector>
#include "mace/core/bfloat16.h"
#include "mace/core/fp16.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
......@@ -101,6 +102,15 @@ class OpDelegatorRegistry {
#endif // MACE_ENABLE_BFLOAT16
#endif // MACE_REGISTER_BF16_DELEGATOR
#ifndef MACE_REGISTER_FP16_DELEGATOR
#ifdef MACE_ENABLE_FP16
#define MACE_REGISTER_FP16_DELEGATOR(registry, class_name, param_name, key) \
MACE_REGISTER_DELEGATOR(registry, class_name, param_name, key)
#else
#define MACE_REGISTER_FP16_DELEGATOR(registry, class_name, param_name, key)
#endif // MACE_ENABLE_FP16
#endif // MACE_REGISTER_FP16_DELEGATOR
#ifndef MACE_DEFINE_DELEGATOR_CREATOR
#define MACE_DEFINE_DELEGATOR_CREATOR(class_name) \
static std::unique_ptr<class_name> Create( \
......
......@@ -23,6 +23,7 @@
#include <vector>
#include "mace/core/bfloat16.h"
#include "mace/core/fp16.h"
#include "mace/core/types.h"
#include "mace/core/ops/operator.h"
#include "mace/core/ops/op_condition_builder.h"
......@@ -102,6 +103,27 @@ class OpRegistry {
#endif // MACE_ENABLE_BFLOAT16
#endif // MACE_REGISTER_BF16_OP_BY_CLASS
#ifndef MACE_REGISTER_FP16_OP
#ifdef MACE_ENABLE_FP16
#define MACE_REGISTER_FP16_OP(op_registry, op_type, class_name, device) \
MACE_REGISTER_OP(op_registry, op_type, class_name, device, float16_t)
#else
#define MACE_REGISTER_FP16_OP(op_registry, op_type, class_name, device)
#endif // MACE_ENABLE_FP16
#endif // MACE_REGISTER_FP16_OP
#ifndef MACE_REGISTER_FP16_OP_BY_CLASS
#ifdef MACE_ENABLE_FP16
#define MACE_REGISTER_FP16_OP_BY_CLASS(op_registry, op_type, \
class_name, device) \
MACE_REGISTER_OP_BY_CLASS(op_registry, op_type, \
class_name, device, float16_t)
#else
#define MACE_REGISTER_FP16_OP_BY_CLASS(op_registry, op_type, \
class_name, device)
#endif // MACE_ENABLE_FP16
#endif // MACE_REGISTER_FP16_OP_BY_CLASS
#ifdef MACE_ENABLE_OPENCL
#define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name) \
op_registry->Register( \
......
......@@ -46,7 +46,8 @@ namespace mace {
break; \
}
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) \
|| defined(MACE_ENABLE_FP16)
#define MACE_TYPE_ENUM_SWITCH_CASE_NEON(STATEMENTS) \
MACE_CASE(float16_t, MACE_SINGLE_ARG(STATEMENTS))
#else
......@@ -60,6 +61,13 @@ namespace mace {
#define MACE_TYPE_ENUM_SWITCH_CASE_BFLOAT16(STATEMENTS)
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
#define MACE_TYPE_ENUM_SWITCH_CASE_FP16(STATEMENTS) \
MACE_CASE(float16_t, MACE_SINGLE_ARG(STATEMENTS))
#else
#define MACE_TYPE_ENUM_SWITCH_CASE_FP16(STATEMENTS)
#endif // MACE_ENABLE_FP16
#if MACE_ENABLE_OPENCL
#define MACE_TYPE_ENUM_SWITCH_CASE_OPENCL(STATEMENTS) \
MACE_CASE(half, MACE_SINGLE_ARG(STATEMENTS))
......
......@@ -26,6 +26,7 @@ bool DataTypeCanUseMemcpy(DataType dt) {
case DT_UINT8:
case DT_INT32:
case DT_BFLOAT16:
case DT_FLOAT16:
return true;
default:
return false;
......@@ -38,7 +39,8 @@ std::string DataTypeToString(const DataType dt) {
{DT_HALF, "DT_HALF"},
{DT_UINT8, "DT_UINT8"},
{DT_INT32, "DT_INT32"},
{DT_BFLOAT16, "DT_BFLOAT16"}};
{DT_BFLOAT16, "DT_BFLOAT16"},
{DT_FLOAT16, "DT_FLOAT16"}};
MACE_CHECK(dt != DT_INVALID, "Not support Invalid data type");
return dtype_string_map[dt];
}
......@@ -49,7 +51,8 @@ size_t GetEnumTypeSize(const DataType dt) {
return sizeof(float);
case DT_HALF:
return sizeof(half);
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) || \
defined(MACE_ENABLE_FP16)
case DT_FLOAT16:
return sizeof(float16_t);
#endif
......
......@@ -55,7 +55,8 @@ struct EnumToDataType;
};
MACE_MAPPING_DATA_TYPE_AND_ENUM(half, DT_HALF);
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) \
|| defined(MACE_ENABLE_FP16)
MACE_MAPPING_DATA_TYPE_AND_ENUM(float16_t, DT_FLOAT16);
#endif
#ifdef MACE_ENABLE_BFLOAT16
......
......@@ -13,6 +13,7 @@ load(
"if_android_armv7",
"if_apu_enabled",
"if_bfloat16_enabled",
"if_fp16_enabled",
"if_darwin",
"if_hexagon_enabled",
"if_hta_enabled",
......@@ -44,6 +45,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
......
......@@ -846,9 +846,8 @@ MaceStatus MaceEngine::Impl::TransposeInput(
} else {
LOG(FATAL) << "Invalid net data type: " << net_data_type_;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
} else if (input_dt == DataType::DT_FLOAT16 ||
input_dt == DataType::DT_BFLOAT16) {
#ifdef MACE_ENABLE_BFLOAT16
} else if (input_dt == DataType::DT_BFLOAT16) {
auto *input_data = input_tensor->mutable_data<BFloat16>();
return ops::Transpose(thread_pool_.get(),
input.second.data<float>().get(),
......@@ -856,6 +855,16 @@ MaceStatus MaceEngine::Impl::TransposeInput(
dst_dims,
input_data);
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
} else if (input_dt == DataType::DT_FLOAT16) {
auto *input_data = input_tensor->mutable_data<float16_t>();
return ops::Transpose(thread_pool_.get(),
input.second.data<float>().get(),
input.second.shape(),
dst_dims,
input_data);
#endif // MACE_ENABLE_FP16
} else if (input_dt == DataType::DT_INT32) {
auto input_data = input_tensor->mutable_data<int>();
return ops::Transpose(thread_pool_.get(),
......@@ -882,15 +891,23 @@ MaceStatus MaceEngine::Impl::TransposeInput(
} else {
LOG(FATAL) << "Invalid net data type: " << net_data_type_;
}
#ifdef MACE_ENABLE_BFLOAT16 // todo(lichao): add float16 macro
} else if (input_dt == DataType::DT_FLOAT16 ||
input_dt == DataType::DT_BFLOAT16) {
#ifdef MACE_ENABLE_BFLOAT16
} else if (input_dt == DataType::DT_BFLOAT16) {
auto input_data = input_tensor->mutable_data<BFloat16>();
const float *data = input.second.data().get();
for (index_t i = 0; i < input_tensor->size(); ++i) {
input_data[i] = data[i];
}
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
} else if (input_dt == DataType::DT_FLOAT16) {
auto input_data = input_tensor->mutable_data<float16_t>();
const float *data = input.second.data().get();
for (index_t i = 0; i < input_tensor->size(); ++i) {
input_data[i] = data[i];
}
#endif // MACE_ENABLE_FP16
} else if (input_dt == DataType::DT_INT32) {
auto input_data = input_tensor->mutable_data<int>();
memcpy(input_data, input.second.data().get(),
......@@ -963,6 +980,15 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
dst_dims,
output->second.data<float>().get());
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
} else if (output_dt == DataType::DT_FLOAT16) {
auto output_data = output_tensor->data<float16_t>();
return ops::Transpose(thread_pool_.get(),
output_data,
output_tensor->shape(),
dst_dims,
output->second.data<float>().get());
#endif // MACE_ENABLE_FP16
} else {
LOG(FATAL) << "MACE do not support the output data type: " << output_dt;
return MaceStatus::MACE_INVALID_ARGS;
......@@ -993,6 +1019,14 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
data[i] = output_data[i];
}
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
} else if (output_dt == DataType::DT_FLOAT16) {
const auto *output_data = output_tensor->data<float16_t>();
float *data = output->second.data<float>().get();
for (index_t i = 0; i < output_tensor->size(); ++i) {
data[i] = output_data[i];
}
#endif // MACE_ENABLE_FP16
} else {
LOG(FATAL) << "MACE do not support the output data type: " << output_dt;
}
......
......@@ -109,6 +109,12 @@ def if_bfloat16_enabled(a):
"//conditions:default": [],
})
def if_fp16_enabled(a):
return select({
"//mace:fp16_enabled": a,
"//conditions:default": [],
})
def if_rpcmem_enabled(a, default_value = []):
return select({
"//mace:rpcmem_enabled": a,
......
......@@ -11,6 +11,7 @@ load(
"if_android",
"if_android_armv7",
"if_bfloat16_enabled",
"if_fp16_enabled",
"if_hexagon_enabled",
"if_neon_enabled",
"if_opencl_enabled",
......@@ -46,6 +47,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......@@ -85,6 +89,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......@@ -112,6 +119,10 @@ cc_library(
[
"arm/bf16/*.cc",
],
)) + if_fp16_enabled(glob(
[
"arm/fp16/*.cc",
],
)),
hdrs = glob(
[
......@@ -126,6 +137,10 @@ cc_library(
[
"arm/bf16/*.h",
],
)) + if_fp16_enabled(glob(
[
"arm/fp16/*.h",
],
)),
copts = [
"-Werror",
......@@ -142,6 +157,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......@@ -225,6 +243,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......@@ -265,6 +286,9 @@ cc_library(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......
......@@ -14,6 +14,9 @@ file(GLOB OPS_ARM_NEON_FP32_KERNELS_SRCS
file(GLOB OPS_ARM_NEON_BF16_KERNELS_SRCS
arm/bf16/*.cc
)
file(GLOB OPS_ARM_NEON_FP16_KERNELS_SRCS
arm/fp16/*.cc
)
file(GLOB OPS_ARM_NEON_Q8_KERNELS_SRCS
arm/q8/*.cc
)
......@@ -45,6 +48,9 @@ if(MACE_ENABLE_NEON)
if(MACE_ENABLE_BFLOAT16)
set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_BF16_KERNELS_SRCS})
endif(MACE_ENABLE_BFLOAT16)
if(MACE_ENABLE_FP16)
set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_FP16_KERNELS_SRCS})
endif(MACE_ENABLE_FP16)
endif(MACE_ENABLE_NEON)
if(MACE_ENABLE_OPENCL)
......
......@@ -152,6 +152,40 @@ inline void vst1o(float *ptr, float32x8_t v) {
vst1q_f32(ptr + 4, v.val[1]);
}
#if defined(MACE_ENABLE_AMR82)
// load of 4D vector
inline float16x4_t vld1(const float16_t *ptr) {
return vld1_fp16(ptr);
}
// store of 4D vector
inline void vst1(float16_t *ptr, float16x4_t v) {
vst1_fp16(ptr, v);
}
// load of 8D vector
inline float16x8_t vld1q(const float16_t *ptr) {
return vld1q_fp16(ptr);
}
// load of 2 8D vectors and perform de-interleaving
inline float16x8x2_t vld2q(const float16_t *ptr) {
return vld2q_fp16(ptr);
}
// store of 8D vector
inline void vst1q(float16_t *ptr, const float16x8_t v) {
vst1q_fp16(ptr, v);
}
// store of 2 8D vectors and perform interleaving
inline void vst2q(float16_t *ptr, const float16x8x2_t v) {
vst2q_fp16(ptr, v);
}
#endif // MACE_ENABLE_FP16
#if defined(MACE_ENABLE_BFLOAT16)
// load of 2D vector
......
......@@ -101,6 +101,10 @@ void RegisterConv2dK1x1Delegator(OpDelegatorRegistry *registry) {
registry, Conv2dK1x1<BFloat16>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU,
BFloat16, ImplType::NEON, K1x1));
MACE_REGISTER_FP16_DELEGATOR(
registry, Conv2dK1x1<float16_t>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU,
float16_t, ImplType::NEON, K1x1));
}
} // namespace arm
......
......@@ -36,6 +36,14 @@ void RegisterConv2dK3x3Delegator(OpDelegatorRegistry *registry) {
registry, Conv2dK3x3S2<BFloat16>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU,
BFloat16, ImplType::NEON, K3x3S2));
MACE_REGISTER_FP16_DELEGATOR(
registry, Conv2dK3x3S1<float16_t>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU,
float16_t, ImplType::NEON, K3x3S1));
MACE_REGISTER_FP16_DELEGATOR(
registry, Conv2dK3x3S2<float16_t>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU,
float16_t, ImplType::NEON, K3x3S2));
}
} // namespace arm
......
......@@ -20,6 +20,16 @@ namespace mace {
namespace ops {
namespace arm {
extern template
MaceStatus DepthwiseConv2dK3x3S1<float16_t>::DoCompute(
const DepthwiseConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data);
extern template
MaceStatus DepthwiseConv2dK3x3S2<float16_t>::DoCompute(
const DepthwiseConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data);
namespace {
template<typename T>
void DepthwiseConv2d3x3Pixel(const T *in_base,
......@@ -464,6 +474,16 @@ void RegisterDepthwiseConv2dK3x3Delegator(OpDelegatorRegistry *registry) {
delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU,
BFloat16, ImplType::NEON, K3x3S2));
MACE_REGISTER_FP16_DELEGATOR(
registry, DepthwiseConv2dK3x3S1<float16_t>,
delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU,
float16_t, ImplType::NEON, K3x3S1));
MACE_REGISTER_FP16_DELEGATOR(
registry, DepthwiseConv2dK3x3S2<float16_t>,
delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU,
float16_t, ImplType::NEON, K3x3S2));
}
} // namespace arm
......
......@@ -27,6 +27,7 @@ namespace mace {
namespace ops {
namespace arm {
template<typename T>
class DepthwiseConv2dK3x3S1 : public DepthwiseConv2dKMxN<T> {
public:
......
......@@ -23,6 +23,24 @@ namespace mace {
namespace ops {
namespace arm {
extern template void Gemm<float16_t>::Pack8x4(
const MatrixMap<const float16_t> &matrix,
MatrixMajor dst_major, float16_t *packed_matrix);
extern template void Gemm<float16_t>::Unpack8x8(
const float16_t *packed_output, MatrixMap<float16_t> *output);
extern template void Gemm<float16_t>::PackLhs(
const MatrixMap<const float16_t> &lhs, float16_t *packed_lhs);
extern template void Gemm<float16_t>::PackRhs(
const MatrixMap<const float16_t> &rhs, float16_t *packed_rhs);
extern template void Gemm<float16_t>::UnpackOutput(
const float16_t *packed_output, MatrixMap<float16_t> *output);
extern template MaceStatus Gemm<float16_t>::Compute(
const OpContext *context, const Tensor *lhs, const Tensor *rhs,
const index_t batch, const index_t rows, const index_t cols,
const index_t depth, const MatrixMajor lhs_major,
const MatrixMajor rhs_major, const MatrixMajor output_major,
const bool lhs_batched, const bool rhs_batched, Tensor *output);
template<typename T>
void Gemm<T>::Pack4x4(const MatrixMap<const T> &matrix,
MatrixMajor dst_major, T *packed_matrix) {
......@@ -681,9 +699,9 @@ MaceStatus Gemm<T>::Compute(
depth_padded,
packed_output_data_block);
MatrixMap<T> output_block = output_matrix.block(start_row,
start_col,
row_block_len,
col_block_len);
start_col,
row_block_len,
col_block_len);
UnpackOutput(packed_output_data_block, &output_block);
} // col_block_idx
} // row_block_idx
......@@ -701,6 +719,10 @@ void RegisterGemmDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Gemm<BFloat16>, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::NEON));
MACE_REGISTER_FP16_DELEGATOR(
registry, Gemm<float16_t>, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::NEON));
}
} // namespace arm
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include <memory>
#include "mace/ops/arm/base/conv_2d_3x3.h"
#include "mace/ops/delegator/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
template<>
MaceStatus Conv2dK3x3S1<float16_t>::DoCompute(
const ConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data) {
p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t m = start1; m < end1; m += step1) {
if (m + 1 < p.out_channels) {
float16_t *out_ptr0_base =
output_data + b * p.out_batch_size + m * p.out_image_size;
float16_t *out_ptr1_base =
output_data + b * p.out_batch_size + (m + 1) * p.out_image_size;
for (index_t c = 0; c < p.in_channels; ++c) {
const float16_t *in_ptr0 =
input_data + b * p.in_batch_size + c * p.in_image_size;
const float16_t
*filter_ptr0 = filter_data + m * p.in_channels * 9 + c * 9;
float16_t *out_ptr1 = out_ptr1_base;
const float16_t *in_ptr1 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 1 * p.in_width;
const float16_t *in_ptr2 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 2 * p.in_width;
const float16_t *in_ptr3 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 3 * p.in_width;
const float16_t *filter_ptr1 =
filter_data + (m + 1) * p.in_channels * 9 + c * 9;
float16_t *out_ptr0 = out_ptr0_base;
// load filter (2 outch x 3 height x 3 width): vf_outch_height
float16x8_t vf00, vf01;
float16x8_t vf10, vf11;
vf00 = vld1q_f16(filter_ptr0);
vf01 = vld1q_f16(filter_ptr0 + 8);
vf10 = vld1q_f16(filter_ptr1);
vf11 = vld1q_f16(filter_ptr1 + 8);
for (index_t h = 0; h + 1 < p.out_height; h += 2) {
for (index_t w = 0; w + 3 < p.out_width; w += 8) {
// input (4 height x 3 slide): vi_height_slide
float16x8_t vi00, vi01, vi02; // reg count: 14
float16x8_t vi10, vi11, vi12;
float16x8_t vi20, vi21, vi22;
float16x8_t vi30, vi31, vi32;
float16x8_t vo20, vo30; // tmp use
// output (4 outch x 2 height x 8 width): vo_outch_height
float16x8_t vo00, vo01;
float16x8_t vo10, vo11;
// load input
vi00 = vld1q_f16(in_ptr0);
vo00 = vld1q_f16(in_ptr0 + 8); // reuse vo00: vi0n
vi10 = vld1q_f16(in_ptr1);
vo10 = vld1q_f16(in_ptr1 + 8);
vi20 = vld1q_f16(in_ptr2);
vo20 = vld1q_f16(in_ptr2 + 8);
vi30 = vld1q_f16(in_ptr3);
vo30 = vld1q_f16(in_ptr3 + 8);
vi01 = vextq_f16(vi00, vo00, 1);
vi02 = vextq_f16(vi00, vo00, 2);
vi11 = vextq_f16(vi10, vo10, 1);
vi12 = vextq_f16(vi10, vo10, 2);
vi21 = vextq_f16(vi20, vo20, 1);
vi22 = vextq_f16(vi20, vo20, 2);
vi31 = vextq_f16(vi30, vo30, 1);
vi32 = vextq_f16(vi30, vo30, 2);
// load ouptut
vo00 = vld1q_f16(out_ptr0);
vo01 = vld1q_f16(out_ptr0 + p.out_width);
vo10 = vld1q_f16(out_ptr1);
vo11 = vld1q_f16(out_ptr1 + p.out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f16(vo00, vi00, vf00, 0); // reg count: 18
vo00 = vfmaq_laneq_f16(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f16(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f16(vo00, vi10, vf00, 3);
vo00 = vfmaq_laneq_f16(vo00, vi11, vf00, 4);
vo00 = vfmaq_laneq_f16(vo00, vi12, vf00, 5);
vo00 = vfmaq_laneq_f16(vo00, vi20, vf00, 6);
vo00 = vfmaq_laneq_f16(vo00, vi21, vf00, 7);
vo00 = vfmaq_laneq_f16(vo00, vi22, vf01, 0);
// outch 0, height 1
vo01 = vfmaq_laneq_f16(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f16(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f16(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f16(vo01, vi20, vf00, 3);
vo01 = vfmaq_laneq_f16(vo01, vi21, vf00, 4);
vo01 = vfmaq_laneq_f16(vo01, vi22, vf00, 5);
vo01 = vfmaq_laneq_f16(vo01, vi30, vf00, 6);
vo01 = vfmaq_laneq_f16(vo01, vi31, vf00, 7);
vo01 = vfmaq_laneq_f16(vo01, vi32, vf01, 0);
// outch 1, height 0
vo10 = vfmaq_laneq_f16(vo10, vi00, vf10, 0);
vo10 = vfmaq_laneq_f16(vo10, vi01, vf10, 1);
vo10 = vfmaq_laneq_f16(vo10, vi02, vf10, 2);
vo10 = vfmaq_laneq_f16(vo10, vi10, vf10, 3);
vo10 = vfmaq_laneq_f16(vo10, vi11, vf10, 4);
vo10 = vfmaq_laneq_f16(vo10, vi12, vf10, 5);
vo10 = vfmaq_laneq_f16(vo10, vi20, vf10, 6);
vo10 = vfmaq_laneq_f16(vo10, vi21, vf10, 7);
vo10 = vfmaq_laneq_f16(vo10, vi22, vf11, 0);
// outch 1, height 1
vo11 = vfmaq_laneq_f16(vo11, vi10, vf10, 0);
vo11 = vfmaq_laneq_f16(vo11, vi11, vf10, 1);
vo11 = vfmaq_laneq_f16(vo11, vi12, vf10, 2);
vo11 = vfmaq_laneq_f16(vo11, vi20, vf10, 3);
vo11 = vfmaq_laneq_f16(vo11, vi21, vf10, 4);
vo11 = vfmaq_laneq_f16(vo11, vi22, vf10, 5);
vo11 = vfmaq_laneq_f16(vo11, vi30, vf10, 6);
vo11 = vfmaq_laneq_f16(vo11, vi31, vf10, 7);
vo11 = vfmaq_laneq_f16(vo11, vi32, vf11, 0);
vst1q_f16(out_ptr0, vo00);
vst1q_f16(out_ptr0 + p.out_width, vo01);
vst1q_f16(out_ptr1, vo10);
vst1q_f16(out_ptr1 + p.out_width, vo11);
in_ptr0 += 8;
in_ptr1 += 8;
in_ptr2 += 8;
in_ptr3 += 8;
out_ptr0 += 8;
out_ptr1 += 8;
} // w
in_ptr0 += 2 + p.in_width;
in_ptr1 += 2 + p.in_width;
in_ptr2 += 2 + p.in_width;
in_ptr3 += 2 + p.in_width;
out_ptr0 += p.out_width;
out_ptr1 += p.out_width;
} // h
} // c
} else {
for (index_t mm = m; mm < p.out_channels; ++mm) {
float16_t *out_ptr0_base =
output_data + b * p.out_batch_size + mm * p.out_image_size;
for (index_t c = 0; c < p.in_channels; ++c) {
const float16_t *in_ptr0 =
input_data + b * p.in_batch_size + c * p.in_image_size;
const float16_t *in_ptr1 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 1 * p.in_width;
const float16_t *in_ptr2 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 2 * p.in_width;
const float16_t *in_ptr3 =
input_data + b * p.in_batch_size + c * p.in_image_size
+ 3 * p.in_width;
const float16_t
*filter_ptr0 = filter_data + mm * p.in_channels * 9 + c * 9;
float16_t *out_ptr0 = out_ptr0_base;
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float16x8_t vf00, vf01;
vf00 = vld1q_f16(filter_ptr0);
vf01 = vld1q_f16(filter_ptr0 + 8);
for (index_t h = 0; h + 1 < p.out_height; h += 2) {
for (index_t w = 0; w + 3 < p.out_width; w += 8) {
// input (4 height x 3 slide): vi_height_slide
float16x8_t vi00, vi01, vi02, vi0n;
float16x8_t vi10, vi11, vi12, vi1n;
float16x8_t vi20, vi21, vi22, vi2n;
float16x8_t vi30, vi31, vi32, vi3n;
// output (1 outch x 2 height x 8 width): vo_outch_height
float16x8_t vo00, vo01;
// load input
vi00 = vld1q_f16(in_ptr0);
vi0n = vld1q_f16(in_ptr0 + 8);
vi10 = vld1q_f16(in_ptr1);
vi1n = vld1q_f16(in_ptr1 + 8);
vi20 = vld1q_f16(in_ptr2);
vi2n = vld1q_f16(in_ptr2 + 8);
vi30 = vld1q_f16(in_ptr3);
vi3n = vld1q_f16(in_ptr3 + 8);
vi01 = vextq_f16(vi00, vi0n, 1);
vi02 = vextq_f16(vi00, vi0n, 2);
vi11 = vextq_f16(vi10, vi1n, 1);
vi12 = vextq_f16(vi10, vi1n, 2);
vi21 = vextq_f16(vi20, vi2n, 1);
vi22 = vextq_f16(vi20, vi2n, 2);
vi31 = vextq_f16(vi30, vi3n, 1);
vi32 = vextq_f16(vi30, vi3n, 2);
// load ouptut
vo00 = vld1q_f16(out_ptr0);
vo01 = vld1q_f16(out_ptr0 + p.out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f16(vo00, vi00, vf00, 0);
vo00 = vfmaq_laneq_f16(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f16(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f16(vo00, vi10, vf00, 3);
vo00 = vfmaq_laneq_f16(vo00, vi11, vf00, 4);
vo00 = vfmaq_laneq_f16(vo00, vi12, vf00, 5);
vo00 = vfmaq_laneq_f16(vo00, vi20, vf00, 6);
vo00 = vfmaq_laneq_f16(vo00, vi21, vf00, 7);
vo00 = vfmaq_laneq_f16(vo00, vi22, vf01, 0);
// outch 0, height 1
vo01 = vfmaq_laneq_f16(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f16(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f16(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f16(vo01, vi20, vf00, 3);
vo01 = vfmaq_laneq_f16(vo01, vi21, vf00, 4);
vo01 = vfmaq_laneq_f16(vo01, vi22, vf00, 5);
vo01 = vfmaq_laneq_f16(vo01, vi30, vf00, 6);
vo01 = vfmaq_laneq_f16(vo01, vi31, vf00, 7);
vo01 = vfmaq_laneq_f16(vo01, vi32, vf01, 0);
vst1q_f16(out_ptr0, vo00);
vst1q_f16(out_ptr0 + p.out_width, vo01);
in_ptr0 += 8;
in_ptr1 += 8;
in_ptr2 += 8;
in_ptr3 += 8;
out_ptr0 += 8;
} // w
in_ptr0 += 2 + p.in_width;
in_ptr1 += 2 + p.in_width;
in_ptr2 += 2 + p.in_width;
in_ptr3 += 2 + p.in_width;
out_ptr0 += p.out_width;
} // h
} // c
} // mm
} // if
} // m
} // b
}, 0, p.batch, 1, 0, p.out_channels, 2);
return MaceStatus::MACE_SUCCESS;
}
template<>
MaceStatus Conv2dK3x3S2<float16_t>::DoCompute(
const ConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data) {
p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t m = start1; m < end1; m += step1) {
for (index_t c = 0; c < p.in_channels; ++c) {
const float16_t
*in_base = input_data + b * p.in_batch_size + c * p.in_image_size;
const float16_t *filter_ptr =
filter_data + m * p.in_channels * 9 + c * 9;
float16_t *out_base =
output_data + b * p.out_batch_size + m * p.out_image_size;
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float16x8_t vf00, vf01;
vf00 = vld1q_f16(filter_ptr);
vf01 = vld1q_f16(filter_ptr + 8);
for (index_t h = 0; h < p.out_height; ++h) {
for (index_t w = 0; w + 7 < p.out_width; w += 8) {
float16x8x2_t vi0, vi1, vi2;
float16x8_t vi0n, vi1n, vi2n;
// input (3 height x 3 slide): vi_height_slide
float16x8_t vi00, vi01, vi02;
float16x8_t vi10, vi11, vi12;
float16x8_t vi20, vi21, vi22;
// output (1 outch x 1 height x 8 width): vo
float16x8_t vo;
// load input
index_t in_h = h * 2;
index_t in_w = w * 2;
index_t in_offset = in_h * p.in_width + in_w;
vi0 = vld2q_f16(in_base + in_offset); // [0.2.4.6, 1.3.5.7]
vi1 = vld2q_f16(in_base + in_offset + p.in_width);
vi2 = vld2q_f16(in_base + in_offset + 2 * p.in_width);
vi0n = vld1q_f16(in_base + in_offset + 8); // [8.9.10.11]
vi1n = vld1q_f16(in_base + in_offset + p.in_width + 8);
vi2n = vld1q_f16(in_base + in_offset + 2 * p.in_width + 8);
// load ouptut
index_t out_offset = h * p.out_width + w;
vo = vld1q_f16(out_base + out_offset);
vi00 = vi0.val[0]; // [0.2.4.6]
vi01 = vi0.val[1]; // [1.3.5.7]
vi02 = vextq_f16(vi00, vi0n, 1); // [2.4.6.8]
vi10 = vi1.val[0];
vi11 = vi1.val[1];
vi12 = vextq_f16(vi10, vi1n, 1);
vi20 = vi2.val[0];
vi21 = vi2.val[1];
vi22 = vextq_f16(vi20, vi2n, 1);
// outch 0, height 0
vo = vfmaq_laneq_f16(vo, vi00, vf00, 0);
vo = vfmaq_laneq_f16(vo, vi01, vf00, 1);
vo = vfmaq_laneq_f16(vo, vi02, vf00, 2);
vo = vfmaq_laneq_f16(vo, vi10, vf00, 3);
vo = vfmaq_laneq_f16(vo, vi11, vf00, 4);
vo = vfmaq_laneq_f16(vo, vi12, vf00, 5);
vo = vfmaq_laneq_f16(vo, vi20, vf00, 6);
vo = vfmaq_laneq_f16(vo, vi21, vf00, 7);
vo = vfmaq_laneq_f16(vo, vi22, vf01, 0);
vst1q_f16(out_base + out_offset, vo);
} // w
} // h
} // c
} // m
} // b
}, 0, p.batch, 1, 0, p.out_channels, 1);
return MaceStatus::MACE_SUCCESS;
}
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include "mace/ops/arm/base/depthwise_conv_2d_3x3.h"
namespace mace {
namespace ops {
namespace arm {
template<typename float16_t>
void DepthwiseConv2d3x3Pixel(const float16_t *in_base,
const float16_t *filter,
const index_t out_h,
const index_t out_w,
const index_t in_h_start,
const index_t in_w_start,
const index_t out_width,
const index_t in_height,
const index_t in_width,
float16_t *out_base) {
const index_t filter_width = 3;
float sum = 0.0f;
index_t in_h = in_h_start;
const float16_t *in = in_base + in_h * in_width;
const float16_t *filter_ptr = filter;
if (in_h >= 0 && in_h < in_height) {
index_t in_w = in_w_start;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[0];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[1];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[2];
}
}
in_h++;
in += in_width;
filter_ptr += filter_width;
if (in_h >= 0 && in_h < in_height) {
index_t in_w = in_w_start;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[0];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[1];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[2];
}
}
in_h++;
in += in_width;
filter_ptr += filter_width;
if (in_h >= 0 && in_h < in_height) {
index_t in_w = in_w_start;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[0];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[1];
}
in_w++;
if (in_w >= 0 && in_w < in_width) {
sum += in[in_w] * filter_ptr[2];
}
}
out_base[out_h * out_width + out_w] = static_cast<float16_t>(sum);
}
template<>
MaceStatus DepthwiseConv2dK3x3S1<float16_t>::DoCompute(
const DepthwiseConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data) {
p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t m = start1; m < end1; m += step1) {
const index_t c = m / p.multiplier;
const index_t multi_index = m % p.multiplier;
auto filter_ptr = filter_data + multi_index * p.in_channels * 9 + c * 9;
auto in_base = input_data + b * p.in_batch_size + c * p.in_image_size;
auto out_base = output_data + b * p.out_batch_size +
m * p.out_image_size;
index_t h, w;
// top
for (h = 0; h < p.valid_h_start; ++h) {
for (w = 0; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
}
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float16x8_t vf00, vf01;
vf00 = vld1q_f16(filter_ptr);
vf01 = vld1q_f16(filter_ptr + 8);
for (h = p.valid_h_start; h + 1 < p.valid_h_stop; h += 2) {
// left
for (w = 0; w < p.valid_w_start; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h + 1,
w,
h + 1 - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
for (w = p.valid_w_start; w + 7 < p.valid_w_stop; w += 8) {
// input (4 height x 3 slide): vi_height_slide
float16x8_t vi00, vi01, vi02, vi0n;
float16x8_t vi10, vi11, vi12, vi1n;
float16x8_t vi20, vi21, vi22, vi2n;
float16x8_t vi30, vi31, vi32, vi3n;
// output (1 outch x 2 height x 8 width): vo_outch_height
float16x8_t vo00, vo01;
// load input
index_t in_h = h - p.pad_top;
index_t in_w = w - p.pad_left;
index_t in_offset = in_h * p.in_width + in_w;
vi00 = vld1q_f16(in_base + in_offset);
vi0n = vld1q_f16(in_base + in_offset + 8);
vi10 = vld1q_f16(in_base + in_offset + p.in_width);
vi1n = vld1q_f16(in_base + in_offset + p.in_width + 8);
vi20 = vld1q_f16(in_base + in_offset + 2 * p.in_width);
vi2n = vld1q_f16(in_base + in_offset + 2 * p.in_width + 8);
vi30 = vld1q_f16(in_base + in_offset + 3 * p.in_width);
vi3n = vld1q_f16(in_base + in_offset + 3 * p.in_width + 8);
vi01 = vextq_f16(vi00, vi0n, 1);
vi02 = vextq_f16(vi00, vi0n, 2);
vi11 = vextq_f16(vi10, vi1n, 1);
vi12 = vextq_f16(vi10, vi1n, 2);
vi21 = vextq_f16(vi20, vi2n, 1);
vi22 = vextq_f16(vi20, vi2n, 2);
vi31 = vextq_f16(vi30, vi3n, 1);
vi32 = vextq_f16(vi30, vi3n, 2);
// load ouptut
index_t out_offset = h * p.out_width + w;
vo00 = vld1q_f16(out_base + out_offset);
vo01 = vld1q_f16(out_base + out_offset + p.out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f16(vo00, vi00, vf00, 0);
vo00 = vfmaq_laneq_f16(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f16(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f16(vo00, vi10, vf00, 3);
vo00 = vfmaq_laneq_f16(vo00, vi11, vf00, 4);
vo00 = vfmaq_laneq_f16(vo00, vi12, vf00, 5);
vo00 = vfmaq_laneq_f16(vo00, vi20, vf00, 6);
vo00 = vfmaq_laneq_f16(vo00, vi21, vf00, 7);
vo00 = vfmaq_laneq_f16(vo00, vi22, vf01, 0);
// outch 0, height 1
vo01 = vfmaq_laneq_f16(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f16(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f16(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f16(vo01, vi20, vf00, 3);
vo01 = vfmaq_laneq_f16(vo01, vi21, vf00, 4);
vo01 = vfmaq_laneq_f16(vo01, vi22, vf00, 5);
vo01 = vfmaq_laneq_f16(vo01, vi30, vf00, 6);
vo01 = vfmaq_laneq_f16(vo01, vi31, vf00, 7);
vo01 = vfmaq_laneq_f16(vo01, vi32, vf01, 0);
vst1q_f16(out_base + out_offset, vo00);
vst1q_f16(out_base + out_offset + p.out_width, vo01);
} // w
// right
for (; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h + 1,
w,
h + 1 - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
} // h
// bottom
for (; h < p.out_height; ++h) {
for (w = 0; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h - p.pad_top,
w - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
}
} // m
} // b
}, 0, p.batch, 1, 0, p.out_channels, 1); // threadpool
return MaceStatus::MACE_SUCCESS;
}
template<>
MaceStatus DepthwiseConv2dK3x3S2<float16_t>::DoCompute(
const DepthwiseConvComputeParam &p, const float16_t *filter_data,
const float16_t *input_data, float16_t *output_data) {
p.thread_pool.Compute2D(
[=](index_t start0, index_t end0, index_t step0, index_t start1,
index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t m = start1; m < end1; m += step1) {
index_t c = m / p.multiplier;
index_t multi_index = m % p.multiplier;
auto filter_ptr = filter_data + multi_index * p.in_channels * 9 +
c * 9;
auto in_base = input_data + b * p.in_batch_size +
c * p.in_image_size;
auto out_base = output_data + b * p.out_batch_size +
m * p.out_image_size;
index_t h, w;
// top
for (h = 0; h < p.valid_h_start; ++h) {
for (w = 0; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h * 2 - p.pad_top,
w * 2 - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
}
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float16x8_t vf00, vf01;
vf00 = vld1q_f16(filter_ptr);
vf01 = vld1q_f16(filter_ptr + 8);
for (h = p.valid_h_start; h < p.valid_h_stop; ++h) {
// left
for (w = 0; w < p.valid_w_start; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h * 2 - p.pad_top,
w * 2 - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
for (w = p.valid_w_start; w + 3 < p.valid_w_stop; w += 8) {
float16x8x2_t vi0, vi1, vi2;
float16x8_t vi0n, vi1n, vi2n;
// input (3 height x 3 slide): vi_height_slide
float16x8_t vi00, vi01, vi02;
float16x8_t vi10, vi11, vi12;
float16x8_t vi20, vi21, vi22;
// output (1 outch x 1 height x 8 width): vo
float16x8_t vo;
// load input
index_t in_h = h * 2 - p.pad_top;
index_t in_w = w * 2 - p.pad_left;
index_t in_offset = in_h * p.in_width + in_w;
vi0 = vld2q_f16(in_base + in_offset); // [0.2.4.6, 1.3.5.7]
vi1 = vld2q_f16(in_base + in_offset + p.in_width);
vi2 = vld2q_f16(in_base + in_offset + 2 * p.in_width);
vi0n = vld1q_f16(in_base + in_offset + 16); // [8.9.10.11]
vi1n = vld1q_f16(in_base + in_offset + p.in_width + 16);
vi2n = vld1q_f16(in_base + in_offset + 2 * p.in_width + 16);
// load ouptut
index_t out_offset = h * p.out_width + w;
vo = vld1q_f16(out_base + out_offset);
vi00 = vi0.val[0]; // [0.2.4.6]
vi01 = vi0.val[1]; // [1.3.5.7]
vi02 = vextq_f16(vi00, vi0n, 1); // [2.4.6.8]
vi10 = vi1.val[0];
vi11 = vi1.val[1];
vi12 = vextq_f16(vi10, vi1n, 1);
vi20 = vi2.val[0];
vi21 = vi2.val[1];
vi22 = vextq_f16(vi20, vi2n, 1);
// outch 0, height 0
vo = vfmaq_laneq_f16(vo, vi00, vf00, 0);
vo = vfmaq_laneq_f16(vo, vi01, vf00, 1);
vo = vfmaq_laneq_f16(vo, vi02, vf00, 2);
vo = vfmaq_laneq_f16(vo, vi10, vf00, 3);
vo = vfmaq_laneq_f16(vo, vi11, vf00, 4);
vo = vfmaq_laneq_f16(vo, vi12, vf00, 5);
vo = vfmaq_laneq_f16(vo, vi20, vf00, 6);
vo = vfmaq_laneq_f16(vo, vi21, vf00, 7);
vo = vfmaq_laneq_f16(vo, vi22, vf01, 0);
vst1q_f16(out_base + out_offset, vo);
} // w
// right
for (; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h * 2 - p.pad_top,
w * 2 - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
} // h
// bottom
for (; h < p.out_height; ++h) {
for (w = 0; w < p.out_width; ++w) {
DepthwiseConv2d3x3Pixel(in_base,
filter_ptr,
h,
w,
h * 2 - p.pad_top,
w * 2 - p.pad_left,
p.out_width,
p.in_height,
p.in_width,
out_base);
}
}
} // m
} // b
},
0, p.batch, 1, 0, p.out_channels, 1);
return MaceStatus::MACE_SUCCESS;
}
} // namespace arm
} // namespace ops
} // namespace mace
......@@ -507,7 +507,7 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
void RegisterConv2D(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Conv2D", Conv2dOp, DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "Conv2D", Conv2dOp, DeviceType::CPU);
MACE_REGISTER_FP16_OP(op_registry, "Conv2D", Conv2dOp, DeviceType::CPU);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Conv2D", Conv2dOp,
DeviceType::CPU, uint8_t);
......
......@@ -406,6 +406,8 @@ void RegisterDepthwiseConv2d(OpRegistry *op_registry) {
DepthwiseConv2dOp, DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::CPU);
MACE_REGISTER_FP16_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::CPU);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "DepthwiseConv2d",
......
......@@ -518,6 +518,8 @@ void RegisterPooling(OpRegistry *op_registry) {
DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "Pooling", PoolingOp,
DeviceType::CPU);
MACE_REGISTER_FP16_OP(op_registry, "Pooling", PoolingOp,
DeviceType::CPU);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Pooling", PoolingOp,
......
......@@ -130,6 +130,10 @@ void RegisterActivationDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Activation<BFloat16>, delegator::ActivationParam,
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, Activation<float16_t>, delegator::ActivationParam,
MACE_DELEGATOR_KEY(Activation, DeviceType::CPU,
float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -152,6 +152,9 @@ void RegisterBiasAddDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, BiasAdd<BFloat16>, DelegatorParam,
MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, BiasAdd<float16_t>, DelegatorParam,
MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -131,6 +131,9 @@ void RegisterConv2dDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Conv2d<BFloat16>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY(Conv2d, DeviceType::CPU, BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, Conv2d<float16_t>, delegator::Conv2dParam,
MACE_DELEGATOR_KEY(Conv2d, DeviceType::CPU, float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -137,6 +137,10 @@ void RegisterDepthwiseConv2dDelegator(OpDelegatorRegistry *registry) {
registry, DepthwiseConv2d<BFloat16>, delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY(DepthwiseConv2d, DeviceType::CPU,
BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, DepthwiseConv2d<float16_t>, delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY(DepthwiseConv2d, DeviceType::CPU,
float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -156,6 +156,9 @@ void RegisterGemmDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Gemm<BFloat16>, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, Gemm<float16_t>, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -92,6 +92,9 @@ void RegisterGemvDelegator(OpDelegatorRegistry *registry) {
MACE_REGISTER_BF16_DELEGATOR(
registry, Gemv<BFloat16>, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, DeviceType::CPU, BFloat16, ImplType::REF));
MACE_REGISTER_FP16_DELEGATOR(
registry, Gemv<float16_t>, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, DeviceType::CPU, float16_t, ImplType::REF));
}
} // namespace ref
......
......@@ -526,6 +526,8 @@ void RegisterSoftmax(OpRegistry *op_registry) {
DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "Softmax", SoftmaxOp,
DeviceType::CPU);
MACE_REGISTER_FP16_OP(op_registry, "Softmax", SoftmaxOp,
DeviceType::CPU);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Softmax", SoftmaxOp,
......
......@@ -81,6 +81,7 @@ class SqueezeOp : public SqueezeOpRaw {
void RegisterSqueeze(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU);
MACE_REGISTER_FP16_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU);
#ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, uint8_t);
#endif // MACE_ENABLE_QUANTIZE
......
......@@ -11,6 +11,7 @@ load(
"if_hexagon_enabled",
"if_neon_enabled",
"if_bfloat16_enabled",
"if_fp16_enabled",
"if_opencl_enabled",
"if_quantize_enabled",
)
......@@ -61,6 +62,9 @@ cc_test(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
......
......@@ -181,6 +181,12 @@ void Conv2d<CPU, uint8_t>(int iters,
#else
#define MACE_BM_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, D, P, OC)
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
#define MACE_BM_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, D, P, OC) \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float16_t, CPU)
#else
#define MACE_BM_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, D, P, OC)
#endif // MACE_ENABLE_FP16
#ifdef MACE_ENABLE_OPENCL
#define MACE_BM_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, D, P, OC) \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, GPU); \
......@@ -193,6 +199,7 @@ void Conv2d<CPU, uint8_t>(int iters,
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \
MACE_BM_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, D, P, OC); \
MACE_BM_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, D, P, OC); \
MACE_BM_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, D, P, OC); \
MACE_BM_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, D, P, OC)
// Filter sizes and data alignments
......
......@@ -140,6 +140,12 @@ void DepthwiseConv2d(int iters,
#else
#define MACE_BM_DEPTHWISE_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, P, M)
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
#define MACE_BM_DEPTHWISE_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, P, M) \
MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float16_t, CPU)
#else
#define MACE_BM_DEPTHWISE_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, P, M)
#endif // MACE_ENABLE_FP16
#ifdef MACE_ENABLE_OPENCL
#define MACE_BM_DEPTHWISE_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, P, M) \
MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, GPU); \
......@@ -152,6 +158,7 @@ void DepthwiseConv2d(int iters,
MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \
MACE_BM_DEPTHWISE_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, P, M); \
MACE_BM_DEPTHWISE_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, P, M); \
MACE_BM_DEPTHWISE_CONV_2D_FP16_MACRO(N, C, H, W, KH, KW, S, P, M); \
MACE_BM_DEPTHWISE_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, P, M)
MACE_BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1);
......
......@@ -12,6 +12,7 @@ load(
"if_hta_enabled",
"if_neon_enabled",
"if_bfloat16_enabled",
"if_fp16_enabled",
"if_opencl_enabled",
"if_quantize_enabled",
)
......@@ -42,6 +43,10 @@ cc_test(
[
"mace/ops/arm/bf16/*.cc",
]
)) + if_fp16_enabled(glob(
[
"mace/ops/arm/fp16/*.cc",
]
)) + if_opencl_enabled(glob(
[
"mace/ops/opencl/*.cc",
......@@ -66,6 +71,9 @@ cc_test(
"-DMACE_ENABLE_QUANTIZE",
]) + if_bfloat16_enabled([
"-DMACE_ENABLE_BFLOAT16",
]) + if_fp16_enabled([
"-DMACE_ENABLE_FP16",
"-march=armv8.2-a+fp16",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/delegator/gemm.h"
#include <gtest/gtest.h>
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
namespace test {
void TestGemmFloat16(const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched) {
Tensor lhs(GetCPUAllocator(), DT_FLOAT16);
Tensor rhs(GetCPUAllocator(), DT_FLOAT16);
Tensor output(GetCPUAllocator(), DT_FLOAT16);
lhs.Resize({lhs_batched ? batch : 1, rows, depth});
rhs.Resize({rhs_batched ? batch : 1, depth, cols});
output.Resize({batch, rows, cols});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
auto lhs_data = lhs.mutable_data<float16_t>();
auto rhs_data = rhs.mutable_data<float16_t>();
auto output_data = output.mutable_data<float16_t>();
GenerateRandomRealTypeData<float16_t>(lhs.shape(), lhs_data);
GenerateRandomRealTypeData<float16_t>(rhs.shape(), rhs_data);
GenerateRandomRealTypeData<float16_t>(output.shape(), output_data);
}
utils::ThreadPool thread_pool(1, AFFINITY_NONE);
thread_pool.Init();
CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool);
OpsTestNet net;
OpContext context(net.ws(), &cpu_device);
std::unique_ptr<delegator::Gemm> gemm = delegator::Gemm::Create(
context.workspace(),
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::NEON),
delegator::GemmParam());
gemm->Compute(&context, &lhs, &rhs, batch, rows, cols, depth, lhs_major,
rhs_major, output_major, lhs_batched, rhs_batched, &output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT16);
expected_output.Resize({batch, rows, cols});
std::unique_ptr<delegator::Gemm> gemm_ref = delegator::Gemm::Create(
context.workspace(),
MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float16_t, ImplType::REF),
delegator::GemmParam());
gemm_ref->Compute(&context, &lhs, &rhs, batch, rows, cols, depth, lhs_major,
rhs_major, output_major, lhs_batched, rhs_batched,
&expected_output);
ExpectTensorSimilar<float16_t>(expected_output, output, 1e-4);
}
TEST(ArmGemm, TestGemmFP16) {
TestGemmFloat16(1, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat16(1, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true);
TestGemmFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, false);
TestGemmFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, false, true);
TestGemmFloat16(16, 31, 61, 67, RowMajor, ColMajor, RowMajor, true, true);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -1435,6 +1435,74 @@ TEST_F(Conv2dOpTest, BFloat16) {
TestBFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3});
}
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
namespace {
void TestFloat16(const index_t batch,
const index_t out_channels,
const index_t in_channels,
const index_t in_height,
const index_t in_width,
const index_t k_height,
const index_t k_width,
enum Padding padding_type,
const std::vector<int> &strides) {
OpsTestNet net;
net.AddRandomInput<CPU, float16_t>(
"Input", {batch, in_channels, in_height, in_width});
net.AddRandomInput<CPU, float16_t>(
"Filter", {out_channels, in_channels, k_height, k_width}, true);
net.AddRandomInput<CPU, float16_t>("Bias", {out_channels}, true);
net.Cast<CPU, float, float16_t>("Input", "FP16Input");
net.Cast<CPU, float, float16_t>("Filter", "FP16Filter");
net.Cast<CPU, float, float16_t>("Bias", "FP16Bias");
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
OpDefBuilder("Conv2D", "FP16Conv2dTest")
.Input("FP16Input")
.Input("FP16Filter")
.Input("FP16Bias")
.Output("FP16Output")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_FLOAT16))
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
net.Cast<CPU, float16_t, float>("FP16Output", "CastOutput");
ExpectTensorSimilar<float>(*net.GetOutput("Output"),
*net.GetTensor("CastOutput"), 1e-4);
}
} // namespace
TEST_F(Conv2dOpTest, float16_t) {
TestFloat16(1, 128, 64, 32, 32, 1, 1, VALID, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 3, 3, VALID, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 3, 3, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 3, 3, FULL, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 3, 3, SAME, {2, 2});
TestFloat16(1, 129, 63, 33, 31, 3, 3, SAME, {1, 1});
TestFloat16(9, 128, 64, 32, 32, 3, 3, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 1, 5, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 5, 5, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 5, 1, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {1, 1});
TestFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {2, 2});
TestFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3});
}
#endif // MACE_ENABLE_FP16
} // namespace test
} // namespace ops
} // namespace mace
......@@ -560,6 +560,72 @@ TEST_F(DepthwiseConv2dOpTest, BFloat16) {
#endif // MACE_ENABLE_BFLOAT16
#ifdef MACE_ENABLE_FP16
namespace {
void TestFloat16(const index_t batch,
const index_t multiplier,
const index_t in_channels,
const index_t in_height,
const index_t in_width,
const index_t k_height,
const index_t k_width,
enum Padding padding_type,
const std::vector<int> &strides) {
OpsTestNet net;
const index_t out_channels = multiplier * in_channels;
net.AddRandomInput<CPU, float16_t>(
"Input", {batch, in_channels, in_height, in_width}, false, false);
net.AddRandomInput<CPU, float16_t>(
"Filter", {multiplier, in_channels, k_height, k_width}, true, false);
net.AddRandomInput<CPU, float16_t>("Bias", {out_channels}, true);
net.Cast<CPU, float, float16_t>("Input", "FP16Input");
net.Cast<CPU, float, float16_t>("Filter", "FP16Filter");
net.Cast<CPU, float, float16_t>("Bias", "FP16Bias");
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
OpDefBuilder("DepthwiseConv2d", "FP16DepthwiseConv2DTest")
.Input("FP16Input")
.Input("FP16Filter")
.Input("FP16Bias")
.Output("FP16Output")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_FLOAT16))
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
net.Cast<CPU, float16_t, float>("FP16Output", "CastOutput");
ExpectTensorSimilar<float>(*net.GetOutput("Output"),
*net.GetTensor("CastOutput"), 1e-4);
}
} // namespace
TEST_F(DepthwiseConv2dOpTest, float16_t) {
TestFloat16(1, 1, 1024, 7, 7, 3, 3, VALID, {1, 1});
TestFloat16(1, 1, 1024, 7, 7, 3, 3, SAME, {1, 1});
TestFloat16(1, 1, 1024, 7, 7, 3, 3, FULL, {1, 1});
TestFloat16(1, 2, 1024, 7, 7, 3, 3, SAME, {1, 1});
TestFloat16(1, 2, 1024, 7, 7, 3, 3, SAME, {2, 2});
TestFloat16(1, 1, 512, 14, 14, 3, 3, SAME, {1, 1});
TestFloat16(1, 1, 512, 14, 13, 5, 5, SAME, {2, 2});
TestFloat16(1, 1, 256, 28, 28, 3, 3, SAME, {1, 1});
TestFloat16(1, 1, 128, 56, 56, 3, 3, SAME, {2, 2});
TestFloat16(3, 1, 128, 56, 56, 3, 3, SAME, {2, 2});
}
#endif // MACE_ENABLE_FP16
} // namespace test
} // namespace ops
} // namespace mace
......@@ -100,6 +100,11 @@ def parse_args():
type=str2bool,
default=True,
help="Whether to use bfloat16")
parser.add_argument(
"--enable_fp16",
type=str2bool,
default=False,
help="Whether to use armv8.2")
parser.add_argument(
"--enable_rpcmem",
type=str2bool,
......@@ -180,6 +185,7 @@ def main(unused_args):
enable_neon=FLAGS.enable_neon,
enable_quantize=FLAGS.enable_quantize,
enable_bfloat16=FLAGS.enable_bfloat16,
enable_fp16=FLAGS.enable_fp16,
enable_rpcmem=FLAGS.enable_rpcmem,
enable_hta=FLAGS.enable_hta,
address_sanitizer=FLAGS.address_sanitizer,
......
......@@ -89,6 +89,7 @@ FPDataTypeStrs = [
"fp16_fp32",
"fp32_fp32",
"bf16_fp32",
"fp16_fp16",
]
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
......@@ -184,6 +185,15 @@ def bfloat16_enabled(configs):
return False
def fp16_enabled(configs):
for model_name in configs[YAMLKeyword.models]:
model_config = configs[YAMLKeyword.models][model_name]
dtype = model_config.get(YAMLKeyword.data_type, FPDataType.fp16_fp32)
if dtype == FPDataType.fp16_fp16:
return True
return False
def hexagon_enabled(configs):
runtime_list = []
for model_name in configs[YAMLKeyword.models]:
......@@ -765,6 +775,7 @@ def build_model_lib(configs, address_sanitizer, debug_mode):
enable_opencl=opencl_enabled(configs),
enable_quantize=quantize_enabled(configs),
enable_bfloat16=bfloat16_enabled(configs),
enable_fp16=fp16_enabled(configs),
address_sanitizer=address_sanitizer,
symbol_hidden=get_symbol_hidden_mode(debug_mode),
debug_mode=debug_mode
......@@ -927,6 +938,7 @@ def build_mace_run(configs, target_abi, toolchain,
enable_opencl=opencl_enabled(configs),
enable_quantize=quantize_enabled(configs),
enable_bfloat16=bfloat16_enabled(configs),
enable_fp16=fp16_enabled(configs),
address_sanitizer=address_sanitizer,
symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),
debug_mode=debug_mode,
......
......@@ -182,6 +182,8 @@ def parse_internal_data_type(str):
return mace_pb2.DT_FLOAT
elif str == 'bf16_fp32':
return mace_pb2.DT_BFLOAT16
elif str == 'fp16_fp16':
return mace_pb2.DT_FLOAT16
else:
return mace_pb2.DT_HALF
......
......@@ -271,6 +271,7 @@ def bazel_build(target,
enable_opencl=True,
enable_quantize=True,
enable_bfloat16=False,
enable_fp16=False,
enable_rpcmem=True,
address_sanitizer=False,
symbol_hidden=True,
......@@ -305,6 +306,8 @@ def bazel_build(target,
"--define",
"bfloat16=%s" % str(enable_bfloat16).lower(),
"--define",
"fp16=%s" % str(enable_fp16).lower(),
"--define",
"rpcmem=%s" % str(enable_rpcmem).lower(),
"--define",
"hexagon=%s" % str(enable_hexagon).lower(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册