提交 2d650b67 编写于 作者: 許永健 提交者: Liangliang He

Integrate MediaTek APU Support (#440)

* Integrate APU

* fix code style issue
上级 9fbb9e17
......@@ -74,6 +74,18 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "apu_enabled",
define_values = {
"apu": "true",
},
values = {
"crosstool_top": "//external:android/crosstool",
"cpu": "arm64-v8a",
},
visibility = ["//visibility:public"],
)
config_setting(
name = "hexagon_enabled",
define_values = {
......
......@@ -11,10 +11,12 @@ load(
"//mace:mace.bzl",
"if_android",
"if_android_armv7",
"if_apu_enabled",
"if_hexagon_enabled",
"if_hexagon_or_hta_enabled",
"if_hta_enabled",
"if_neon_enabled",
"if_not_apu_enabled",
"if_not_hexagon_enabled",
"if_opencl_enabled",
"if_openmp_enabled",
......@@ -39,7 +41,9 @@ cc_library(
"runtime/hexagon/hexagon_dsp_wrapper.cc",
]) + if_hta_enabled([
"runtime/hexagon/hexagon_hta_wrapper.cc",
]),
]) + if_apu_enabled(glob([
"runtime/apu/*.cc",
])),
hdrs = glob([
"*.h",
"runtime/cpu/*.h",
......@@ -52,6 +56,8 @@ cc_library(
"runtime/hexagon/*dsp*.h",
])) + if_hta_enabled(glob([
"runtime/hexagon/*hta*.h",
])) + if_apu_enabled(glob([
"runtime/apu/*.h"
])),
copts = [
"-Werror",
......@@ -68,6 +74,8 @@ cc_library(
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
"-DMACE_ENABLE_HTA",
]) + if_apu_enabled([
"-DMACE_ENABLE_APU",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
......@@ -90,6 +98,8 @@ cc_library(
"//third_party/nnlib:libhexagon",
]) + if_hta_enabled([
"//third_party/hta",
]) + if_apu_enabled([
"//third_party/apu:libapu-frontend",
]),
)
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
#define MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
#include "mace/core/device.h"
namespace mace {
class ApuDevice : public CPUDevice {
public:
explicit ApuDevice(utils::ThreadPool *thread_pool)
: CPUDevice(0, AFFINITY_NONE, thread_pool) {}
DeviceType device_type() const override {
return DeviceType::APU;
};
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_APU_APU_DEVICE_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/runtime/apu/apu_wrapper.h"
#include <algorithm>
#include "mace/core/quantize.h"
namespace mace {
ApuWrapper::ApuWrapper(Device *device)
: quantize_util_(&device->cpu_runtime()->thread_pool()) {
}
apu_data_type ApuWrapper::MapToApuDataType(DataType mace_type) {
switch (mace_type) {
case DT_FLOAT:
return APU_DATA_TYPE_FLOAT;
case DT_INT32:
return APU_DATA_TYPE_INT32;
case DT_HALF:
return APU_DATA_TYPE_HALF;
case DT_UINT8:
return APU_DATA_TYPE_UINT8;
default:
MACE_CHECK(true, "unsupport mace data type");
break;
}
return APU_DATA_TYPE_UNDEFINED;
}
apu_pooling_mode ApuWrapper::MapToApuPoolingMode(int mace_mode) {
switch (mace_mode) {
case 1:
return APU_POOLING_AVG;
case 2:
return APU_POOLING_MAX;
default:
MACE_CHECK(true, "unsupport mace pooling mode");
break;
}
return APU_POOLING_UNDEFINED;
}
apu_eltwise_mode ApuWrapper::MapToApuEltwiseMode(int mace_mode) {
switch (mace_mode) {
case 0:
return APU_ELTWISE_ADD;
case 1:
return APU_ELTWISE_SUB;
case 2:
return APU_ELTWISE_MUL;
case 4:
return APU_ELTWISE_MIN;
case 5:
return APU_ELTWISE_MAX;
default:
MACE_CHECK(true, "unsupport mace eltwise mode");
break;
}
return APU_ELTWISE_UNDEFINED;
}
bool ApuWrapper::Init(const NetDef &net_def,
unsigned const char *model_data) {
frontend = new ApuFrontend();
// parse model argument
int const_data_num = 0;
for (auto arg : net_def.arg()) {
if (arg.name().compare("const_data_num") == 0) {
const_data_num = arg.i();
}
}
// const tensors
std::vector<apu_tensor> const_tensors;
for (auto const_tensor : net_def.tensors()) {
apu_tensor tensor;
tensor.tensor_id = const_tensor.node_id();
tensor.tensor_type = (tensor.tensor_id < const_data_num) ?
APU_TENSOR_CONST_DATA :
APU_TENSOR_CONST_ARGUMENT;
tensor.data_type = MapToApuDataType(const_tensor.data_type());
tensor.scale = const_tensor.has_scale() ? const_tensor.scale() : 0.0f;
tensor.zero_point = const_tensor.has_zero_point() ?
const_tensor.zero_point() : 0;
tensor.dim_size = const_tensor.dims_size();
MACE_CHECK(tensor.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
for (auto i = 0 ; i < tensor.dim_size ; i++) {
tensor.dims[i] = const_tensor.dims(i);
}
tensor.data_buf =
const_cast<unsigned char*>(model_data + const_tensor.offset());
const_tensors.push_back(tensor);
}
// input tensors
std::vector<apu_tensor> input_tensors;
for (auto input_info : net_def.input_info()) {
apu_tensor tensor;
tensor.tensor_id = input_info.node_id();
tensor.tensor_type = APU_TENSOR_MODEL_INPUT;
tensor.data_type = APU_DATA_TYPE_UINT8; // will do quantize in Run()
tensor.scale = input_info.has_scale() ? input_info.scale() : 0.0f;
tensor.zero_point = input_info.has_zero_point() ?
input_info.zero_point() : 0;
tensor.dim_size = input_info.dims_size();
MACE_CHECK(tensor.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
tensor_info info;
info.name = input_info.name();
info.size = 1;
for (auto i = 0 ; i < tensor.dim_size ; i++) {
tensor.dims[i] = input_info.dims(i);
info.size *= input_info.dims(i);
info.shape.push_back(input_info.dims(i));
}
info.buf = std::shared_ptr<uint8_t>(new uint8_t[info.size],
std::default_delete<uint8_t[]>());
info.scale = tensor.scale;
info.zero_point = tensor.zero_point;
input_infos.push_back(info);
tensor.data_buf = info.buf.get();
input_tensors.push_back(tensor);
}
// output tensors
std::vector<int> output_tensor_ids;
std::vector<void*> output_buffers;
for (auto output_info : net_def.output_info()) {
output_tensor_ids.push_back(output_info.node_id());
tensor_info info;
info.name = output_info.name();
info.size = 1;
for (auto i = 0 ; i < output_info.dims().size() ; i++) {
info.size *= output_info.dims(i);
info.shape.push_back(output_info.dims(i));
}
info.buf = std::shared_ptr<uint8_t>(new uint8_t[info.size],
std::default_delete<uint8_t[]>());
for (auto op_def : net_def.op()) {
if (output_info.name() == op_def.output(0)) {
info.scale = op_def.quantize_info(0).scale();
info.zero_point = op_def.quantize_info(0).zero_point();
}
}
output_infos.push_back(info);
output_buffers.push_back(info.buf.get());
}
// operators
std::vector<apu_operator> ops;
std::vector<std::vector<int>> cached_op_inputs;
for (auto op_def : net_def.op()) {
apu_operator op;
strncpy(op.type, op_def.type().c_str(), APU_OP_TYPE_MAX_SIZE);
op.input_size = op_def.node_input_size();
std::vector<int> input_ids;
for (auto i = 0 ; i < op.input_size ; i++) {
input_ids.push_back(op_def.node_input(i).node_id());
}
cached_op_inputs.push_back(input_ids);
op.input_ids = cached_op_inputs.back().data();
op.output.tensor_id = op_def.node_id();
op.output.tensor_type = APU_TENSOR_OP_OUTPUT;
op.output.data_type = MapToApuDataType(op_def.output_type(0));
if (op.output.data_type == APU_DATA_TYPE_UINT8) {
op.output.scale = op_def.quantize_info(0).scale();
op.output.zero_point = op_def.quantize_info(0).zero_point();
} else {
op.output.scale = 0.0f;
op.output.zero_point = 0;
}
op.output.dim_size = op_def.output_shape(0).dims_size();
MACE_CHECK(op.output.dim_size <= APU_TENSOR_MAX_DIMS,
"tensor dimension size not supported");
for (auto i = 0 ; i < op.output.dim_size ; i++) {
op.output.dims[i] = op_def.output_shape(0).dims(i);
}
op.output.data_buf = nullptr;
// get op mode and activation mode
bool is_pooling = (strcmp(op.type, "Pooling") == 0);
bool is_eltwise = (strcmp(op.type, "Eltwise") == 0);
std::string activation;
float max_limit = 0.0f;
for (auto arg : op_def.arg()) {
if (arg.name().compare("activation") == 0) {
activation = arg.s();
}
if (arg.name().compare("max_limit") == 0) {
max_limit = arg.f();
}
if (is_pooling && arg.name().compare("pooling_type") == 0) {
op.op_mode = static_cast<int>(MapToApuPoolingMode(arg.i()));
}
if (is_eltwise && arg.name().compare("type") == 0) {
op.op_mode = static_cast<int>(MapToApuEltwiseMode(arg.i()));
}
}
if (activation.compare("RELU") == 0) {
op.act_mode = APU_ACT_RELU;
} else if (activation.compare("RELUX") == 0 && max_limit == 6.0) {
op.act_mode = APU_ACT_RELU6;
} else {
op.act_mode = APU_ACT_NONE;
}
ops.push_back(op);
}
bool print_model = false;
bool ret = frontend->InitGraph(
const_tensors.size(), const_tensors.data(),
input_tensors.size(), input_tensors.data(),
output_tensor_ids.size(), output_tensor_ids.data(),
output_buffers.data(),
ops.size(), ops.data(),
print_model);
cached_op_inputs.clear();
MACE_CHECK(ret == true, "apu init graph failed");
return ret;
}
bool ApuWrapper::Run(const std::map<std::string, Tensor *> &input_tensors,
std::map<std::string, Tensor *> *output_tensors) {
MACE_ASSERT(input_tensors.size() == input_infos.size(), "Wrong inputs num");
MACE_ASSERT(output_tensors.size() == output_infos.size(),
"Wrong outputs num");
// prepare input
for (int i = 0 ; i < static_cast<int>(input_tensors.size()) ; i++) {
Tensor* tensor = input_tensors.at(input_infos[i].name);
// check size
int size = input_infos[i].size;
MACE_ASSERT(size == static_cast<int>(tensor->size()), "Wrong input size");
// quantize
quantize_util_.QuantizeWithScaleAndZeropoint(
(const float*)tensor->raw_data(),
size,
input_infos[i].scale,
input_infos[i].zero_point,
input_infos[i].buf.get());
}
// run model
bool ret = frontend->RunGraph();
MACE_CHECK(ret == true, "neuron run model failed");
// process output
for (int i = 0 ; i < static_cast<int>(output_tensors->size()) ; i++) {
Tensor* tensor = output_tensors->at(output_infos[i].name);
// prepare out buffer
tensor->SetDtype(DT_FLOAT);
tensor->Resize(output_infos[i].shape);
int size = output_infos[i].size;
MACE_ASSERT(size == static_cast<int>(tensor->size()), "Wrong output size");
// dequantize
quantize_util_.Dequantize(
output_infos[i].buf.get(),
size,
output_infos[i].scale,
output_infos[i].zero_point,
reinterpret_cast<float*>(tensor->raw_mutable_data()));
}
return true;
}
bool ApuWrapper::Uninit() {
bool ret = frontend->UninitGraph();
frontend = nullptr;
input_infos.clear();
output_infos.clear();
return ret;
}
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
#define MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "mace/proto/mace.pb.h"
#include "mace/core/tensor.h"
#include "mace/core/device.h"
#include "mace/core/quantize.h"
#include "third_party/apu/ApuFrontend.h"
namespace mace {
class ApuWrapper {
struct tensor_info {
std::string name;
std::shared_ptr<uint8_t> buf;
std::vector<index_t> shape;
int size;
float scale;
int zero_point;
};
public:
explicit ApuWrapper(Device *device);
bool Init(const NetDef& net_def, unsigned const char *model_data);
bool Run(const std::map<std::string, Tensor *> &input_tensors,
std::map<std::string, Tensor *> *output_tensors);
bool Uninit();
private:
apu_data_type MapToApuDataType(DataType mace_type);
apu_pooling_mode MapToApuPoolingMode(int mace_mode);
apu_eltwise_mode MapToApuEltwiseMode(int mace_mode);
private:
ApuFrontend* frontend;
std::vector<tensor_info> input_infos;
std::vector<tensor_info> output_infos;
QuantizeUtil<uint8_t> quantize_util_;
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_APU_APU_WRAPPER_H_
......@@ -11,6 +11,7 @@ load(
"//mace:mace.bzl",
"if_android",
"if_android_armv7",
"if_apu_enabled",
"if_darwin",
"if_hexagon_enabled",
"if_hta_enabled",
......@@ -44,6 +45,8 @@ cc_library(
"-DMACE_ENABLE_HEXAGON",
]) + if_hta_enabled([
"-DMACE_ENABLE_HTA",
]) + if_apu_enabled([
"-DMACE_ENABLE_APU",
]),
deps = [
"//mace/ops",
......
......@@ -38,7 +38,13 @@
#include "mace/core/runtime/hexagon/hexagon_device.h"
#endif
#ifdef MACE_ENABLE_APU
#include "mace/core/runtime/apu/apu_wrapper.h"
#include "mace/core/runtime/apu/apu_device.h"
#endif // MACE_ENABLE_APU
namespace mace {
namespace {
#ifdef MACE_ENABLE_OPENCL
......@@ -398,6 +404,9 @@ class MaceEngine::Impl {
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif
#ifdef MACE_ENABLE_APU
std::unique_ptr<ApuWrapper> apu_controller_;
#endif
MACE_DISABLE_COPY_AND_ASSIGN(Impl);
};
......@@ -415,6 +424,9 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
, hexagon_controller_(nullptr)
#endif
#ifdef MACE_ENABLE_APU
, apu_controller_(nullptr)
#endif
{
LOG(INFO) << "Creating MaceEngine, MACE version: " << MaceVersion();
thread_pool_->Init();
......@@ -441,6 +453,11 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
|| device_type_ == DeviceType::HTA) {
device_.reset(new HexagonDevice(device_type_, thread_pool_.get()));
}
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == DeviceType::APU) {
device_.reset(new ApuDevice(thread_pool_.get()));
}
#endif
MACE_CHECK_NOTNULL(device_);
}
......@@ -497,6 +514,11 @@ MaceStatus MaceEngine::Impl::Init(
Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), output_dt);
output_tensor->set_data_format(DataFormat::NHWC);
#endif
#if defined(MACE_ENABLE_APU)
Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), DT_FLOAT);
output_tensor->set_data_format(DataFormat::NHWC);
#endif
}
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
......@@ -512,6 +534,12 @@ MaceStatus MaceEngine::Impl::Init(
hexagon_controller_->PrintGraph();
}
} else {
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
apu_controller_.reset(new ApuWrapper(device_.get()));
MACE_CHECK(apu_controller_->Init(*net_def, model_data), "apu init error");
} else {
#endif
MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def,
device_.get(),
......@@ -542,6 +570,9 @@ MaceStatus MaceEngine::Impl::Init(
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
#ifdef MACE_ENABLE_APU
}
#endif
return MaceStatus::MACE_SUCCESS;
}
......@@ -580,6 +611,11 @@ MaceEngine::Impl::~Impl() {
MACE_CHECK(hexagon_controller_->Finalize(), "hexagon finalize error");
}
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
MACE_CHECK(apu_controller_->Uninit(), "apu uninit error");
}
#endif
}
MaceStatus MaceEngine::Impl::TransposeInput(
......@@ -767,11 +803,20 @@ MaceStatus MaceEngine::Impl::Run(
}
hexagon_controller_->ExecuteGraphNew(input_tensors, &output_tensors);
} else {
#endif
#ifdef MACE_ENABLE_APU
if (device_type_ == APU) {
MACE_CHECK(apu_controller_->Run(input_tensors, &output_tensors),
"apu run error");
} else {
#endif
MACE_RETURN_IF_ERROR(net_->Run(run_metadata));
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
}
#endif
#ifdef MACE_ENABLE_APU
}
#endif
#ifdef MACE_ENABLE_OPENCL
if (device_type_ == GPU) {
......
......@@ -79,6 +79,18 @@ def if_hexagon_or_hta_enabled(a):
"//conditions:default": [],
})
def if_apu_enabled(a):
return select({
"//mace:apu_enabled": a,
"//conditions:default": [],
})
def if_not_apu_enabled(a):
return select({
"//mace:apu_enabled": [],
"//conditions:default": a,
})
def if_openmp_enabled(a):
return select({
"//mace:openmp_enabled": a,
......
......@@ -32,7 +32,7 @@ namespace mace {
class NetDef;
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 };
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4, APU = 5 };
enum class DataFormat {
NONE = 0, NHWC = 1, NCHW = 2,
......
......@@ -16,6 +16,7 @@ py_library(
"converter_tool/onnx_converter.py",
"converter_tool/shape_inference.py",
"converter_tool/tensorflow_converter.py",
"converter_tool/apu_converter.py",
"converter_tool/transformer.py",
"graph_util.py",
],
......
......@@ -38,6 +38,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value,
'gpu': cvt.DeviceType.GPU.value,
'dsp': cvt.DeviceType.HEXAGON.value,
'hta': cvt.DeviceType.HTA.value,
'apu': cvt.DeviceType.APU.value,
'cpu+gpu': cvt.DeviceType.CPU.value}
data_format_map = {
......@@ -63,6 +64,8 @@ def parse_data_type(data_type, device_type):
elif device_type == cvt.DeviceType.HEXAGON.value or \
device_type == cvt.DeviceType.HTA.value:
return mace_pb2.DT_FLOAT
elif device_type == cvt.DeviceType.APU.value:
return mace_pb2.DT_FLOAT
else:
print("Invalid device type: " + str(device_type))
......@@ -129,7 +132,7 @@ def main(unused_args):
six.print_("platform %s is not supported." % FLAGS.platform,
file=sys.stderr)
sys.exit(-1)
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'cpu+gpu']:
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', 'hta', 'apu', 'cpu+gpu']:
six.print_("runtime %s is not supported." % FLAGS.runtime,
file=sys.stderr)
sys.exit(-1)
......@@ -232,6 +235,13 @@ def main(unused_args):
converter = hexagon_converter.HexagonConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
elif FLAGS.runtime == 'apu':
if FLAGS.platform != 'tensorflow':
raise Exception('apu only support model from tensorflow')
from mace.python.tools.converter_tool import apu_converter
converter = apu_converter.ApuConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
try:
visualizer = visualize_model.ModelVisualizer(FLAGS.model_tag,
......@@ -287,7 +297,7 @@ def parse_args():
default="",
help="File to save the output graph to.")
parser.add_argument(
"--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp")
"--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp/apu")
parser.add_argument(
"--input_node",
type=str,
......
# Copyright 2018 The MACE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from enum import Enum
from operator import mul
from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.convert_util import mace_check
from mace.python.tools import graph_util
ApuSupportedOps = [
'Concat',
'Conv2D',
'DepthwiseConv2d',
'Eltwise',
'Pooling',
'Reduce',
'ResizeBilinear',
'Reshape',
'Softmax',
'Squeeze',
]
ApuOp = Enum('ApuOp', [(op, op) for op in ApuSupportedOps], type=str)
class ApuOps(object):
def __init__(self):
self.apu_ops = {
MaceOp.Concat.name: ApuOp.Concat.name,
MaceOp.Conv2D.name: ApuOp.Conv2D.name,
MaceOp.DepthwiseConv2d.name: ApuOp.DepthwiseConv2d.name,
MaceOp.Eltwise.name: ApuOp.Eltwise.name,
MaceOp.Pooling.name: ApuOp.Pooling.name,
MaceOp.Reduce.name: ApuOp.Reduce.name,
MaceOp.ResizeBilinear.name: ApuOp.ResizeBilinear.name,
MaceOp.Reshape.name: ApuOp.Reshape.name,
MaceOp.Softmax.name: ApuOp.Softmax.name,
MaceOp.Squeeze.name: ApuOp.Squeeze.name,
}
def has_op(self, op_name):
return op_name in self.apu_ops
def map_nn_op(self, op_name):
if op_name not in self.apu_ops:
raise Exception('Could not map nn op for: ', op_name)
return self.apu_ops[op_name]
class ApuConverter(base_converter.ConverterInterface):
def __init__(self, option, model, quantize_activation_info):
self._option = option
self._model = model
self._apu_ops = ApuOps()
def run(self):
self.use_uint8_in_out()
self.add_op_output_type()
self.ensure_bias_vector()
self.common_check()
if ConverterUtil.get_arg(self._model.op[0],
MaceKeyword.mace_framework_type_str).i == \
FrameworkType.TENSORFLOW.value:
self.add_tensorflow_padding_value()
const_data_num_arg = self._model.arg.add()
const_data_num_arg.name = MaceKeyword.mace_const_data_num_arg_str
const_data_num_arg.i = len(self._model.tensors)
self.convert_ops()
self.add_node_id()
return self._model
def common_check(self):
for op in self._model.op:
mace_check(len(op.input) >= 1,
op.name + ': apu does not support op with 0 input')
mace_check(len(op.output) == 1,
op.name + ': apu only support single output op')
mace_check(len(op.output) == len(op.output_shape),
op.name + ': length of output and output_shape not'
' match')
mace_check(len(op.output_shape[0].dims) <= 4,
op.name + ': apu only support 1D~4D tensor')
mace_check(len(op.output) == len(op.quantize_info),
op.name + ': length of output and quantize_info not'
' match')
data_format = ConverterUtil.data_format(op)
if data_format is not None and len(op.output_shape[0].dims) == 4:
mace_check((data_format == DataFormat.NHWC)
or (data_format == DataFormat.AUTO),
op.name + ': apu only support 4D tensor with NHWC'
' or AUTO format but find ' + str(data_format))
act_mode_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_activation_type_str)
if act_mode_arg is not None:
mace_check(act_mode_arg.s == b'RELU'
or act_mode_arg.s == b'RELUX',
op.name + ': apu only support activation RELU and'
' RELUX')
for tensor in self._model.tensors:
mace_check(len(tensor.dims) <= 4,
tensor.name + ': apu only support 1D~4D tensor')
for input_info in self._model.input_info:
mace_check(len(input_info.dims) <= 4,
input_info.name + ': apu only support 1D~4D tensor')
mace_check(input_info.data_type == mace_pb2.DT_FLOAT,
input_info.name + ': apu only support float input')
if len(input_info.dims) == 4:
mace_check(input_info.data_format == DataFormat.NHWC.value,
input_info.name + ': apu only support 4D tensor'
' with NHWC format')
def convert_ops(self):
print("Convert mace graph to apu.")
for op in self._model.op:
if not self._apu_ops.has_op(op.type):
raise Exception('Unsupported op: ', op)
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
mace_check(len(op.input) == 3,
op.name + ': apu only support ' + op.type + ' op'
' with 3 input')
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_strides_str)
self.add_padding_tensor_from_arg(op)
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_dilations_str)
if op.type == MaceOp.DepthwiseConv2d.name:
multiplier = self._model.tensors.add()
multiplier.name = op.name + '/multiplier:0'
multiplier.data_type = mace_pb2.DT_INT32
multiplier.dims.extend([1])
for tensor in self._model.tensors:
if tensor.name == op.input[1]:
multiplier.int32_data.extend([tensor.dims[0]])
break
op.input.extend([multiplier.name])
elif op.type == MaceOp.Eltwise.name:
mace_check(len(op.input) == 2,
op.name + ': apu only support eltwise op with 2'
' input')
eltwise_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_element_type_str).i
mace_check(eltwise_type == EltwiseType.SUM.value,
op.name + ': apu only support eltwise type SUM')
elif op.type == MaceOp.Pooling.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support pooling op with 1'
' input')
pooling_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_pooling_type_str)
mace_check(PoolingType(pooling_type_arg.i) == PoolingType.AVG,
op.name + ': apu only support pooling type AVG')
self.add_padding_tensor_from_arg(op)
self.add_size_tensor_from_arg(
op, MaceKeyword.mace_strides_str)
self.add_size_tensor_from_arg(op, MaceKeyword.mace_kernel_str)
elif op.type == MaceOp.Concat.name:
self.add_int_tensor_from_arg(op, MaceKeyword.mace_axis_str)
elif op.type == MaceOp.Reduce.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support reduce op with 1'
' input')
self.add_int_list_tensor_from_arg(
op, MaceKeyword.mace_axis_str)
self.add_int_tensor_from_arg(
op, MaceKeyword.mace_keepdims_str)
elif op.type == MaceOp.ResizeBilinear.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support resize bilinear op'
' with 1 input')
self.add_int_tensor_from_arg(
op, MaceKeyword.mace_align_corners_str)
elif op.type == MaceOp.Reshape.name:
mace_check(len(op.input) == 1 or len(op.input) == 2,
op.name + ': apu only support reshape op with 1 or'
' 2 input')
elif op.type == MaceOp.Softmax.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support softmax op with 1'
' input')
beta_value_tensor = self._model.tensors.add()
beta_value_tensor.name = op.name + '/beta:0'
beta_value_tensor.data_type = mace_pb2.DT_FLOAT
beta_value_tensor.dims.extend([1])
beta_value_tensor.float_data.extend([1.0])
op.input.extend([beta_value_tensor.name])
elif op.type == MaceOp.Squeeze.name:
mace_check(len(op.input) == 1,
op.name + ': apu only support squeeze op with 1'
' input')
self.add_int_list_tensor_from_arg(
op, MaceKeyword.mace_axis_str)
op.type = self._apu_ops.map_nn_op(op.type)
def add_op_output_type(self):
type_map = {}
for input_info in self._model.input_info:
# will do input quantize in wrapper
type_map[input_info.name] = mace_pb2.DT_UINT8
for op in self._model.op:
if len(op.output_type) >= 1:
print([op.name, len(op.output), len(op.output_type)])
type_map[op.output[0]] = op.output_type[0]
continue
mace_check(op.input[0] in type_map,
op.input[0] + ' not in type_map')
op.output_type.extend([type_map[op.input[0]]])
type_map[op.output[0]] = op.output_type[0]
for op in self._model.op:
mace_check(len(op.output) == len(op.output_type),
op.name + ': length of output and output_type not'
' match')
mace_check(op.output_type[0] == mace_pb2.DT_UINT8
or op.output_type[0] == mace_pb2.DT_INT32,
op.name + ': apu only support quantized node')
def add_node_id(self):
node_id_counter = 0
node_id_map = {}
for tensor in self._model.tensors:
tensor.node_id = node_id_counter
node_id_counter += 1
node_id_map[tensor.name] = tensor.node_id
for input_info in self._model.input_info:
input_info.node_id = node_id_counter
node_id_counter += 1
node_id_map[input_info.name] = input_info.node_id
for op in self._model.op:
op.node_id = node_id_counter
node_id_counter += 1
node_id_map[op.output[0]] = op.node_id
for op in self._model.op:
del op.node_input[:]
for input_tensor in op.input:
node_input = op.node_input.add()
node_input.node_id = node_id_map[input_tensor]
for output_info in self._model.output_info:
output_info.node_id = node_id_map[output_info.name]
def add_padding_tensor_from_arg(self, op):
padding_value_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str)
mace_check(len(padding_value_arg.ints) == 4,
op.name + ': padding value does not have size 4')
padding_value_tensor = self._model.tensors.add()
padding_value_tensor.name = op.name + '/padding:0'
padding_value_tensor.data_type = mace_pb2.DT_INT32
padding_value_tensor.dims.extend([4])
padding_value_tensor.int32_data.extend(padding_value_arg.ints)
op.input.extend([padding_value_tensor.name])
def add_size_tensor_from_arg(self, op, keyword):
size_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(len(size_value_arg.ints) == 2,
op.name + ': ' + keyword + ' value does not have size 2')
size_value_tensor = self._model.tensors.add()
size_value_tensor.name = op.name + '/' + keyword + ':0'
size_value_tensor.data_type = mace_pb2.DT_INT32
size_value_tensor.dims.extend([2])
size_value_tensor.int32_data.extend(size_value_arg.ints)
op.input.extend([size_value_tensor.name])
def add_int_tensor_from_arg(self, op, keyword):
int_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(int_value_arg.i is not None,
op.name + ': ' + keyword + ' value i should not be None')
int_value_tensor = self._model.tensors.add()
int_value_tensor.name = op.name + '/' + keyword + ':0'
int_value_tensor.data_type = mace_pb2.DT_INT32
int_value_tensor.dims.extend([1])
int_value_tensor.int32_data.extend([int_value_arg.i])
op.input.extend([int_value_tensor.name])
def add_int_list_tensor_from_arg(self, op, keyword):
list_value_arg = ConverterUtil.get_arg(op, keyword)
mace_check(list_value_arg.ints is not None,
op.name + ': ' + keyword + ' value ints should not be None')
list_value_tensor = self._model.tensors.add()
list_value_tensor.name = op.name + '/' + keyword + ':0'
list_value_tensor.data_type = mace_pb2.DT_INT32
list_value_tensor.dims.extend([len(list_value_arg.ints)])
list_value_tensor.int32_data.extend(list_value_arg.ints)
op.input.extend([list_value_tensor.name])
def add_tensorflow_padding_value(self):
for op in self._model.op:
padding_type = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str)
if padding_type is None:
continue
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
if padding_type.i == PaddingMode.VALID.value:
padding_arg.ints.extend([0, 0, 0, 0])
elif padding_type.i == PaddingMode.SAME.value:
stride = ConverterUtil.get_arg(
op, MaceKeyword.mace_strides_str).ints
kernel = []
dilation = [1, 1]
if op.type == MaceOp.Conv2D.name or \
op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg(
op, MaceKeyword.mace_dilations_str) is not None:
dilation = ConverterUtil.get_arg(
op, MaceKeyword.mace_dilations_str).ints
for tensor in self._model.tensors:
if tensor.name == op.input[1]:
kernel = tensor.dims[1:3]
break
else:
kernel = ConverterUtil.get_arg(
op, MaceKeyword.mace_kernel_str).ints
in_size = []
for input_info in self._model.input_info:
if input_info.name == op.input[0]:
in_size = input_info.dims[1:3]
break
for _op in self._model.op:
for out in _op.output:
if out == op.input[0]:
in_size = _op.output_shape[0].dims[1:3]
break
if len(in_size) > 0:
break
out_size = op.output_shape[0].dims[1:3]
h = (out_size[0] - 1) * stride[0] \
+ ((kernel[0] - 1) * dilation[0] + 1) - in_size[0]
w = (out_size[1] - 1) * stride[1] \
+ ((kernel[1] - 1) * dilation[1] + 1) - in_size[1]
top = int(np.floor(h/2))
left = int(np.floor(w/2))
bottom = h - top
right = w - left
padding_arg.ints.extend([top, right, bottom, left])
def ensure_bias_vector(self):
for _op in self._model.op:
if _op.type != MaceOp.Conv2D.name and \
_op.type != MaceOp.DepthwiseConv2d.name:
continue
if len(_op.input) != 2:
continue
tensor = self._model.tensors.add()
tensor.name = _op.name + '/add/bias_add'
tensor.dims.extend([_op.output_shape[0].dims[-1]])
if _op.output_type[0] == mace_pb2.DT_UINT8:
tensor.data_type = mace_pb2.DT_INT32
input_name = _op.input[0]
for input_op in self._model.op:
if input_op.output[0] == input_name:
scale_input = input_op.quantize_info[0].scale
break
filter_name = _op.input[1]
for filter_tensor in self._model.tensors:
if filter_tensor.name == filter_name:
scale_filter = filter_tensor.scale
break
tensor.scale = scale_input * scale_filter
tensor.zero_point = 0
tensor.quantized = True
tensor.int32_data.extend([0] * tensor.dims[0])
elif _op.output_type[0] == mace_pb2.DT_FLOAT:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend([0.0] * tensor.dims[0])
_op.input.extend([tensor.name])
def use_uint8_in_out(self):
for input_info in self._model.input_info:
if input_info.data_type == mace_pb2.DT_FLOAT:
for op in self._model.op:
if op.input[0] == input_info.name \
and op.type == MaceOp.Quantize.name:
input_info.name = op.output[0]
input_info.scale = op.quantize_info[0].scale
input_info.zero_point = op.quantize_info[0].zero_point
break
self._model.op.remove(op)
for output_info in self._model.output_info:
if output_info.data_type == mace_pb2.DT_FLOAT:
for op in self._model.op:
if op.output[0] == output_info.name \
and op.type == MaceOp.Dequantize.name:
output_info.name = op.input[0]
break
self._model.op.remove(op)
......@@ -23,6 +23,7 @@ class DeviceType(Enum):
GPU = 2
HEXAGON = 3
HTA = 4
APU = 5
class DataFormat(Enum):
......@@ -271,6 +272,7 @@ class MaceKeyword(object):
mace_pad_type_str = 'pad_type'
mace_exclusive_str = 'exclusive'
mace_reverse_str = 'reverse'
mace_const_data_num_arg_str = 'const_data_num'
class TransformerRule(Enum):
......@@ -518,6 +520,9 @@ class ConverterOption(object):
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
# For StoB -> conv -> BtoS -> BN pattern
# Insert flatten_atrous_conv before fold_xxx_and_bn
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DECONV_AND_BN,
......
......@@ -115,6 +115,7 @@ TFSupportedOps = [
'ArgMax',
'Split',
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs',
'FloorDiv',
'Sqrt',
'MirrorPad',
......@@ -261,6 +262,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
......@@ -1034,10 +1036,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
min_arg = op.arg.add()
min_arg.name = 'min'
min_arg.f = tf_op.inputs[1].eval()
max_arg = op.arg.add()
max_arg.name = 'max'
max_arg.f = tf_op.inputs[2].eval()
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
min_arg.f = tf_op.inputs[1].eval()
max_arg.f = tf_op.inputs[2].eval()
elif tf_op.type == TFOpType.FakeQuantWithMinMaxArgs.name:
min_arg.f = float(tf_op.get_attr('min'))
max_arg.f = float(tf_op.get_attr('max'))
narrow_range_arg = op.arg.add()
narrow_range_arg.name = 'narrow_range'
narrow_range_arg.i = int(tf_op.get_attr('narrow_range'))
......@@ -1045,8 +1051,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_bits_arg.name = 'num_bits'
num_bits_arg.i = int(tf_op.get_attr('num_bits'))
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
if tf_op.type == TFOpType.FakeQuantWithMinMaxVars.name:
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
def convert_cumsum(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -654,6 +654,7 @@ class Transformer(base_converter.ConverterInterface):
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -722,6 +723,7 @@ class Transformer(base_converter.ConverterInterface):
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -778,6 +780,7 @@ class Transformer(base_converter.ConverterInterface):
# remove bn
del consumer_op.input[:]
net.tensors.remove(scale)
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -874,7 +877,8 @@ class Transformer(base_converter.ConverterInterface):
return False
def flatten_atrous_conv(self):
if self._option.device != DeviceType.GPU.value:
if self._option.device != DeviceType.GPU.value \
and self._option.device != DeviceType.APU.value:
return
net = self._model
......@@ -1070,7 +1074,8 @@ class Transformer(base_converter.ConverterInterface):
transposed_deconv_filter = set()
if self._option.quantize and \
self._option.device == DeviceType.CPU.value:
(self._option.device == DeviceType.CPU.value or
self._option.device == DeviceType.APU.value):
print("Transpose filters to OHWI")
if filter_format == DataFormat.HWIO:
transpose_order = [3, 0, 1, 2]
......@@ -1082,7 +1087,9 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if (op.type == MaceOp.Conv2D.name or
op.type == MaceOp.Deconv2D.name) and\
op.type == MaceOp.Deconv2D.name or
(op.type == MaceOp.DepthwiseConv2d.name and
self._option.device == DeviceType.APU.value)) and\
op.input[1] not in transposed_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
......@@ -1572,7 +1579,8 @@ class Transformer(base_converter.ConverterInterface):
if len(ops[0].input) >= 4:
check_deconv = ops[0].input[3] == tensor.name
if check_conv or check_deconv:
if self._option.device == DeviceType.CPU.value:
if self._option.device == DeviceType.CPU.value \
or self._option.device == DeviceType.APU.value:
conv_op = ops[0]
scale_input = self._quantize_activation_info[
conv_op.input[0]].scale
......@@ -1648,13 +1656,16 @@ class Transformer(base_converter.ConverterInterface):
net = self._model
for op in net.op:
if op.type == 'FakeQuantWithMinMaxVars':
if op.type == 'FakeQuantWithMinMaxVars' or \
op.type == 'FakeQuantWithMinMaxArgs':
producer_op = self._producer[op.input[0]]
minval = ConverterUtil.get_arg(op, 'min').f
maxval = ConverterUtil.get_arg(op, 'max').f
quantize_info = \
self.add_quantize_info(producer_op, minval, maxval)
self._quantize_activation_info[op.input[0]] = quantize_info
# for add -> fakequant pattern
self._quantize_activation_info[op.output[0]] = quantize_info
op.type = MaceOp.Identity.name
return False
......
......@@ -78,6 +78,8 @@ DeviceType ParseDeviceType(const std::string &device_str) {
return DeviceType::HEXAGON;
} else if (device_str.compare("HTA") == 0) {
return DeviceType::HTA;
} else if (device_str.compare("APU") == 0) {
return DeviceType::APU;
} else {
return DeviceType::CPU;
}
......@@ -141,7 +143,7 @@ DEFINE_string(model_data_file,
DEFINE_string(model_file,
"",
"model file name, used when load mace model in pb");
DEFINE_string(device, "GPU", "CPU/GPU/HEXAGON");
DEFINE_string(device, "GPU", "CPU/GPU/HEXAGON/APU");
DEFINE_int32(round, 1, "round");
DEFINE_int32(restart_round, 1, "restart round");
DEFINE_int32(malloc_check_cycle, -1, "malloc debug check cycle, -1 to disable");
......
// Copyright 2019 MediaTek Inc. All rights reserved.
#pragma once
enum apu_act_mode {
APU_ACT_NONE = 0,
APU_ACT_RELU = 1,
APU_ACT_RELU6 = 2,
};
enum apu_pooling_mode {
APU_POOLING_UNDEFINED = 0,
APU_POOLING_AVG = 1,
APU_POOLING_MAX = 2,
};
enum apu_eltwise_mode {
APU_ELTWISE_UNDEFINED = 0,
APU_ELTWISE_ADD = 1,
APU_ELTWISE_SUB = 2,
APU_ELTWISE_MUL = 3,
APU_ELTWISE_MIN = 4,
APU_ELTWISE_MAX = 5,
};
enum apu_data_type {
APU_DATA_TYPE_UNDEFINED = 0,
APU_DATA_TYPE_FLOAT = 1,
APU_DATA_TYPE_UINT8 = 2,
APU_DATA_TYPE_HALF = 3,
APU_DATA_TYPE_INT32 = 4,
};
enum apu_tensor_type {
APU_TENSOR_UNDEFINED = 0,
APU_TENSOR_CONST_DATA = 1,
APU_TENSOR_CONST_ARGUMENT = 2,
APU_TENSOR_MODEL_INPUT = 3,
APU_TENSOR_OP_OUTPUT = 4,
};
#define APU_TENSOR_MAX_DIMS 4
struct apu_tensor {
int tensor_id;
apu_tensor_type tensor_type;
apu_data_type data_type;
float scale;
int zero_point;
int dims[APU_TENSOR_MAX_DIMS];
int dim_size;
void* data_buf;
};
#define APU_OP_TYPE_MAX_SIZE 32
struct apu_operator {
char type[APU_OP_TYPE_MAX_SIZE];
int* input_ids;
int input_size;
apu_tensor output;
int op_mode; // for pooling and eltwise
apu_act_mode act_mode;
};
class ApuFrontend {
public:
ApuFrontend();
~ApuFrontend();
bool InitGraph(int const_tensor_size, const apu_tensor* const_tensors,
int input_tensor_size, const apu_tensor* input_tensors,
int output_tensor_size, const int* output_tensor_ids,
void** output_buffers,
int operator_size, const apu_operator* operators,
bool print_model);
bool RunGraph();
bool UninitGraph();
private:
class Impl;
ApuFrontend::Impl* impl;
};
licenses(["notice"])
exports_files(["license.txt"])
cc_library(
name = "libapu-frontend",
srcs = [
"libapu-frontend.so",
],
hdrs = glob(["*.h"]),
copts = ["-DNN_TARGET_NDK"],
linkopts = ["-Wl,-unresolved-symbols=ignore-in-shared-libs"],
visibility = ["//visibility:public"],
)
License Agreement
PLEASE CAREFULLY READ ALL OF THE TERMS AND CONDITIONS SET FORTH IN THIS License Agreement ("AGREEMENT") BEFORE YOU ("YOU") ACCESS AND/OR USE the Software Package (as defined below) and/or Documentation (as defined below) from this website or other MediaTek Inc. ("MediaTek") website (collectively this "Website"). ANY ACCESS AND/OR USE OF THE SOFTWARE PACKAGE AND DOCUMENTATION ARE SUBJECT TO THE TERMS AND CONDITIONS SET FORTH IN THIS AGREEMENT. BY ACCESSING OR USING ANY PART OF THE SOFTWARE PACKAGE AND/OR DOCUMENTATION, YOU ACCEPT AND AGREE (ON BEHALF OF YOURSELF AND/OR YOUR COMPANY OR ORGANIZATION) TO BE BOUND BY THE TERMS AND CONDITIONS OF THIS AGREEMENT, WHICH THEN COMMENCES WITH EFFECT AS A LEGAL AGREEMENT BETWEEN YOU AND/OR YOUR COMPANY OR ORGANIZATION (AS APPLICABLE) AND MEDIATEK. IF YOU DO NOT OR CANNOT AGREE TO THE TERMS AND CONDITIONS OF THIS AGREEMENT, YOU MUST NOT ACCESS, OR USE THE SOFTWARE PACKAGE AND DOCUMENTATION.
MEDIATEK RESERVES ITS RIGHT, AT ANY TIME, TO CHANGE OR MODIFY THE TERMS AND CONDITIONS OF THIS AGREEMENT BY POSTING NEW OR REVISED TERMS AND CONDITIONS TO THIS WEBSITE. IF YOU DO NOT AGREE TO THE NEW OR MODIFIED TERMS AND CONDITIONS OF THIS AGREEMENT, YOU MAY NOT CONTINUE TO USE OR ACCESS THE SOFTWARE PACKAGE AND DOCUMENTATION. ANY ACCESS AND/OR USE OF THE SOFTWARE PACKAGE AND/OR DOCUMENTATION AFTER MEDIATEK POSTS NEW OR MODIFIED TERMS AND CONDITIONS INDICATES THAT YOU ACCEPT THE NEW OR MODIFIED TERMS AND CONDITIONS.
YOU ACKNOWLEDGE AND AGREE THAT, FROM TIME TO TIME, SOME OF THE SOFTWARE PACKAGE OR DOCUMENTATION ("MATERIALS") YOU DOWNLOAD OR HAVE ACCESS TO FROM THIS WEBSITE MAY CONTAIN AND BE SUBJECT TO THIRD PARTY OR OPEN SOURCE SOFTWARE LICENSES ("THIRD PARTY LICENSES"). BY DOWNLOADING, ACCESSING OR USING SUCH MATERIALS, YOU AGREE TO ABIDE BY THE TERMS AND CONDITIONS OF SUCH THIRD PARTY LICENSES. IN THE EVENT OF CONFLICTS BETWEEN THE TERMS AND CONDITIONS OF THIS AGREEMENT AND THE TERMS AND CONDITIONS OF THE THIRD PARTY LICENSES CONTAINED IN ANY MATERIALS, THE TERMS AND CONDITIONS OF SUCH THIRD PARTY LICENSES SHALL PREVAIL BUT ONLY WITH RESPECT TO THE MATERIALS CONTAINING SUCH THIRD PARTY LICENSES.
1. Definitions
1.1 “Affiliate” means a corporation, company or other entity which: (a) is controlled by MediaTek; (b) controls MediaTek; or (c) is under common control with MediaTek. For the purpose of this definition, “control” means that more than fifty percent (50%) of the shares or ownership interest representing the voting right for the election of directors or persons performing similar functions for such a corporation, company or entity are owned or controlled, directly or indirectly, by the controlling entity. Such corporation, company or entity shall be deemed to be an Affiliate so long as such ownership or control exists.
1.2 “Application” means a software or hardware developed by You using the Software Package or Documentation for specific use with the devices which incorporate MediaTek’s chipsets and/or system and under Your own trademark and/or brand, including, in respect of such software programs or hardware, all bug fixes, enhancements, modifications, new releases, new versions, revisions, supplements, updates and upgrades.
1.3 “Confidential Information” has the meaning given in Section 5.1.
1.4 “Documentation” means any technical specifications, development guideline, hardware schematics, hardware diagrams, technical layout and other specifications or documentation that MediaTek may make available or provide to You from this Website relating to or for use in connection with the Software Package or MediaTek’s chipsets.
1.5 “ MediaTek” means MediaTek Inc., a company organized and existing under the laws of the Republic of China, having its principal office at No. 1, Dusing Road 1, Science-based Industrial Park, Hsin-Chu City, Taiwan, R.O.C.
1.6 “Open Source Software” means any software or software component, module or package that contains, or is derived in any manner (in whole or in part) from, any software that is distributed as free software, open source software or similar licensing or distribution models, including, without limitation, software licensed or distributed under any of the following licenses or distribution models, or licenses or distribution models similar to any of the following: (a) GNU’s General Public License (GPL) or Lesser/Library GPL (LGPL); (b) the Artistic License (e.g., PERL); (c) the Mozilla Public License; (d) the Netscape Public License; (e) the Sun Community Source License (SCSL); (f) the Sun Industry Standards License (SISL); (g) the BSD License; and (h) the Apache License.
1.7 “Software Package” means the APIs (Application Programming Interface), applications, data, files, libraries, materials, IDE (Integrated Development Environment), sample code, software (source code and object code), simulators, and tools provided or made available to You from this Website for use in connection with the development of Applications, including any Updates that MediaTek may provide or make available.
1.8 “Updates” means, in respect of the Software Package or Documentation or any part of the Software Package or Documentation, bug fixes, enhancements, modifications, new releases, new versions, supplements, updates or upgrades.
1.9 “You” means the person(s) or entity using the Software Package and/or Documentation, or otherwise exercising rights under this Agreement.
2. Term
The term of this Agreement commences at the time it becomes effective in the manner described above and continues until terminated by either party in accordance with Section 9 ("Term").
3. Grant of License and Restrictions
3.1 Subject to the provisions of this Agreement, MediaTek grants and You accept, a limited, non-exclusive, non-transferable, non-sublicensable, and terminable (under Section 9 hereof ) license, under MediaTek’s intellectual property rights in and to the Software and Documentation, during the Term to:
(a) install a reasonable number of copies of the Software Package on computers that You owns, for use internally solely for the purpose of developing or testing Applications;
(b) use and copy the Software Packagesolely for the purpose of internally developing or testing Application for specific use with the devices which incorporates MediaTek’s chipsets (“Licensed Product”) ; and
(c) make a reasonable number of copies of the Documentation for use internally and solely for the purpose of developing or testing Applications for specific use with the Licensed Product.
Each copy will include all notices and legends embedded in the Software Package and Documentation.
3.2 You must ensure that MediaTek’s or any third party’s copyright disclaimers and other proprietary notices that appear in the Software Package and Documentation are retained and reproduced in full in all copies of the Software Package and Documentation that You make as permitted under this Agreement.
3.3 You must not sell, redistribute, rent, lease, lend all or any part of the Software Package and Documentation, or enable or allow others to do such things except as explicitly set forth in Sections 3.1(d) or 3.1(e) or in the Third Party Licenses. You must not use the Software Package and Documentation for any purpose that is not expressly permitted under this Agreement. As a condition to the license granted in Section 3.1 above, You shall not (and shall not allow any third party to) decompile, disassemble, reverse engineer or attempt to reconstruct, identify or discover any source code, underlying ideas, underlying algorithms of the Software Package provided to You in object code form by any means whatsoever, or disclose any of the foregoing, except to the extent such restriction is expressly prohibited by applicable laws and not waivable thereunder.
3.4 Except for the limited license granted to You in this Agreement, all rights, title, and interest in and to the Software Package and Documentation that are made available to You under this Agreement remain, at all times, the sole and exclusive property of MediaTek or, for third party software contained in the Software Package, such third party. You agree to cooperate with MediaTek to maintain MediaTek and such thirty party's ownership of the Software Package and Documentation, and You agree to promptly provide notice of any claims or threatened claims relating to the Software Package and Documentation. Apart from the license rights expressly set out in this Agreement, MediaTek does not grant to You and You does not receive, whether by implication, estoppel or otherwise, any ownership right, title or interest nor any security interest or other interest in any intellectual property rights relating to the Software Package and Documentation, nor in any copy of any part of the foregoing, nor any other licenses, immunity or rights, express or implied.
3.5 The foregoing is subject always to, among other things, the following consideration from You:
(a) You hereby grant MediaTek, its Affiliates, and/or subcontractor a non-exclusive, non-transferable, world-wide, perpetual, royalty-free, fully paid up, license, under Your intellectual property rights related thereto, to use, copy, modify, distribute and exploit in all possible ways such Modification. For the avoidance of doubt, You are not obliged to deliver the Modification to MediaTek; and
(b) You hereby covenant to defend, indemnify and hold MediaTek, its Affiliates and/or subcontractors harmless from any and all claims, losses and damages (including without limitation reasonable attorneys fees) arising from the Modification violating or infringing any third party intellectual property or proprietary rights.
3.6 Except as otherwise permitted under this Agreement, nothing in this Agreement grants You any right to use any of MediaTek’s trademarks, trade names, copyrights, service marks, logos, domain names, patents, trade secrets, other brand features distinctive to MediaTek and/or other intellectual property, which remain, at all times, the sole and exclusive property of MediaTek.
3.7 MediaTek may, at any time without notice, extend, enhance, or otherwise modify the Software Package and Documentation. If MediaTek makes available Updates, such Updates will be governed by this Agreement (unless a separate license is provided with the Updates, in which case the terms of that license will govern the Updates). You acknowledge that MediaTek has no obligation, whether express or implied, to announce or make available any Updates.
4. Conditions and Requirements
4.1 You acknowledge and agree that the Applications shall comply with the conditions and requirements set out below, as modified by MediaTek from time to time:
(a) Applications shall comply with all applicable laws and regulations (including the laws and regulations of any jurisdiction in which the Applications are offered or made available). You must not design or market Applications for the purpose of violating any legal rights of any person or legal entity (including but not limited to privacy rights).
(b) You shall also ensure that the Applications do not and will not violate, misappropriate, or infringe any copyright, patent, design, trademark, trade secret, privacy or publicity rights, or any proprietary, intellectual property or other legal right of MediaTek or any third party and the embodying of such content in any Application, does not infringe upon any proprietary or intellectual property rights of any third party. You should be solely responsible for such third party contents embodied in the Application and You agree and acknowledge that MediaTek has no responsibility or liability relating to such third party contents.
(c) Applications must not contain content or materials of any kind (including, but not limited to, text, graphics, images, photographs, sounds, etc.) that are illegal or objectionable (for example, materials that may be considered obscene, pornographic or defamatory).
(d) Applications must not contain any material, component or code which could damage, destroy, unduly burden or adversely affect software, firmware, hardware, data, systems, services, or networks.
(e) If an Application or Licensed Product includes any Open Source Software, You must comply with all licensing terms applicable to such Open Source Software. However, You shall separate the portion of the Open Source Software from the portion of the Software Package in the Application or Licensed Product and shall not cause the portion of the Software Package in the Application or Licensed Product to be subject to the licensing terms applicable to such Open Source Software.
4.2 You acknowledge and agree that use of the Software Package and Documentation is subject to the following conditions:
(a) You will only use the Software Package and Documentation for the purposes and in the manner expressly permitted under this Agreement;
(b) You will not use the Software Package and Documentation for any unlawful or illegal activity;
(c) You will not develop any Application that would constitute or facilitate the commission of any crime, or any tortious, unlawful, or illegal act;
(d) You will develop any Application in compliance with the Documentation and all other requirements set out in this Agreement;
(e) You will not create or enable others to create, whether by using the Software Package or otherwise, any Application or other program that would disable, hack or otherwise interfere with any authentication, content protection, digital signing, digital rights management, security or verification mechanisms implemented in or by the Software Package, or other MediaTek’s software, services or technology;
(f) You agree that any information or technology that You provide or disclose to MediaTek in connection with this Agreement, including without limitation, information about the Applications, will be freely used and disclosed by MediaTek without restriction and without notifying or providing compensation to You. You release MediaTek from all liabilities and obligations that may arise from the receipt, review, use, or disclosure of such information or any part of it. Any physical materials You submit to MediaTek will become the property of MediaTek and MediaTek is not obliged to return such materials to You or to certify the destruction of such materials; and
(g) You agree that You are solely responsible (and that MediaTek has no liability or responsibility to You or to any third party) for any breach of Your covenants and obligations under this Agreement or any applicable laws and regulations, or for the consequences of any such breach (including any loss or damage which MediaTek or any third party may suffer).
4.3. You shall warrant that You have full power and authority to enter into this Agreement.
5. Confidentiality and Press Release
5.1 You acknowledge and agree that the Software Package (including all test versions of the Software Package), Documentation and any other information provided by MediaTek constitute "Confidential Information" for the purposes of this Agreement, unless MediaTek expressly indicates otherwise. Notwithstanding the foregoing, Confidential Information does not include: (a) information that is generally and legitimately available to the public through no fault or breach by You, (b) information that MediaTek makes generally made available to the public without restriction, (c) information that You independently develop without use of any Confidential Information, or (d) information that You lawfully obtain from a third party who had the right to transfer or disclose the information to You without limitation, the licensing terms of which do not contain obligations of confidentiality.
5.2 You must protect Confidential Information using a degree of care that is no less than that which You use to protect Your own confidential information of the same or similar importance (and in any event, no less than a reasonable degree of care). You may use Confidential Information solely for the purpose of exercising Your rights and performing Your obligations under this Agreement and You must not use Confidential Information for any other purpose, or for Your own or any third party’s benefit, without the prior written consent of MediaTek. You may disclose Confidential Information to the extent required by law, provided that You take reasonable steps to notify MediaTek of the relevant requirement prior to disclosing the Confidential Information and You take reasonable steps to obtain protective treatment against disclosure of the Confidential Information.
5.3 Upon mutually agreeing to the terms and method of issuance of a written announcement with MediaTek, You may issue a press release relating to the Application and the relationship of the parties.
6. DISCLAIMER OF WARRANTY
6.1 MEDIATEK MAKES NO WARRANTIES WITH RESPECT TO THE LICENSE GRANTED TO YOU HEREUNDER, WHETHER EXPRESSED, IMPLIED, STATUTORY OR OTHERWISE AND MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES OF ANY KIND, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF MERCHANTIBILITY, FITNESS FOR A PARTICULAR PUPOSE AND NON-INFRINGEMENT. MEDIATEK SHALL NOT BE RESPONSIBLE FOR ANY SOFTWARE PACKAGE AND/OR DOCUMENTATION RELEASED MADE TO YOUR SPECIFICATION OR CONFORMING TO A PARTICULAR STANDARD OR OPEN FORUM. FURTHER, MEDIATEK DOES NOT REPRESENT OR WARRANT THAT ANY PORTION OF THE SOFTWARE PACKAGE AND DOCUMENTATION IS FREE OF INACCURACIES, ERRORS, BUGS OR INTERRUPTIONS, OR IS RELIABLE, ACCURATE, COMPLETE, OR OTHERWISE VALID. THE SOFTWARE PACKAGE AND DOCUMENTATION ARE PROVIDED "AS IS" AND "AS AVAILABLE", WITHOUT ANY WARRANTY OF ANY KIND FROM MEDIATEK. YOUR USE OF THE SOFTWARE PACKAGE AND/OR DOCUMENTATION IS AT YOUR OWN DISCRETION AND RISK.
6.2 YOU ACKNOWLEDGE THAT SOFTWARE PACKAGE AND DOCUMENTATION MAY BE SUBJECT TO IMPORT, EXPORT, AND/OR RE-EXPORT RESTRICTIONS UNDER THE LAWS AND REGULATIONS OF RELATED JURISDICTIONS. YOU SHALL NOT EXPORT, RE-EXPORT, IMPORT OR OTHERWISE SELL, TRANSFER, DIRECTLY OR INDIRECTLY, SOFTWARE PACKAGE AND/OR DOCUMENTATION ACQUIRED HEREUNDER EXCEPT IN STRICT COMPLIANCE WITH ALL SUCH APPLICABLE LAWS AND REGULATIONS. YOU EXPRESSLY AGREE THAT SOFTWARE PACKAGE AND DOCUMENTATION SHALL NOT BE DOWNLOADED, TRANSFERRED OR OTHERWISE EXPORTED/IMPORTED OR RE-EXPORTED INTO (OR TO A NATIONAL OR RESIDENT OF) ANY EMBARGOED COUNTRIES, NOR TO ANYONE ON RELATED DENIAL LISTS, INCLUDING BUT NOT LIMITED TO THE U.S. TREASURY DEPARTMENT’S LIST OF SPECIALLY DESIGNATED NATIONALS OR THE U.S. COMMERCE DEPARTMENT’S TABLE OF DENIAL ORDERS. YOU HEREBY REPRESENT AND WARRANT THAT YOU ARE NOT LOCATED IN, UNDER THE CONTROL OF, OR A NATIONAL OR RESIDENT OF, ANY SUCH COUNTRY, OR ON ANY SUCH LIST. WITHOUT LIMITING THE FOREGOING, YOU AGREE THAT SOFTWARE PACKAGE AND DOCUMENTATION PROVIDED HEREUNDER SHALL NOT BE EXPORTED, RE-EXPORTED, OR TRANSFERRED TO ANY END-USER ENGAGED IN ACTIVITIES, OR FOR ANY END-USE, DIRECTLY OR INDIRECTLY RELATED TO THE DESIGN, DEVELOPMENT, PRODUCTION, USE, OR STOCKPILING OF WEAPONS OF MASS DESTRUCTION (E.G., NUCLEAR, CHEMICAL, OR BIOLOGICAL WEAPONS, MILITARY AND THE MISSILE TECHNOLOGY TO DELIVER THEM). YOU FURTHER ACKNOWLEDGE AND AGREE THAT YOU WILL COOPERATE WITH MEDIATEK AND/OR RELATED APPLICABLE JURISDICTIONS TO PROVIDE ALL THE NECESSARY ASSISTANCE, INFORMATION AND DOCUMENTS TO PROVE YOUR COMPLIANCE WITH THIS SECTION.
6.3 You hereby acknowledge that the Software Package provided under this Agreement might include software from one or more third parties (e.g. open source or proprietary, collectively as "Third Party Software") and the use of such shall be in accordance with the terms and conditions of this Agreement unless otherwise specified in the third party software license agreement accompanying such Third Party Software. You expressly acknowledge that it is Your sole responsibility to obtain from any third party all proper licenses contained in the Software Package. NOTWITHSTANDING ANYTHING CONTAINED HEREIN TO THE CONTRARY, MEDIATEK HEREBY EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES, EXPRESS OR IMPLIED, TO THE EXTENT ALLOWED BY APPLICABLE LAWS, WITH RESPECT TO ANY THIRD PARTY SOFTWARE.
6.4 Notwithstanding anything contained herein to the contrary, You understand and acknowledge that the payment payable to MediaTek hereunder does not include royalties or fees payable based on adherence of any Application or Licensed Product to published standards, and any such fees are the sole responsibility of You and You have the sole responsibility to procure license of any intellectual property right for the Application or Licensed Product to comply with such published standards.
7. LIMITATION OF LIABILITY
MediaTek’s entire liability to You arising out of or in connection with a particular version of Software Package or Documentation shall not exceed the aggregate amount of license fee paid by You to MediaTek for such Software Package or Documentation. Notwithstanding the foregoing, MediaTek’s entire liability in the aggregate for its breach of the terms Agreement shall not exceed the aggregate amount of license fee paid by You to MediaTek for the twelve (12) months preceding the event giving rise to the first breach. TO THE FULLEST EXTENT ALLOWED AND PERMITTED BY APPLICABLE LAWS AND REGULATIONS, MEDIATEK SHALL NOT, UNDER ANY CIRCUMSTANCES, BE LIABLE TO YOU OR ANY THIRD PARTY THROUGH YOU FOR PERSONAL INJURY OR ANY CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, INDIRECT, PUNITIVE OR SPECIAL DAMAGES WHATSOEVER, INCLUDING, WITHOUT LIMITATION, DAMAGES FOR LOSS OF PROFITS, LOSS OF DATA, BUSINESS INTERRUPTION OR ANY OTHER COMMERCIAL DAMAGES OR LOSSES, ARISING OUT OF OR IN RELATION TO THIS AGREEMENT, YOUR USE OF THE SOFTWARE PACKAGE AND DOCUMENTATION, OR YOUR DEVELOPMENT OF APPLICATIONS, WHETHER BASED ON BREACH OF CONTRACT, BREACH OF WARRANTY, TORT (INCLUDING NEGLIGENCE), PRODUCT LIABILITY OR OTHERWISE, EVEN IF MEDIATEK HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES AND NOTWITHSTANDING THE FAILURE OF ESSENTIAL PURPOSE OF ANY REMEDY. YOUR UNDERSTSANDING, ACKNOWLEDGEMENT AND ACCEPTANCE OF THIS AGREEMENT ARE THE LEGAL BASIS AND CONSIDERATION FOR THE LICENSES GRANTED UNDER IT.
SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES OR LIMITATIONS ON APPLICABLE STATUTORY RIGHTS, SO THESE EXCLUSIONS AND LIMITATIONS MAY NOT APPLY TO YOU IN SUCH JURISDICTIONS TO THE EXTENT PROHIBITED BY RELEVANT MANDATORY LAWS.
8. Indemnity
8.1 To the fullest extent permitted by law, You agree to indemnify, defend and hold harmless MediaTek, its Affiliates, directors, officers, employees, independent contractors and agents (each an "Indemnified Party") from any and all claims, losses, liabilities, damages, expenses and costs (including without limitation reasonable attorneys fees) (collectively "Losses") incurred by a Indemnified Party as a result of Your breach of this Agreement, any claims that the Applications and/or Modifications violate or infringe any third party intellectual property or proprietary rights, or otherwise related to or arising from Your use of the Software Package, Documentation, or the Applications or Your development or distribution of Applications.
8.2 You acknowledge that the Software Package and Documentation are not intended to be used in the development of any Application or Licensed Products where death, personal injury, or severe physical or environmental damage could result from errors or inaccuracies in the content, data or information provided by the Application or Licensed Products or the Application or Licensed Products failing. To the extent permitted by law, You agree to indemnify, defend and hold harmless each Indemnified Party from any Losses incurred by such Indemnified Party as a result of Your use of the Software Package and/or Documentation in the development of any such Applications or Licensed Products.
8.3 You must not enter into a settlement or like agreement with any third party that affects MediaTek's rights or binds MediaTek in any way related to or arising from Your use of the Software Package or Documentation without the prior written consent of MediaTek.
9. Termination
9.1 Right to Terminate.
This Agreement and all rights granted by MediaTek hereunder will automatically terminate without notice:
(a) by MediaTek: (i) if You have breached any terms of this Agreement; or (ii) if MediaTek is required by law to terminate this Agreement or the rights granted by MediaTek hereunder;
(b) by either party for any reason or no reason upon thirty (30) days prior written notice to the other party.
MediaTek will have no liability to pay compensation or damages, or to provide an indemnity, of any kind as a result of terminating this Agreement in accordance with its terms, and termination of this Agreement is without prejudice to any other right or remedy that MediaTek may have, now or at any time in the future.
9.2 Consequences of Termination
Upon the termination of this Agreement: (a) all rights granted to You in this Agreement will terminate; (b) You shall promptly stop using the Software Package and Documentation and return the Software Package and Documentation to MediaTek or destroy all electronic copies of the Software Package and Documentation and provide written certification of such destruction to MediaTek. The provisions of Sections 1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 4, 5, 6, 7, 8, 9.2 and 10 will survive the expiration or termination of this Agreement.
10. General
10.1 Assignment.
You may not assign this Agreement, in whole or in part, without MediaTek’s prior written consent, and any attempt to do so without such consent shall be void. MediaTek may assign this Agreement without Your consent. This Agreement shall be binding upon and shall inure to the benefit of the parties hereto and their respective successors and permitted assigns.
10.2 Relationship of Between You and MediaTek.
This Agreement will not be construed as creating an agency, partnership, joint venture, fiduciary duty, or any other form of legal relationship between You and MediaTek, and You must not represent the existence of any such relationship, whether expressly, by implication or otherwise.
10.3 Development by MediaTek.
Nothing in this Agreement limits or otherwise affects MediaTek's right to acquire, develop, license, market, promote, or distribute any product or technology that performs the same or similar functions as the Applications or any other products or technologies that You develops, markets, promotes or distributes, or that otherwise competes with the Applications or such products or technologies.
10.4 Audit.
In consideration for determining Your compliance with its obligation under this Agreement, MediaTek shall have the right to audit relevant records containing information regarding Your exercise of Your right or performance of Your obligation during regular business hours.
10.5 Notices.
Any notices or other communication to be made to MediaTek pursuant to this Agreement must be made in writing and will be deemed to have been fully given or made when: (a) personally delivered; or (b) three (3) days after being mailed via commercially reputable overnight delivery service, to the following address:
MediaTek: MediaTek Inc.
No. 1, Dusing Rd. 1, Hsinchu Science Park, Hsinchu City, Taiwan 300, R.O.C.
You agree to receive notices and other communications to be made to You pursuant to this Agreement by posting on this Website and You agree that any notices that MediaTek posts on this Website will satisfy any legal communication requirements.
10.6 No Waiver.
Failure by MediaTek to insist upon strict performance of any of the provisions contained in this Agreement shall in no way constitute a waiver of MediaTek’s rights as set forth in this Agreement, at law or in equity, or a waiver of any other provisions or the right to take action in respect of a subsequent default by You in the performance or compliance with any of the terms and conditions set forth in this Agreement.
10.7 Remedies.
You acknowledge that any disclosure, use or misappropriation of Confidential Information of MediaTek in violation of this Agreement would cause MediaTek irreparable harm for which there may be no adequate remedy at law. Accordingly, You agree that MediaTek shall have the right to apply to any court of competent jurisdiction for injunctive relief and specific performance, without prejudice to any remedies otherwise available to MediaTek at law or in equity.
10.8 Governing Law.
This Agreement shall be governed by and construed in accordance with the laws of the Republic of Singapore, without regard to any conflict-of-laws rules.
10.9 Entire Agreement.
This Agreement contains the entire agreement between You and MediaTek with respect to the use of the Software Package and Documentation licensed hereunder and supersedes all existing agreements and all other oral, written or other communications between You and MediaTek concerning this subject matter. If any provision of this Agreement (or any portion thereof) is invalid, illegal or unenforceable, the validity, legality and enforceability of the remainder of this Agreement shall not be affected or impaired.
......@@ -19,6 +19,7 @@ mkdir -p $LIB_DIR/armeabi-v7a/cpu_gpu
rm -rf $LIB_DIR/arm64-v8a
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu_dsp
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu
mkdir -p $LIB_DIR/arm64-v8a/cpu_gpu_apu
rm -rf $LIB_DIR/linux-x86-64
mkdir -p $LIB_DIR/linux-x86-64
......@@ -50,6 +51,11 @@ echo "build shared lib for arm64-v8a + cpu_gpu"
bazel build --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define quantize=true --cpu=arm64-v8a
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm64-v8a/cpu_gpu/
echo "build shared lib for arm64-v8a + cpu_gpu_apu"
bazel build --config android --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define apu=true --define quantize=true --cpu=arm64-v8a
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
cp third_party/apu/libapu-frontend.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
echo "build shared lib for arm_linux_gnueabihf + cpu_gpu"
bazel build --config arm_linux_gnueabihf --config optimization mace/libmace:libmace_dynamic --define neon=true --define openmp=false --define opencl=true --define quantize=true
cp bazel-bin/mace/libmace/libmace.so $LIB_DIR/arm_linux_gnueabihf/cpu_gpu/
......@@ -83,6 +89,11 @@ echo "build static lib for arm64-v8a + cpu_gpu"
bazel build --config android --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define quantize=true --cpu=arm64-v8a
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm64-v8a/cpu_gpu/
echo "build static lib for arm64-v8a + cpu_gpu_apu"
bazel build --config android --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define apu=true --define quantize=true --cpu=arm64-v8a
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm64-v8a/cpu_gpu_apu/
cp third_party/apu/libapu-frontend.so $LIB_DIR/arm64-v8a/cpu_gpu_apu/
echo "build static lib for arm_linux_gnueabihf + cpu_gpu"
bazel build --config arm_linux_gnueabihf --config optimization mace/libmace:libmace_static --config symbol_hidden --define neon=true --define openmp=false --define opencl=true --define quantize=true
cp bazel-genfiles/mace/libmace/libmace.a $LIB_DIR/arm_linux_gnueabihf/cpu_gpu/
......
......@@ -130,6 +130,7 @@ class DeviceType(object):
GPU = 'GPU'
HEXAGON = 'HEXAGON'
HTA = 'HTA'
APU = 'APU'
class DataFormat(object):
......@@ -207,6 +208,8 @@ def parse_device_type(runtime):
device_type = DeviceType.GPU
elif runtime == RuntimeType.cpu:
device_type = DeviceType.CPU
elif runtime == RuntimeType.apu:
device_type = DeviceType.APU
return device_type
......@@ -520,6 +523,7 @@ class RuntimeType(object):
gpu = 'gpu'
dsp = 'dsp'
hta = 'hta'
apu = 'apu'
cpu_gpu = 'cpu+gpu'
......
......@@ -62,6 +62,7 @@ RuntimeTypeStrs = [
"gpu",
"dsp",
"hta",
"apu",
"cpu+gpu"
]
......@@ -89,6 +90,13 @@ DSPDataTypeStrs = [
DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
type=str)
APUDataTypeStrs = [
"uint8",
]
APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
type=str)
WinogradParameters = [0, 2, 4]
DataFormatStrs = [
......@@ -150,6 +158,8 @@ def parse_device_type(runtime):
device_type = DeviceType.GPU
elif runtime == RuntimeType.cpu:
device_type = DeviceType.CPU
elif runtime == RuntimeType.apu:
device_type = DeviceType.APU
return device_type
......@@ -361,6 +371,15 @@ def format_model_config(flags):
else:
model_config[YAMLKeyword.data_type] = \
DSPDataType.uint8.value
elif runtime == RuntimeType.apu:
if len(data_type) > 0:
mace_check(data_type in APUDataTypeStrs,
ModuleName.YAML_CONFIG,
"'data_type' must be in " + str(APUDataTypeStrs)
+ " for apu runtime")
else:
model_config[YAMLKeyword.data_type] = \
APUDataType.uint8.value
else:
if len(data_type) > 0:
mace_check(data_type in FPDataTypeStrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册