提交 2d650b67 编写于 作者: 許永健 提交者: Liangliang He

Integrate MediaTek APU Support (#440)

* Integrate APU

* fix code style issue
上级 9fbb9e17
......@@ -74,6 +74,18 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "apu_enabled",
define_values = {
"apu": "true",
},
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "arm64-v8a",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "hexagon_enabled",
define_values = {
......
......@@ -11,10 +11,12 @@ load(
"//mace:mace.bzl",
"if_android",
"if_android_armv7",
"if_apu_enabled",
"if_hexagon_enabled",
"if_hexagon_or_hta_enabled",
"if_hta_enabled",
"if_neon_enabled",
"if_not_apu_enabled",
"if_not_hexagon_enabled",
"if_opencl_enabled",
"if_openmp_enabled",
......@@ -39,7 +41,9 @@ cc_library(
"runtime/hexagon/hexagon_dsp_wrapper.cc",
]) + if_hta_enabled([
"runtime/hexagon/hexagon_hta_wrapper.cc",
]),
]) + if_apu_enabled(glob([
"runtime/apu/*.cc",
])),
hdrs = glob([
"*.h",
"runtime/cpu/*.h",
......@@ -52,6 +56,8 @@ cc_library(
"runtime/hexagon/*dsp*.h",
])) + if_hta_enabled(glob([
"runtime/hexagon/*hta*.h",
])) + if_apu_enabled(glob([
"runtime/apu/*.h"
])),
copts = [
"-Werror",
......@@ -68,6 +74,8 @@ cc_library(
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
"-DMACE_ENABLE_HTA",
]) + if_apu_enabled([
"-DMACE_ENABLE_APU",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
......@@ -90,6 +98,8 @@ cc_library(
"//third_party/nnlib:libhexagon",
]) + if_hta_enabled([
"//third_party/hta",
]) + if_apu_enabled([
"//third_party/apu:libapu-frontend",
]),
)
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
#define MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
#include "mace/core/device.h"
namespace mace {
class ApuDevice : public CPUDevice {
public:
explicit ApuDevice(utils::ThreadPool *thread_pool)
: CPUDevice(0, AFFINITY_NONE, thread_pool) {}
DeviceType device_type() const override {
return DeviceType::APU;
};
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/runtime/apu/apu_wrapper.h"
#include <algorithm>
#include "mace/core/quantize.h"
namespace mace {
ApuWrapper::ApuWrapper(Device *device)
: quantize_util_(&device->cpu_runtime()->thread_pool()) {
}
apu_data_type ApuWrapper::MapToApuDataType(DataType mace_type) {
switch (mace_type) {
case DT_FLOAT:
return APU_DATA_TYPE_FLOAT;
case DT_INT32:
return APU_DATA_TYPE_INT32;
case DT_HALF:
return APU_DATA_TYPE_HALF;
case DT_UINT8:
return APU_DATA_TYPE_UINT8;
default:
MACE_CHECK(true, "unsupport mace data type");
break;
}
return APU_DATA_TYPE_UNDEFINED;
}
apu_pooling_mode ApuWrapper::MapToApuPoolingMode(int mace_mode) {
switch (mace_mode) {
case 1:
return APU_POOLING_AVG;
case 2:
return APU_POOLING_MAX;
default:
MACE_CHECK(true, "unsupport mace pooling mode");
break;
}
return APU_POOLING_UNDEFINED;
}
apu_eltwise_mode ApuWrapper::MapToApuEltwiseMode(int mace_mode) {
switch (mace_mode) {
case 0:
return APU_ELTWISE_ADD;
case 1:
return APU_ELTWISE_SUB;
case 2:
return APU_ELTWISE_MUL;
case 4:
return APU_ELTWISE_MIN;
case 5:
return APU_ELTWISE_MAX;
default:
MACE_CHECK(true, "unsupport mace eltwise mode");
break;
}
return APU_ELTWISE_UNDEFINED;
}
bool ApuWrapper::Init(const NetDef &net_def,
unsigned const char *model_data) {
frontend = new ApuFrontend();
// parse model argument
int const_data_num = 0;
for (auto arg : net_def.arg()) {
if (arg.name().compare("const_data_num") == 0) {
const_data_num = arg.i();
}
}
// const tensors
std::vector<apu_tensor> const_tensors;
for (auto const_tensor : net_def.tensors()) {
apu_tensor tensor;
tensor.tensor_id = const_tensor.node_id();
tensor.tensor_type = (tensor.tensor_id < const_data_num) ?
APU_TENSOR_CONST_DATA :
APU_TENSOR_CONST_ARGUMENT;
tensor.data_type = MapToApuDataType(const_tensor.data_type());
tensor.scale = const_tensor.has_scale() ? const_tensor.scale() : 0.0f;
tensor.zero_point = const_tensor.has_zero_point() ?
const_tensor.zero_point() : 0;
tensor.dim_size = const_tensor.dims_size();
MACE_CHECK(tensor.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
for (auto i = 0 ; i < tensor.dim_size ; i++) {
tensor.dims[i] = const_tensor.dims(i);
}
tensor.data_buf =
const_cast<unsigned char*>(model_data + const_tensor.offset());
const_tensors.push_back(tensor);
}
// input tensors
std::vector<apu_tensor> input_tensors;
for (auto input_info : net_def.input_info()) {
apu_tensor tensor;
tensor.tensor_id = input_info.node_id();
tensor.tensor_type = APU_TENSOR_MODEL_INPUT;
tensor.data_type = APU_DATA_TYPE_UINT8; // will do quantize in Run()
tensor.scale = input_info.has_scale() ? input_info.scale() : 0.0f;
tensor.zero_point = input_info.has_zero_point() ?
input_info.zero_point() : 0;
tensor.dim_size = input_info.dims_size();
MACE_CHECK(tensor.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
tensor_info info;
info.name = input_info.name();
info.size = 1;
for (auto i = 0 ; i < tensor.dim_size ; i++) {
tensor.dims[i] = input_info.dims(i);
info.size *= input_info.dims(i);
info.shape.push_back(input_info.dims(i));
}
info.buf = std::shared_ptr<uint8_t>(new uint8_t[info.size],
std::default_delete<uint8_t[]>());
info.scale = tensor.scale;
info.zero_point = tensor.zero_point;
input_infos.push_back(info);
tensor.data_buf = info.buf.get();
input_tensors.push_back(tensor);
}
// output tensors
std::vector<int> output_tensor_ids;
std::vector<void*> output_buffers;
for (auto output_info : net_def.output_info()) {
output_tensor_ids.push_back(output_info.node_id());
tensor_info info;
info.name = output_info.name();
info.size = 1;
for (auto i = 0 ; i < output_info.dims().size() ; i++) {
info.size *= output_info.dims(i);
info.shape.push_back(output_info.dims(i));
}
info.buf = std::shared_ptr<uint8_t>(new uint8_t[info.size],
std::default_delete<uint8_t[]>());
for (auto op_def : net_def.op()) {
if (output_info.name() == op_def.output(0)) {
info.scale = op_def.quantize_info(0).scale();
info.zero_point = op_def.quantize_info(0).zero_point();
}
}
output_infos.push_back(info);
output_buffers.push_back(info.buf.get());
}
// operators
std::vector<apu_operator> ops;
std::vector<std::vector<int>> cached_op_inputs;
for (auto op_def : net_def.op()) {
apu_operator op;
strncpy(op.type, op_def.type().c_str(), APU_OP_TYPE_MAX_SIZE);
op.input_size = op_def.node_input_size();
std::vector<int> input_ids;
for (auto i = 0 ; i < op.input_size ; i++) {
input_ids.push_back(op_def.node_input(i).node_id());
}
cached_op_inputs.push_back(input_ids);
op.input_ids = cached_op_inputs.back().data();
op.output.tensor_id = op_def.node_id();
op.output.tensor_type = APU_TENSOR_OP_OUTPUT;
op.output.data_type = MapToApuDataType(op_def.output_type(0));
if (op.output.data_type == APU_DATA_TYPE_UINT8) {
op.output.scale = op_def.quantize_info(0).scale();
op.output.zero_point = op_def.quantize_info(0).zero_point();
} else {
op.output.scale = 0.0f;
op.output.zero_point = 0;
}
op.output.dim_size = op_def.output_shape(0).dims_size();
MACE_CHECK(op.output.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
for (auto i = 0 ; i < op.output.dim_size ; i++) {
op.output.dims[i] = op_def.output_shape(0).dims(i);
}
op.output.data_buf = nullptr;
// get op mode and activation mode
bool is_pooling = (strcmp(op.type, "Pooling") == 0);
bool is_eltwise = (strcmp(op.type, "Eltwise") == 0);
std::string activation;
float max_limit = 0.0f;
for (auto arg : op_def.arg()) {
if (arg.name().compare("activation") == 0) {
activation = arg.s();
}
if (arg.name().compare("max_limit") == 0) {
max_limit = arg.f();
}
if (is_pooling && arg.name().compare("pooling_type") == 0) {
op.op_mode = static_cast<int>(MapToApuPoolingMode(arg.i()));
}
if (is_eltwise && arg.name().compare("type") == 0) {
op.op_mode = static_cast<int>(MapToApuEltwiseMode(arg.i()));
}
}
if (activation.compare("RELU") == 0) {
op.act_mode = APU_ACT_RELU;
} else if (activation.compare("RELUX") == 0 && max_limit == 6.0) {
op.act_mode = APU_ACT_RELU6;
} else {
op.act_mode = APU_ACT_NONE;
}
ops.push_back(op);
}
bool print_model = false;
bool ret = frontend->InitGraph(
const_tensors.size(), const_tensors.data(),
input_tensors.size(), input_tensors.data(),
output_tensor_ids.size(), output_tensor_ids.data(),
output_buffers.data(),
ops.size(), ops.data(),
print_model);
cached_op_inputs.clear();
MACE_CHECK(ret == true, "apu init graph failed");
return ret;
}
bool ApuWrapper::Run(const std::map<std::string, Tensor *> &input_tensors,
std::map<std::string, Tensor *> *output_tensors) {
MACE_ASSERT(input_tensors.size() == input_infos.size(), "Wrong inputs num");
MACE_ASSERT(output_tensors.size() == output_infos.size(),
"Wrong outputs num");
// prepare input
for (int i = 0 ; i < static_cast<int>(input_tensors.size()) ; i++) {
Tensor* tensor = input_tensors.at(input_infos[i].name);
// check size
int size = input_infos[i].size;
MACE_ASSERT(size == static_cast<int>(tensor->size()), "Wrong input size");
// quantize
quantize_util_.QuantizeWithScaleAndZeropoint(
(const float*)tensor->raw_data(),
size,
input_infos[i].scale,
input_infos[i].zero_point,
input_infos[i].buf.get());
}
// run model
bool ret = frontend->RunGraph();
MACE_CHECK(ret == true, "neuron run model failed");
// process output
for (int i = 0 ; i < static_cast<int>(output_tensors->size()) ; i++) {
Tensor* tensor = output_tensors->at(output_infos[i].name);
// prepare out buffer
tensor->SetDtype(DT_FLOAT);
tensor->Resize(output_infos[i].shape);
int size = output_infos[i].size;
MACE_ASSERT(size == static_cast<int>(tensor->size()), "Wrong output size");
// dequantize
quantize_util_.Dequantize(
output_infos[i].buf.get(),
size,
output_infos[i].scale,
output_infos[i].zero_point,
reinterpret_cast<float*>(tensor->raw_mutable_data()));
}
return true;
}
bool ApuWrapper::Uninit() {
bool ret = frontend->UninitGraph();
frontend = nullptr;
input_infos.clear();
output_infos.clear();
return ret;
}
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
#define MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "mace/proto/mace.pb.h"
#include "mace/core/tensor.h"
#include "mace/core/device.h"
#include "mace/core/quantize.h"
#include "third_party/apu/ApuFrontend.h"
namespace mace {
class ApuWrapper {
struct tensor_info {
std::string name;
std::shared_ptr<uint8_t> buf;
std::vector<index_t> shape;
int size;
float scale;
int zero_point;
};
public:
explicit ApuWrapper(Device *device);
bool Init(const NetDef& net_def, unsigned const char *model_data);
bool Run(const std::map<std::string, Tensor *> &input_tensors,
std::map<std::string, Tensor *> *output_tensors);
bool Uninit();
private:
apu_data_type MapToApuDataType(DataType mace_type);
apu_pooling_mode MapToApuPoolingMode(int mace_mode);
apu_eltwise_mode MapToApuEltwiseMode(int mace_mode);
private:
ApuFrontend* frontend;
std::vector<tensor_info> input_infos;
std::vector<tensor_info> output_infos;
QuantizeUtil<uint8_t> quantize_util_;
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
......@@ -11,6 +11,7 @@ load(
"//mace:mace.bzl",
"if_android",
"if_android_armv7",
"if_apu_enabled",
"if_darwin",
"if_hexagon_enabled",
"if_hta_enabled",
......@@ -44,6 +45,8 @@ cc_library(
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
"-DMACE_ENABLE_HTA",
]) + if_apu_enabled([
"-DMACE_ENABLE_APU",
]),
deps = [
"//mace/ops",
......
......@@ -38,7 +38,13 @@
#include "mace/core/runtime/hexagon/hexagon_device.h"
#endif
#ifdef MACE_ENABLE_APU
#include "mace/core/runtime/apu/apu_wrapper.h"
#include "mace/core/runtime/apu/apu_device.h"
#endif // MACE_ENABLE_APU
namespace mace {
namespace {
#ifdef MACE_ENABLE_OPENCL
......@@ -398,6 +404,9 @@ class MaceEngine::Impl {
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif
#ifdef MACE_ENABLE_APU
std::unique_ptr<ApuWrapper> apu_controller_;
#endif
MACE_DISABLE_COPY_AND_ASSIGN(Impl);
};
......@@ -415,6 +424,9 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
, hexagon_controller_(nullptr)
#endif
#ifdef MACE_ENABLE_APU
, apu_controller_(nullptr)
#endif
{
LOG(INFO) << "Creating MaceEngine, MACE version: " << MaceVersion();
thread_pool_->Init();
......@@ -441,6 +453,11 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
|| device_type_ == DeviceType::HTA) {
device_.reset(new HexagonDevice(device_type_, thread_pool_.get()));
}
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == DeviceType::APU) {
device_.reset(new ApuDevice(thread_pool_.get()));
}
#endif
MACE_CHECK_NOTNULL(device_);
}
......@@ -497,6 +514,11 @@ MaceStatus MaceEngine::Impl::Init(
Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), output_dt);
output_tensor->set_data_format(DataFormat::NHWC);
#endif
#if defined(MACE_ENABLE_APU)
Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), DT_FLOAT);
output_tensor->set_data_format(DataFormat::NHWC);
#endif
}
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
......@@ -512,6 +534,12 @@ MaceStatus MaceEngine::Impl::Init(
hexagon_controller_->PrintGraph();
}
} else {
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
apu_controller_.reset(new ApuWrapper(device_.get()));
MACE_CHECK(apu_controller_->Init(*net_def, model_data), "apu init error");
} else {
#endif
MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def,
device_.get(),
......@@ -542,6 +570,9 @@ MaceStatus MaceEngine::Impl::Init(
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
#ifdef MACE_ENABLE_APU
}
#endif
return MaceStatus::MACE_SUCCESS;
}
......@@ -580,6 +611,11 @@ MaceEngine::Impl::~Impl() {
MACE_CHECK(hexagon_controller_->Finalize(), "hexagon finalize error");
}
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
MACE_CHECK(apu_controller_->Uninit(), "apu uninit error");
}
#endif
}
MaceStatus MaceEngine::Impl::TransposeInput(
......@@ -767,11 +803,20 @@ MaceStatus MaceEngine::Impl::Run(
}
hexagon_controller_->ExecuteGraphNew(input_tensors, &output_tensors);
} else {
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
MACE_CHECK(apu_controller_->Run(input_tensors, &output_tensors),
"apu run error");
} else {
#endif
MACE_RETURN_IF_ERROR(net_->Run(run_metadata));
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
#ifdef MACE_ENABLE_APU
}
#endif
#ifdef MACE_ENABLE_OPENCL
if (device_type_ == GPU) {
......
......@@ -79,6 +79,18 @@ def if_hexagon_or_hta_enabled(a):
"//conditions:default": [],
})
def if_apu_enabled(a):
return select({
"//mace:apu_enabled": a,
"//conditions:default": [],
})
def if_not_apu_enabled(a):
return select({
"//mace:apu_enabled": [],
"//conditions:default": a,
})
def if_openmp_enabled(a):
return select({
"//mace:openmp_enabled": a,
......
......@@ -32,7 +32,7 @@ namespace mace {
class NetDef;
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 };
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4, APU = 5 };
enum class DataFormat {
NONE = 0, NHWC = 1, NCHW = 2,
......
......@@ -16,6 +16,7 @@ py_library(
"converter_tool/onnx_converter.py",
"converter_tool/shape_inference.py",
"converter_tool/tensorflow_converter.py",
"converter_tool/apu_converter.py",
"converter_tool/transformer.py",
"graph_util.py",
],
......
......@@ -38,6 +38,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value,
'gpu': cvt.DeviceType.GPU.value,
'dsp': cvt.DeviceType.HEXAGON.value,
'hta': cvt.DeviceType.HTA.value,
'apu': cvt.DeviceType.APU.value,
'cpu+gpu': cvt.DeviceType.CPU.value}
data_format_map = {
......@@ -63,6 +64,8 @@ def parse_data_type(data_type, device_type):
elif device_type == cvt.DeviceType.HEXAGON.value or \
device_type == cvt.DeviceType.HTA.value:
return mace_pb2.DT_FLOAT
elif device_type == cvt.DeviceType.APU.value:
return mace_pb2.DT_FLOAT
else:
print("Invalid device type: " + str(device_type))
......@@ -129,7 +132,7 @@ def main(unused_args):
six.print_("platform %s is not supported." % FLAGS.platform,
file=sys.stderr)
sys.exit(-1)
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'cpu+gpu']:
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'apu', 'cpu+gpu']:
six.print_("runtime %s is not supported." % FLAGS.runtime,
file=sys.stderr)
sys.exit(-1)
......@@ -232,6 +235,13 @@ def main(unused_args):
converter = hexagon_converter.HexagonConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
elif FLAGS.runtime == 'apu':
if FLAGS.platform != 'tensorflow':
raise Exception('apu only support model from tensorflow')
from mace.python.tools.converter_tool import apu_converter
converter = apu_converter.ApuConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
try:
visualizer = visualize_model.ModelVisualizer(FLAGS.model_tag,
......@@ -287,7 +297,7 @@ def parse_args():
default="",
help="File to save the output graph to.")
parser.add_argument(
"--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp")
"--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp/apu")
parser.add_argument(
"--input_node",
type=str,
......
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from enum import Enum
from operator import mul
from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.convert_util import mace_check
from mace.python.tools import graph_util
ApuSupportedOps = [
'Concat',
'Conv2D',
'DepthwiseConv2d',
'Eltwise',
'Pooling',
'Reduce',
'ResizeBilinear',
'Reshape',
'Softmax',
'Squeeze',
]
ApuOp = Enum('ApuOp', [(op, op) for op in ApuSupportedOps], type=str)
class ApuOps(object):
def __init__(self):
self.apu_ops = {
MaceOp.Concat.name: ApuOp.Concat.name,
MaceOp.Conv2D.name: ApuOp.Conv2D.name,
MaceOp.DepthwiseConv2d.name: ApuOp.DepthwiseConv2d.name,
MaceOp.Eltwise.name: ApuOp.Eltwise.name,
MaceOp.Pooling.name: ApuOp.Pooling.name,
MaceOp.Reduce.name: ApuOp.Reduce.name,
MaceOp.ResizeBilinear.name: ApuOp.ResizeBilinear.name,
MaceOp.Reshape.name: ApuOp.Reshape.name,
MaceOp.Softmax.name: ApuOp.Softmax.name,
MaceOp.Squeeze.name: ApuOp.Squeeze.name,
}
def has_op(self, op_name):
return op_name in self.apu_ops
def map_nn_op(self, op_name):
if op_name not in self.apu_ops:
raise Exception('Could not map nn op for: ', op_name)
return self.apu_ops[op_name]
class ApuConverter(base_converter.ConverterInterface):
def __init__(self, option, model, quantize_activation_info):
self._option = option
self._model = model
self._apu_ops = ApuOps()
def run(self):
self.use_uint8_in_out()
self.add_op_output_type()
self.ensure_bias_vector()
self.common_check()
if ConverterUtil.get_arg(self._model.op[0],
MaceKeyword.mace_framework_type_str).i == \
FrameworkType.TENSORFLOW.value:
self.add_tensorflow_padding_value()
const_data_num_arg = self._model.arg.add()
const_data_num_arg.name = MaceKeyword.mace_const_data_num_arg_str
const_data_num_arg.i = len(self._model.tensors)
self.convert_ops()
self.add_node_id()
return self._model
def common_check(self):
for op in self._model.op:
mace_check(len(op.input) >= 1,
op.name + ': apu does not support op with 0 input')
mace_check(len(op.output) == 1,
op.name + ': apu only support single output op')
mace_check(len(op.output) == len(op.output_shape),
op.name + ': length of output and output_shape not'
' match')
mace_check(len(op.output_shape[0].dims) <= 4,
op.name + ': apu only support 1D~4D tensor')
mace_check(len(op.output) == len(op.quantize_info),
op.name + ': length of output and quantize_info not'
' match')
data_format = ConverterUtil.data_format(op)
if data_format is not None and len(op.output_shape[0].dims) == 4:
mace_check((data_format == DataFormat.NHWC)
or (data_format == DataFormat.AUTO),
op.name + ': apu only support 4D tensor with NHWC'
' or AUTO format but find ' + str(data_format))
act_mode_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_activation_type_str)
if act_mode_arg is not None:
mace_check(act_mode_arg.s == b'RELU'
or act_mode_arg.s == b'RELUX',
op.name + ': apu only support activation RELU and'
' RELUX')
for tensor in self._model.tensors:
mace_check(len(tensor.dims) <= 4,
tensor.name + ': apu only support 1D~4D tensor')
for input_info in self._model.input_info:
mace_check(len(input_info.dims) <= 4,
input_info.name + ': apu only support 1D~4D tensor')
mace_check(input_info.data_type == mace_pb2.DT_FLOAT,
input_info.name + ': apu only support float input')
if len(input_info.dims) == 4:
mace_check(input_info.data_format == DataFormat.NHWC.value,
input_info.name + ': apu only support 4D tensor'
' with NHWC format')
def convert_ops(self):
print("Convert mace graph to apu.")
for op in self._model.op:
if not self._apu_ops.has_op(op.type):
raise Exception('Unsupported op: ', op)
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
mace_check(len(op.input) == 3,
op.name + ': apu only support ' + op.type + ' op'
' with 3 input')
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_strides_str)
self.add_padding_tensor_from_arg(op)
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_dilations_str)
if op.type == MaceOp.DepthwiseConv2d.name:
multiplier = self._model.tensors.add()
multiplier.name = op.name + '/multiplier:0'
multiplier.data_type = mace_pb2.DT_INT32
multiplier.dims.extend([1])
for tensor in self._model.tensors:
if tensor.name == op.input[1]:
multiplier.int32_data.extend([tensor.dims[0]])
break
op.input.extend([multiplier.name])
elif op.type == MaceOp.Eltwise.name:
mace_check(len(op.input) == 2,
op.name + ': apu only support eltwise op with 2'
' input')
eltwise_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_element_type_str).i
mace_check(eltwise_type == EltwiseType.SUM.value,
op.name + ': apu only support eltwise type SUM')
elif op.type == MaceOp.Pooling.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support pooling op with 1'
' input')
pooling_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_pooling_type_str)
mace_check(PoolingType(pooling_type_arg.i) == PoolingType.AVG,
op.name + ': apu only support pooling type AVG')
self.add_padding_tensor_from_arg(op)
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_strides_str)
self.add_size_tensor_from_arg(op, MaceKeyword.mace_kernel_str)
elif op.type == MaceOp.Concat.name:
self.add_int_tensor_from_arg(op, MaceKeyword.mace_axis_str)
elif op.type == MaceOp.Reduce.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support reduce op with 1'
' input')
self.add_int_list_tensor_from_arg(
op, MaceKeyword.mace_axis_str)
self.add_int_tensor_from_arg(
op, MaceKeyword.mace_keepdims_str)
elif op.type == MaceOp.ResizeBilinear.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support resize bilinear op'
' with 1 input')
self.add_int_tensor_from_arg(
op, MaceKeyword.mace_align_corners_str)
elif op.type == MaceOp.Reshape.name:
mace_check(len(op.input) == 1 or len(op.input) == 2,
op.name + ': apu only support reshape op with 1 or'
' 2 input')
elif op.type == MaceOp.Softmax.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support softmax op with 1'
' input')
beta_value_tensor = self._model.tensors.add()
beta_value_tensor.name = op.name + '/beta:0'
beta_value_tensor.data_type = mace_pb2.DT_FLOAT
beta_value_tensor.dims.extend([1])
beta_value_tensor.float_data.extend([1.0])
op.input.extend([beta_value_tensor.name])
elif op.type == MaceOp.Squeeze.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support squeeze op with 1'
' input')
self.add_int_list_tensor_from_arg(
op, MaceKeyword.mace_axis_str)
op.type = self._apu_ops.map_nn_op(op.type)
def add_op_output_type(self):
type_map = {}
for input_info in self._model.input_info:
# will do input quantize in wrapper
type_map[input_info.name] = mace_pb2.DT_UINT8
for op in self._model.op:
if len(op.output_type) >= 1:
print([op.name, len(op.output), len(op.output_type)])
type_map[op.output[0]] = op.output_type[0]
continue
mace_check(op.input[0] in type_map,
op.input[0] + ' not in type_map')
op.output_type.extend([type_map[op.input[0]]])
type_map[op.output[0]] = op.output_type[0]
for op in self._model.op:
mace_check(len(op.output) == len(op.output_type),
op.name + ': length of output and output_type not'
' match')
mace_check(op.output_type[0] == mace_pb2.DT_UINT8
or op.output_type[0] == mace_pb2.DT_INT32,
op.name + ': apu only support quantized node')
def add_node_id(self):
node_id_counter = 0
node_id_map = {}
for tensor in self._model.tensors:
tensor.node_id = node_id_counter
node_id_counter += 1
node_id_map[tensor.name] = tensor.node_id
for input_info in self._model.input_info:
input_info.node_id = node_id_counter
node_id_counter += 1
node_id_map[input_info.name] = input_info.node_id
for op in self._model.op:
op.node_id = node_id_counter
node_id_counter += 1
node_id_map[op.output[0]] = op.node_id
for op in self._model.op:
del op.node_input[:]
for input_tensor in op.input:
node_input = op.node_input.add()
node_input.node_id = node_id_map[input_tensor]
for output_info in self._model.output_info:
output_info.node_id = node_id_map[output_info.name]
def add_padding_tensor_from_arg(self, op):
padding_value_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str)
mace_check(len(padding_value_arg.ints) == 4,
op.name + ': padding value does not have size 4')
padding_value_tensor = self._model.tensors.add()
padding_value_tensor.name = op.name + '/padding:0'
padding_value_tensor.data_type = mace_pb2.DT_INT32
padding_value_tensor.dims.extend([4])
padding_value_tensor.int32_data.extend(padding_value_arg.ints)
op.input.extend([padding_value_tensor.name])
def add_size_tensor_from_arg(self, op, keyword):
size_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(len(size_value_arg.ints) == 2,
op.name + ': ' + keyword + ' value does not have size 2')
size_value_tensor = self._model.tensors.add()
size_value_tensor.name = op.name + '/' + keyword + ':0'
size_value_tensor.data_type = mace_pb2.DT_INT32
size_value_tensor.dims.extend([2])
size_value_tensor.int32_data.extend(size_value_arg.ints)
op.input.extend([size_value_tensor.name])
def add_int_tensor_from_arg(self, op, keyword):
int_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(int_value_arg.i is not None,
op.name + ': ' + keyword + ' value i should not be None')
int_value_tensor = self._model.tensors.add()
int_value_tensor.name = op.name + '/' + keyword + ':0'
int_value_tensor.data_type = mace_pb2.DT_INT32
int_value_tensor.dims.extend([1])
int_value_tensor.int32_data.extend([int_value_arg.i])
op.input.extend([int_value_tensor.name])
def add_int_list_tensor_from_arg(self, op, keyword):
list_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(list_value_arg.ints is not None,
op.name + ': ' + keyword + ' value ints should not be None')
list_value_tensor = self._model.tensors.add()
list_value_tensor.name = op.name + '/' + keyword + ':0'
list_value_tensor.data_type = mace_pb2.DT_INT32
list_value_tensor.dims.extend([len(list_value_arg.ints)])
list_value_tensor.int32_data.extend(list_value_arg.ints)
op.input.extend([list_value_tensor.name])
def add_tensorflow_padding_value(self):
for op in self._model.op:
padding_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str)
if padding_type is None:
continue
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
if padding_type.i == PaddingMode.VALID.value:
padding_arg.ints.extend([0, 0, 0, 0])
elif padding_type.i == PaddingMode.SAME.value:
stride = ConverterUtil.get_arg(
op, MaceKeyword.mace_strides_str).ints
kernel = []
dilation = [1, 1]
if op.type == MaceOp.Conv2D.name or \
op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg(
op, MaceKeyword.mace_dilations_str) is not None:
dilation = ConverterUtil.get_arg(
op, MaceKeyword.mace_dilations_str).ints
for tensor in self._model.tensors:
if tensor.name == op.input[1]:
kernel = tensor.dims[1:3]
break
else:
kernel = ConverterUtil.get_arg(
op, MaceKeyword.mace_kernel_str).ints
in_size = []
for input_info in self._model.input_info:
if input_info.name == op.input[0]:
in_size = input_info.dims[1:3]
break
for _op in self._model.op:
for out in _op.output:
if out == op.input[0]:
in_size = _op.output_shape[0].dims[1:3]
break
if len(in_size) > 0:
break
out_size = op.output_shape[0].dims[1:3]
h = (out_size[0] - 1) * stride[0] \
+ ((kernel[0] - 1) * dilation[0] + 1) - in_size[0]
w = (out_size[1] - 1) * stride[1] \
+ ((kernel[1] - 1) * dilation[1] + 1) - in_size[1]
top = int(np.floor(h/2))
left = int(np.floor(w/2))
bottom = h - top
right = w - left
padding_arg.ints.extend([top, right, bottom, left])
def ensure_bias_vector(self):
for _op in self._model.op:
if _op.type != MaceOp.Conv2D.name and \
_op.type != MaceOp.DepthwiseConv2d.name:
continue
if len(_op.input) != 2:
continue
tensor = self._model.tensors.add()
tensor.name = _op.name + '/add/bias_add'
tensor.dims.extend([_op.output_shape[0].dims[-1]])
if _op.output_type[0] == mace_pb2.DT_UINT8:
tensor.data_type = mace_pb2.DT_INT32
input_name = _op.input[0]
for input_op in self._model.op:
if input_op.output[0] == input_name:
scale_input = input_op.quantize_info[0].scale
break
filter_name = _op.input[1]
for filter_tensor in self._model.tensors:
if filter_tensor.name == filter_name:
scale_filter = filter_tensor.scale
break
tensor.scale = scale_input * scale_filter
tensor.zero_point = 0
tensor.quantized = True
tensor.int32_data.extend([0] * tensor.dims[0])
elif _op.output_type[0] == mace_pb2.DT_FLOAT:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend([0.0] * tensor.dims[0])
_op.input.extend([tensor.name])
def use_uint8_in_out(self):
for input_info in self._model.input_info:
if input_info.data_type == mace_pb2.DT_FLOAT:
for op in self._model.op:
if op.input[0] == input_info.name \
and op.type == MaceOp.Quantize.name:
input_info.name = op.output[0]
input_info.scale = op.quantize_info[0].scale
input_info.zero_point = op.quantize_info[0].zero_point
break
self._model.op.remove(op)
for output_info in self._model.output_info:
if output_info.data_type == mace_pb2.DT_FLOAT:
for op in self._model.op:
if op.output[0] == output_info.name \
and op.type == MaceOp.Dequantize.name:
output_info.name = op.input[0]
break
self._model.op.remove(op)
......@@ -23,6 +23,7 @@ class DeviceType(Enum):
GPU = 2
HEXAGON = 3
HTA = 4
APU = 5
class DataFormat(Enum):
......@@ -271,6 +272,7 @@ class MaceKeyword(object):
mace_pad_type_str = 'pad_type'
mace_exclusive_str = 'exclusive'
mace_reverse_str = 'reverse'
mace_const_data_num_arg_str = 'const_data_num'
class TransformerRule(Enum):
......@@ -518,6 +520,9 @@ class ConverterOption(object):
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
# For StoB -> conv -> BtoS -> BN pattern
# Insert flatten_atrous_conv before fold_xxx_and_bn
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DECONV_AND_BN,
......
......@@ -115,6 +115,7 @@ TFSupportedOps = [
'ArgMax',
'Split',
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs',
'FloorDiv',
'Sqrt',
'MirrorPad',
......@@ -261,6 +262,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
......@@ -1034,10 +1036,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
min_arg = op.arg.add()
min_arg.name = 'min'
min_arg.f = tf_op.inputs[1].eval()
max_arg = op.arg.add()
max_arg.name = 'max'
max_arg.f = tf_op.inputs[2].eval()
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
min_arg.f = tf_op.inputs[1].eval()
max_arg.f = tf_op.inputs[2].eval()
elif tf_op.type == TFOpType.FakeQuantWithMinMaxArgs.name:
min_arg.f = float(tf_op.get_attr('min'))
max_arg.f = float(tf_op.get_attr('max'))
narrow_range_arg = op.arg.add()
narrow_range_arg.name = 'narrow_range'
narrow_range_arg.i = int(tf_op.get_attr('narrow_range'))
......@@ -1045,8 +1051,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_bits_arg.name = 'num_bits'
num_bits_arg.i = int(tf_op.get_attr('num_bits'))
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -654,6 +654,7 @@ class Transformer(base_converter.ConverterInterface):
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -722,6 +723,7 @@ class Transformer(base_converter.ConverterInterface):
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -778,6 +780,7 @@ class Transformer(base_converter.ConverterInterface):
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -874,7 +877,8 @@ class Transformer(base_converter.ConverterInterface):
return False
def flatten_atrous_conv(self):
if self._option.device != DeviceType.GPU.value:
if self._option.device != DeviceType.GPU.value \
and self._option.device != DeviceType.APU.value:
return
net = self._model
......@@ -1070,7 +1074,8 @@ class Transformer(base_converter.ConverterInterface):
transposed_deconv_filter = set()
if self._option.quantize and \
self._option.device == DeviceType.CPU.value:
(self._option.device == DeviceType.CPU.value or
self._option.device == DeviceType.APU.value):
print("Transpose filters to OHWI")
if filter_format == DataFormat.HWIO:
transpose_order = [3, 0, 1, 2]
......@@ -1082,7 +1087,9 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if (op.type == MaceOp.Conv2D.name or
op.type == MaceOp.Deconv2D.name) and\
op.type == MaceOp.Deconv2D.name or
(op.type == MaceOp.DepthwiseConv2d.name and
self._option.device == DeviceType.APU.value)) and\
op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
......@@ -1572,7 +1579,8 @@ class Transformer(base_converter.ConverterInterface):
if len(ops[0].input) >= 4:
check_deconv = ops[0].input[3] == tensor.name
if check_conv or check_deconv:
if self._option.device == DeviceType.CPU.value:
if self._option.device == DeviceType.CPU.value \
or self._option.device == DeviceType.APU.value:
conv_op = ops[0]
scale_input = self._quantize_activation_info[
conv_op.input[0]].scale
......@@ -1648,13 +1656,16 @@ class Transformer(base_converter.ConverterInterface):
net = self._model
for op in net.op:
if op.type == 'FakeQuantWithMinMaxVars':
if op.type == 'FakeQuantWithMinMaxVars' or \
op.type == 'FakeQuantWithMinMaxArgs':
producer_op = self._producer[op.input[0]]
minval = ConverterUtil.get_arg(op, 'min').f
maxval = ConverterUtil.get_arg(op, 'max').f
quantize_info = \
self.add_quantize_info(producer_op, minval, maxval)
self._quantize_activation_info[op.input[0]] = quantize_info
# for add -> fakequant pattern
self._quantize_activation_info[op.output[0]] = quantize_info
op.type = MaceOp.Identity.name
return False
......
......@@ -78,6 +78,8 @@ DeviceType ParseDeviceType(const std::string &device_str) {
return DeviceType::HEXAGON;
} else if (device_str.compare("HTA") == 0) {
return DeviceType::HTA;
} else if (device_str.compare("APU") == 0) {
return DeviceType::APU;
} else {
return DeviceType::CPU;
}
......@@ -141,7 +143,7 @@ DEFINE_string(model_data_file,
DEFINE_string(model_file,
"",
"model file name, used when load mace model in pb");
DEFINE_string(device, "GPU", "CPU/GPU/HEXAGON");
DEFINE_string(device, "GPU", "CPU/GPU/HEXAGON/APU");
DEFINE_int32(round, 1, "round");
DEFINE_int32(restart_round, 1, "restart round");
DEFINE_int32(malloc_check_cycle, -1, "malloc debug check cycle, -1 to disable");
......
// Copyright 2019 MediaTek Inc. All rights reserved.
#pragma once
enum apu_act_mode {
APU_ACT_NONE = 0,
APU_ACT_RELU = 1,
APU_ACT_RELU6 = 2,
};
enum apu_pooling_mode {
APU_POOLING_UNDEFINED = 0,
APU_POOLING_AVG = 1,
APU_POOLING_MAX = 2,
};
enum apu_eltwise_mode {
APU_ELTWISE_UNDEFINED = 0,
APU_ELTWISE_ADD = 1,
APU_ELTWISE_SUB = 2,
APU_ELTWISE_MUL = 3,
APU_ELTWISE_MIN = 4,
APU_ELTWISE_MAX = 5,
};
enum apu_data_type {
APU_DATA_TYPE_UNDEFINED = 0,
APU_DATA_TYPE_FLOAT = 1,
APU_DATA_TYPE_UINT8 = 2,
APU_DATA_TYPE_HALF = 3,
APU_DATA_TYPE_INT32 = 4,
};
enum apu_tensor_type {
APU_TENSOR_UNDEFINED = 0,
APU_TENSOR_CONST_DATA = 1,
APU_TENSOR_CONST_ARGUMENT = 2,
APU_TENSOR_MODEL_INPUT = 3,
APU_TENSOR_OP_OUTPUT = 4,
};
#define APU_TENSOR_MAX_DIMS 4
struct apu_tensor {
int tensor_id;
apu_tensor_type tensor_type;
apu_data_type data_type;
float scale;
int zero_point;
int dims[APU_TENSOR_MAX_DIMS];
int dim_size;
void* data_buf;
};
#define APU_OP_TYPE_MAX_SIZE 32
struct apu_operator {
char type[APU_OP_TYPE_MAX_SIZE];
int* input_ids;
int input_size;
apu_tensor output;
int op_mode; // for pooling and eltwise
apu_act_mode act_mode;
};
class ApuFrontend {
public:
ApuFrontend();
~ApuFrontend();
bool InitGraph(int const_tensor_size, const apu_tensor* const_tensors,
int input_tensor_size, const apu_tensor* input_tensors,
int output_tensor_size, const int* output_tensor_ids,
void** output_buffers,
int operator_size, const apu_operator* operators,
bool print_model);
bool RunGraph();
bool UninitGraph();
private:
class Impl;
ApuFrontend::Impl* impl;
};
licenses(["notice"])
exports_files(["license.txt"])
cc_library(
name = "libapu-frontend",
srcs = [
"libapu-frontend.so",
],
hdrs = glob(["*.h"]),
copts = ["-DNN_TARGET_NDK"],
linkopts = ["-Wl,-unresolved-symbols=ignore-in-shared-libs"],
visibility = ["//visibility:public"],
)
此差异已折叠。
......@@ -19,6 +19,7 @@ mkdir -p $LIB_DIR/armeabi-v7a/cpu_gpu
rm -rf $LIB_DIR/arm64-v8a
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu_dsp
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu_apu
rm -rf $LIB_DIR/linux-x86-64
mkdir -p $LIB_DIR/linux-x86-64
......@@ -50,6 +51,11 @@ echo "build shared lib for arm64-v8a + cpu_gpu"
bazel build --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define quantize=true --cpu=arm64-v8a
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm64-v8a/cpu_gpu/
echo "build shared lib for arm64-v8a + cpu_gpu_apu"
bazel build --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define apu=true --define quantize=true --cpu=arm64-v8a
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
cp third_party/apu/libapu-frontend.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
echo "build shared lib for arm_linux_gnueabihf + cpu_gpu"
bazel build --config arm_linux_gnueabihf --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define quantize=true
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm_linux_gnueabihf/cpu_gpu/
......@@ -83,6 +89,11 @@ echo "build static lib for arm64-v8a + cpu_gpu"
bazel build --config android --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define quantize=true --cpu=arm64-v8a
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm64-v8a/cpu_gpu/
echo "build static lib for arm64-v8a + cpu_gpu_apu"
bazel build --config android --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define apu=true --define quantize=true --cpu=arm64-v8a
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm64-v8a/cpu_gpu_apu/
cp third_party/apu/libapu-frontend.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
echo "build static lib for arm_linux_gnueabihf + cpu_gpu"
bazel build --config arm_linux_gnueabihf --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define quantize=true
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm_linux_gnueabihf/cpu_gpu/
......
......@@ -130,6 +130,7 @@ class DeviceType(object):
GPU = 'GPU'
HEXAGON = 'HEXAGON'
HTA = 'HTA'
APU = 'APU'
class DataFormat(object):
......@@ -207,6 +208,8 @@ def parse_device_type(runtime):
device_type = DeviceType.GPU
elif runtime == RuntimeType.cpu:
device_type = DeviceType.CPU
elif runtime == RuntimeType.apu:
device_type = DeviceType.APU
return device_type
......@@ -520,6 +523,7 @@ class RuntimeType(object):
gpu = 'gpu'
dsp = 'dsp'
hta = 'hta'
apu = 'apu'
cpu_gpu = 'cpu+gpu'
......
......@@ -62,6 +62,7 @@ RuntimeTypeStrs = [
"gpu",
"dsp",
"hta",
"apu",
"cpu+gpu"
]
......@@ -89,6 +90,13 @@ DSPDataTypeStrs = [
DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
type=str)
APUDataTypeStrs = [
"uint8",
]
APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
type=str)
WinogradParameters = [0, 2, 4]
DataFormatStrs = [
......@@ -150,6 +158,8 @@ def parse_device_type(runtime):
device_type = DeviceType.GPU
elif runtime == RuntimeType.cpu:
device_type = DeviceType.CPU
elif runtime == RuntimeType.apu:
device_type = DeviceType.APU
return device_type
......@@ -361,6 +371,15 @@ def format_model_config(flags):
else:
model_config[YAMLKeyword.data_type] = \
DSPDataType.uint8.value
elif runtime == RuntimeType.apu:
if len(data_type) > 0:
mace_check(data_type in APUDataTypeStrs,
ModuleName.YAML_CONFIG,
"'data_type' must be in " + str(APUDataTypeStrs)
+ " for apu runtime")
else:
model_config[YAMLKeyword.data_type] = \
APUDataType.uint8.value
else:
if len(data_type) > 0:
mace_check(data_type in FPDataTypeStrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册