提交 d9275275 编写于 作者: 叶剑武

Merge branch 'micro' into 'master'

add mace micro

See merge request deep-computing/mace!1257
......@@ -22,6 +22,9 @@ mace/codegen/version/
mace/codegen/engine/
mace/codegen/lib/
micro/codegen/models/
micro/codegen/engines/
examples/android/macelibrary/src/main/cpp/mace/
examples/android/macelibrary/src/main/cpp/include/
examples/android/macelibrary/src/main/cpp/lib/arm64-v8a/
......
......@@ -80,12 +80,14 @@ mace_cc_test:
DEVICE_CONF_FILE=generic-mobile-devices/devices.yml
fi
- python tools/bazel_adb_run.py --target="//test/ccunit:mace_cc_test" --device_yml=${DEVICE_CONF_FILE} --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a,arm64 --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//micro/test/ccunit:micro_ops_test" --run_target=True --stdout_processor=ops_benchmark_stdout_processor --target_abis=arm64-v8a
mace_cc_benchmark:
stage: test
script:
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/bazel_adb_run.py --target="//test/ccbenchmark:mace_cc_benchmark" --run_target=True --stdout_processor=ops_benchmark_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS --args="--filter=.*SIGMOID.*"
- python tools/bazel_adb_run.py --target="//micro/test/ccbenchmark:micro_cc_benchmark" --run_target=True --stdout_processor=ops_benchmark_stdout_processor --target_abis=arm64-v8a
only:
- triggers
......@@ -112,6 +114,13 @@ model_tests:
- python tools/converter.py convert --config=${CONF_FILE} --target_socs=$TARGET_SOCS --model_graph_format=code --model_data_format=file
- python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --round=1 --validate --model_graph_format=code --model_data_format=file
- python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --round=5 --model_graph_format=code --model_data_format=file --benchmark
- CONF_FILE=mace-models/micro-models/har-cnn/har-cnn.yml
- python tools/converter.py convert --config=${CONF_FILE} --enable_micro
- python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_cnn
- python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build --benchmark
- CONF_FILE=mace-models/micro-models/har-cnn/har-cnn-bf16.yml
- python tools/converter.py convert --config=${CONF_FILE} --enable_micro
- python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_cnn
- rm -rf mace-models
quantization_tests:
......
......@@ -3,6 +3,7 @@ workspace(name = "mace")
# generate version and opencl kernel code.
load("//repository/git:git_configure.bzl", "git_version_repository")
load("//repository/opencl-kernel:opencl_kernel_configure.bzl", "encrypt_opencl_kernel_repository")
load("//micro:micro.bzl", "new_local_repository_env")
git_version_repository(name = "local_version_config")
......@@ -161,3 +162,15 @@ new_http_archive(
"https://releases.linaro.org/components/toolchain/binaries/7.3-2018.05/aarch64-linux-gnu/gcc-linaro-7.3.1-2018.05-x86_64_aarch64-linux-gnu.tar.xz",
],
)
new_local_repository_env(
name = "hexagon_sdk",
build_file = "third_party/hexagon/hexagon_sdk.BUILD",
path = "${HEXAGON_SDK_ROOT}",
)
new_local_repository_env(
name = "hexagon_tools",
build_file = "third_party/hexagon/hexagon_tools.BUILD",
path = "${HL_HEXAGON_TOOLS}",
)
......@@ -46,6 +46,13 @@ The main documentation is organized into the following sections:
development/data_format
development/dynamic_lstm
.. toctree::
:maxdepth: 1
:caption: Micro Controllers
:name: sec-micro
micro-controllers/basic_usage.rst
.. toctree::
:maxdepth: 1
:caption: FAQ
......
Basic usage for Micro Controllers
==================================
Build and run an example model
-------------------------------
At first, make sure the environment has been set up correctly already (refer to :doc:`../installation/env_requirement`).
The followings are instructions about how to quickly build and run a provided model in
`MACE Model Zoo <https://github.com/XiaoMi/mace-models>`__.
Here we use the har-cnn model as an example.
**Commands**
1. Pull `MACE <https://github.com/XiaoMi/mace>`__ project.
.. code-block:: sh
git clone https://github.com/XiaoMi/mace.git
cd mace/
git fetch --all --tags --prune
# Checkout the latest tag (i.e. release version)
tag_name=`git describe --abbrev=0 --tags`
git checkout tags/${tag_name}
.. note::
It's highly recommended to use a release version instead of master branch.
2. Pull `MACE Model Zoo <https://github.com/XiaoMi/mace-models>`__ project.
.. code-block:: sh
git clone https://github.com/XiaoMi/mace-models.git
3. Convert the pre-trained har-cnn model to c++ code.
.. code-block:: sh
cd path/to/mace
# output lib path: build/har-cnn/model/har_cnn_micro.tar.gz
CONF_FILE=/path/to/mace-models/micro-models/har-cnn/har-cnn.yml
python tools/converter.py convert --config=$CONF_FILE --enable_micro
4. Build Micro-Controllers engine and models to library on host.
.. code-block:: sh
# copy convert result to micro dir ``path/to/micro``
cp build/har-cnn/model/har_cnn_micro.tar.gz path/to/micro/
cd path/to/micro
tar zxvf har_cnn_micro.tar.gz
bazel build //micro/codegen:micro_engine
.. note::
- This step can be skipped if you just want to run a model using ``tools/python/run_micro.py``, such as commands in step 5.
- The build result ``bazel-bin/micro/codegen/libmicro_engine.so``'s abi is host, if you want to run the model on micro controllers, you should build the code with the target abi.
5. Run the model on host.
.. code-block:: sh
CONF_FILE=/path/to/mace-models/micro-models/har-cnn/har-cnn.yml
# Run
python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build
# Test model run time
python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build --round=100
# Validate the correctness by comparing the results against the
# original model and framework, measured with cosine distance for similarity.
python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build --validate
# Validate the layers' correctness.
python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build --validate --layers 0:-1
Deploy your model into applications
------------------------------------
Please refer to \ ``/mace/micro/tools/micro_run.cc`` for full usage. The following list the key steps.
.. code-block:: cpp
// Include the headers
#include "micro/include/public/micro.h"
// 1. Create MaceMicroEngine instance
MaceMicroEngine *micro_engine = nullptr;
MaceStatus status = har_cnn::GetMicroEngineSingleton(&micro_engine);
// 1. Create and register Input buffers
std::vector<std::shared_ptr<char>> inputs;
std::vector<int32_t> input_sizes;
for (size_t i = 0; i < input_shapes.size(); ++i) {
input_sizes.push_back(std::accumulate(input_shapes[i].begin(),
input_shapes[i].end(), sizeof(float),
std::multiplies<int32_t>()));
inputs.push_back(std::shared_ptr<char>(new char[input_sizes[i]],
std::default_delete<char[]>()));
}
// TODO: fill data into input buffers
for (size_t i = 0; i < input_names.size(); ++i) {
micro_engine->RegisterInputData(i, inputs[i].get(),
input_shapes[i].data());
}
// 3. Run the model
MaceStatus status = micro_engine->Run();
// 4. Get the results
for (size_t i = 0; i < output_names.size(); ++i) {
void *output_buffer = nullptr;
const int32_t *output_dims = nullptr;
uint32_t dim_size = 0;
MaceStatus status =
micro_engine->GetOutputData(i, &output_buffer, &output_dims, &dim_size);
// TODO: the result data is in output_buffer, you can not delete output_buffer.
}
......@@ -53,10 +53,14 @@ cat <<EOF > ${OUTPUT_FILENAME}
// This is a generated file. DO NOT EDIT!
namespace mace {
namespace {
#ifndef _MSC_VER
__attribute__((visibility("default")))
#endif
const char *MaceVersion() { return "MACEVER-${GIT_VERSION}" + 8; }
const char *kMaceVersion = "MACEVER-${GIT_VERSION}";
} // namespace
const char *MaceVersion() { return kMaceVersion + 8; }
} // namespace mace
EOF
......@@ -322,7 +322,8 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
.TypeConstraint("T", key_dtype)
.Build();
if (registry_.at(op_type)->creators.count(key) == 0) {
LOG(FATAL) << "Key not registered: " << key;
LOG(FATAL) << "Key not registered: " << key
<< ", op type is: " << operator_def->type();
}
return registry_.at(op_type)->creators.at(key)(context);
}
......
......@@ -8,9 +8,11 @@ package(
licenses(["notice"]) # Apache 2.0
load("@com_google_protobuf//:protobuf.bzl",
"py_proto_library",
"cc_proto_library")
load(
"@com_google_protobuf//:protobuf.bzl",
"cc_proto_library",
"py_proto_library",
)
py_proto_library(
name = "mace_py",
......@@ -27,3 +29,14 @@ cc_proto_library(
default_runtime = "@com_google_protobuf//:protobuf_lite",
protoc = "@com_google_protobuf//:protoc",
)
py_proto_library(
name = "micro_mem_py",
srcs = ["micro_mem.proto"],
default_runtime = "@com_google_protobuf//:protobuf_python",
protoc = "@com_google_protobuf//:protoc",
srcs_version = "PY2AND3",
deps = [
"@com_google_protobuf//:protobuf_python",
],
)
......@@ -14,6 +14,7 @@ enum DataType {
DT_HALF = 3;
DT_INT32 = 4;
DT_FLOAT16 = 5;
DT_BFLOAT16 = 6;
}
enum MemoryType {
......@@ -76,6 +77,7 @@ message OperatorDef {
repeated DataType output_type = 8;
repeated QuantizeActivationInfo quantize_info = 9;
// for mace it is mem_id, for micro, it is mem_offset
repeated int32 mem_id = 10;
// for hexagon mace-nnlib
......
syntax = "proto2";
package micro;
message OutputShape {
repeated int64 dims = 1;
}
message OpContext {
optional int32 op_idx = 1;
// The input info of downstream operator is the output info of upstream
// operator, so there is no output info defined here
repeated uint32 input_infos = 2;
repeated OutputShape output_resize_shapes = 3;
}
message Graph {
repeated OpContext op_contexts = 1;
repeated uint32 input_op_idxs = 2;
// The output info of the last operator, which is not recorded in opcontext,
// is the output of graph
repeated uint32 output_infos = 3;
}
config_setting(
name = "hexagon_enabled",
define_values = {
"hexagon": "true",
},
visibility = ["//visibility:public"],
)
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "base_hdrs",
hdrs = glob([
"*.h",
]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/include",
"//micro/port",
],
)
cc_library(
name = "base",
srcs = glob(
[
"*.cc",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"base_hdrs",
"//micro/port",
],
)
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/base/logger.h"
#include "micro/base/value_to_str.h"
#include "micro/port/api.h"
namespace micro {
namespace base {
namespace {
const int32_t kInt64ValueBufferLength = 21;
const int32_t kInt32ValueBufferLength = 12;
const int32_t kInt16ValueBufferLength = 6;
const int32_t kInt8ValueBufferLength = 4;
const int32_t kFloatValueBufferLength = 21;
inline bool IsValidLogLevel(const LogLevel level) {
return level >= CLEAN && level < INVALID_MAX;
}
char LogLevelToShortStr(LogLevel level) {
if (!IsValidLogLevel(level)) {
level = INFO;
}
return "CIWEF"[static_cast<int>(level)];
}
} // namespace
Logger::Logger(const char *fname, uint32_t line,
LogLevel severity) : severity_(severity) {
if (severity == CLEAN) {
return;
}
char buffer[15] = {0};
char *end = buffer + 15;
buffer[0] = LogLevelToShortStr(severity);
buffer[1] = ' ';
micro::port::api::DebugLog(buffer);
micro::port::api::DebugLog(fname);
buffer[0] = ':';
ToString("] ", ToString(line, buffer + 1, end), end);
micro::port::api::DebugLog(buffer);
}
Logger::~Logger() {
micro::port::api::DebugLog("\n");
if (severity_ == FATAL) {
micro::port::api::Abort();
}
}
const Logger &Logger::operator<<(const char *str) const {
micro::port::api::DebugLog(str);
return *this;
}
const Logger &Logger::operator<<(const char c) const {
char buffer[2] = {0};
buffer[0] = c;
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const float value) const {
char buffer[kFloatValueBufferLength] = {0};
ToString(value, buffer, buffer + kFloatValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const int64_t value) const {
char buffer[kInt64ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt64ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const int32_t value) const {
char buffer[kInt32ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt32ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const uint32_t value) const {
char buffer[kInt32ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt32ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const int16_t value) const {
char buffer[kInt16ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt16ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const uint16_t value) const {
char buffer[kInt16ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt16ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const int8_t value) const {
char buffer[kInt8ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt8ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const uint8_t value) const {
char buffer[kInt8ValueBufferLength] = {0};
ToString(value, buffer, buffer + kInt8ValueBufferLength);
micro::port::api::DebugLog(buffer);
return *this;
}
const Logger &Logger::operator<<(const bool value) const {
if (value) {
micro::port::api::DebugLog("true");
} else {
micro::port::api::DebugLog("false");
}
return *this;
}
} // namespace base
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_LOGGER_H_
#define MICRO_BASE_LOGGER_H_
#include <stdint.h>
namespace micro {
enum LogLevel {
CLEAN = 0,
INFO = 1,
WARNING = 2,
ERROR = 3,
FATAL = 4,
INVALID_MAX,
};
namespace base {
class Logger {
public:
Logger(const char *fname, uint32_t line, LogLevel severity);
~Logger();
const Logger &operator<<(const char *str) const;
const Logger &operator<<(const char c) const;
const Logger &operator<<(const float value) const;
const Logger &operator<<(const int64_t value) const;
const Logger &operator<<(const int32_t value) const;
const Logger &operator<<(const uint32_t value) const;
const Logger &operator<<(const int16_t value) const;
const Logger &operator<<(const uint16_t value) const;
const Logger &operator<<(const int8_t value) const;
const Logger &operator<<(const uint8_t value) const;
const Logger &operator<<(const bool value) const;
private:
LogLevel severity_;
};
} // namespace base
} // namespace micro
#endif // MICRO_BASE_LOGGER_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_LOGGING_H_
#define MICRO_BASE_LOGGING_H_
#include <stdint.h>
#include "micro/base/logger.h"
#include "micro/include/port/define.h"
namespace micro {
namespace log {
#define LOG(severity) \
micro::base::Logger(__FILE__, __LINE__, micro::severity)
#ifndef NDEBUG
#define LOG1(severity, value) LOG(severity) << value
#define LOG2(severity, value1, value2) LOG(severity) << value1 << value2
#define LOG3(severity, value1, value2, value3) \
LOG(severity) << value1 << value2 << value3
#define LOG4(severity, value1, value2, value3, value4) \
LOG(severity) << value1 << value2 << value3 << value4
#define LOG5(severity, value1, value2, value3, value4, value5) \
LOG(severity) << value1 << value2 << value3 << value4 << value5
#else
#define LOG1(severity, value)
#define LOG2(severity, value1, value2)
#define LOG3(severity, value1, value2, value3)
#define LOG4(severity, value1, value2, value3, value4)
#define LOG5(severity, value1, value2, value3, value4, value5)
#endif // NDEBUG
#ifndef NDEBUG
#define MACE_ASSERT(condition) \
if (!(condition)) LOG(FATAL) << "Assert failed: "#condition // NOLINT
#define MACE_ASSERT1(condition, str) \
if (!(condition)) LOG(FATAL) << "Assert failed: "#condition " " << str // NOLINT
#define MACE_ASSERT2(condition, str1, str2) \
if (!(condition)) LOG(FATAL) << "Assert failed: "#condition " " << str1 << str2 // NOLINT
#else
#define MACE_ASSERT(condition)
#define MACE_ASSERT1(condition, string)
#define MACE_ASSERT2(condition, string1, string2)
#endif // NDEBUG
#define MACE_NOT_IMPLEMENTED MACE_ASSERT1(false, "not implemented")
#define MACE_CHECK_SUCCESS(stmt) \
{ \
MaceStatus status = (stmt); \
if (status != MACE_SUCCESS) { \
LOG(FATAL) << #stmt << " failed with error: " \
<< status; \
} \
}
#define MACE_RETURN_IF_ERROR(stmt) \
{ \
MaceStatus status = (stmt); \
if (status != MACE_SUCCESS) { \
LOG(INFO) << static_cast<int32_t>(stmt) \
<< " failed with error: " \
<< static_cast<int32_t>(status); \
return status; \
} \
}
} // namespace log
} // namespace micro
#endif // MICRO_BASE_LOGGING_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/base/serialize.h"
#include "micro/base/logging.h"
#include "micro/base/utils.h"
namespace micro {
#ifdef MACE_WRITE_MAGIC
SerialUint32 Serialize::GetMagic() const {
return magic_;
}
SerialUint32 Serialize::Magic(const char *bytes4) const {
MACE_ASSERT1(micro::base::strlen(bytes4) >= 4, "The magic bytes must >= 4.");
SerialUint32 magic = 0;
for (int32_t i = 0; i < 32 && (*bytes4) != '\0'; i += 8, ++bytes4) {
magic += (*bytes4) << i;
}
return magic;
}
MaceStatus Serialize::MagicToString(SerialUint32 magic,
char (&array)[5]) const {
char *buffer = array;
for (int32_t i = 0; i <32; i += 8, ++buffer) {
*buffer = (magic >> i) & 0x000000ff;
}
*buffer = '\0';
return MACE_SUCCESS;
}
#endif // MACE_WRITE_MAGIC
void Serialize::Uint2OpIOInfo(const OpIOInfo *info) const {
OpIOInfo *io_info = const_cast<OpIOInfo *>(info);
uint32_t info_data = *(reinterpret_cast<uint32_t *>(io_info));
io_info->op_def_idx_ = (info_data & 0xffff0000) >> 16;
io_info->output_idx_ = (info_data & 0x0000ffff);
}
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_SERIALIZE_H_
#define MICRO_BASE_SERIALIZE_H_
#include <stdint.h>
#include "micro/base/serialize_type.h"
#include "micro/include/public/micro.h"
namespace micro {
#ifdef MACE_WRITE_MAGIC
#ifndef MACE_DEFINE_HARD_CODE_MAGIC
#define MACE_DEFINE_HARD_CODE_MAGIC(CLASS_NAME) \
SerialUint32 GetHardCodeMagic() const { \
return Magic(#CLASS_NAME); \
}
#endif // MACE_DEFINE_HARD_CODE_MAGIC
#else
#ifndef MACE_DEFINE_HARD_CODE_MAGIC
#define MACE_DEFINE_HARD_CODE_MAGIC(CLASS_NAME)
#endif // MACE_DEFINE_HARD_CODE_MAGIC
#endif // MACE_WRITE_MAGIC
// We describe a tensor as an output tensor, but it can also
// be used to represent an input tensor.
struct OpIOInfo {
uint16_t op_def_idx_;
uint16_t output_idx_;
};
class Serialize {
#ifdef MACE_WRITE_MAGIC
public:
SerialUint32 GetMagic() const;
MaceStatus MagicToString(SerialUint32 magic, char (&array)[5]) const;
protected:
SerialUint32 magic_;
protected:
SerialUint32 Magic(const char *bytes4) const;
#endif // MACE_WRITE_MAGIC
public:
void Uint2OpIOInfo(const OpIOInfo *output_info) const;
};
} // namespace micro
#endif // MICRO_BASE_SERIALIZE_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_SERIALIZE_TYPE_H_
#define MICRO_BASE_SERIALIZE_TYPE_H_
#include <stdint.h>
#include "micro/include/public/micro.h"
namespace micro {
#ifdef MACE_OFFSET_USE_16
typedef uint16_t offset_size_t;
#else
typedef uint32_t offset_size_t;
#endif // MACE_OFFSET_USE_16
template<typename T>
struct SerialArray {
offset_size_t size_;
offset_size_t offset_;
SerialArray() : size_(0), offset_(0) {}
};
struct SerialString {
offset_size_t packed_length_;
offset_size_t offset_;
SerialString() : packed_length_(0), offset_(0) {}
};
struct SerialBytes {
offset_size_t packed_length_;
offset_size_t offset_;
SerialBytes() : packed_length_(0), offset_(0) {}
};
typedef float SerialFloat;
typedef int32_t SerialInt32;
typedef uint32_t SerialUint32;
typedef uint32_t SerialBool;
typedef int32_t SerialInt16;
typedef uint32_t SerialUint16;
typedef int32_t SerialInt8;
typedef uint32_t SerialUint8;
#ifndef MACE_DECLARE_OBJECT_FUNC
#define MACE_DECLARE_OBJECT_FUNC(T, OBJECT_NAME) \
T OBJECT_NAME() const;
#endif // MACE_DECLARE_OBJECT_FUNC
#ifndef MACE_DEFINE_OBJECT_FUNC
#define MACE_DEFINE_OBJECT_FUNC(CLASS_NAME, T, OBJECT_NAME) \
T CLASS_NAME::OBJECT_NAME() const { \
return OBJECT_NAME##_; \
}
#endif // MACE_DEFINE_OBJECT_FUNC
#ifndef MACE_MACE_DECLARE_PTR_FUNC
#define MACE_DECLARE_PTR_FUNC(T, OBJECT_NAME) \
const T *OBJECT_NAME() const;
#endif // MACE_DECLARE_PTR_FUNC
#ifndef MACE_DEFINE_PTR_FUNC
#define MACE_DEFINE_PTR_FUNC(CLASS_NAME, T, OBJECT_NAME) \
const T *CLASS_NAME::OBJECT_NAME() const { \
return &OBJECT_NAME##_; \
}
#endif // MACE_DEFINE_PTR_FUNC
#ifndef MACE_DECLARE_ARRAY_FUNC
#define MACE_DECLARE_ARRAY_FUNC(T, OBJECT_NAME) \
T OBJECT_NAME(uint32_t index) const; \
uint32_t OBJECT_NAME##_size() const
#endif // MACE_DECLARE_ARRAY_FUNC
#ifndef MACE_DECLARE_ARRAY_BASE_PTR_FUNC
#define MACE_DECLARE_ARRAY_BASE_PTR_FUNC(T, OBJECT_NAME) \
const T * OBJECT_NAME() const
#endif // MACE_DECLARE_ARRAY_BASE_PTR_FUNC
#ifndef MACE_DEFINE_ARRAY_BASE_PTR_FUNC
#define MACE_DEFINE_ARRAY_BASE_PTR_FUNC( \
CLASS_NAME, T, OBJECT_NAME, ARRAY_NAME) \
const T *CLASS_NAME::OBJECT_NAME() const { \
const T *array = reinterpret_cast<const T *>( \
reinterpret_cast<const uint8_t *>(this) + ARRAY_NAME.offset_); \
return array; \
}
#endif // MACE_DEFINE_ARRAY_BASE_PTR_FUNC
#ifndef MACE_DEFINE_ARRAY_FUNC
#define MACE_DEFINE_ARRAY_FUNC(CLASS_NAME, T, OBJECT_NAME, ARRAY_NAME) \
T CLASS_NAME::OBJECT_NAME(uint32_t index) const { \
const T *array = reinterpret_cast<const T *>( \
reinterpret_cast<const uint8_t *>(this) + ARRAY_NAME.offset_); \
return *(array + index); \
} \
uint32_t CLASS_NAME::OBJECT_NAME##_size() const { \
return ARRAY_NAME.size_; \
}
#endif // MACE_DEFINE_ARRAY_FUNC
#ifndef MACE_DECLARE_PTR_ARRAY_FUNC
#define MACE_DECLARE_PTR_ARRAY_FUNC(T, OBJECT_NAME) \
const T *OBJECT_NAME(uint32_t index) const; \
uint32_t OBJECT_NAME##_size() const
#endif // MACE_DECLARE_PTR_ARRAY_FUNC
#ifndef MACE_DEFINE_PTR_ARRAY_FUNC
#define MACE_DEFINE_PTR_ARRAY_FUNC(CLASS_NAME, T, OBJECT_NAME, ARRAY_NAME) \
const T *CLASS_NAME::OBJECT_NAME(uint32_t index) const { \
const T *array = reinterpret_cast<const T *>( \
reinterpret_cast<const uint8_t *>(this) + ARRAY_NAME.offset_); \
return (array + index); \
} \
\
uint32_t CLASS_NAME::OBJECT_NAME##_size() const { \
return ARRAY_NAME.size_; \
}
#endif // MACE_DEFINE_PTR_ARRAY_FUNC
#ifndef MACE_DECLARE_STRING_FUNC
#define MACE_DECLARE_STRING_FUNC(OBJECT_NAME) \
const char *OBJECT_NAME() const;
#endif // MACE_DECLARE_STRING_FUNC
#ifndef MACE_DEFINE_STRING_FUNC
#define MACE_DEFINE_STRING_FUNC(CLASS_NAME, OBJECT_NAME, STRING_NAME) \
const char *CLASS_NAME::OBJECT_NAME() const { \
if (STRING_NAME.packed_length_ == 0) { \
return NULL; \
} else { \
return reinterpret_cast<const char *>(this) + STRING_NAME.offset_; \
} \
}
#endif // MACE_DEFINE_STRING_FUNC
#ifndef MACE_DECLARE_BYTES_FUNC
#define MACE_DECLARE_BYTES_FUNC(OBJECT_NAME) \
const uint8_t *OBJECT_NAME() const; \
uint32_t OBJECT_NAME##_size() const
#endif // MACE_DECLARE_BYTES_FUNC
#ifndef MACE_DEFINE_BYTES_FUNC
#define MACE_DEFINE_BYTES_FUNC(CLASS_NAME, OBJECT_NAME, BYTES_NAME) \
const uint8_t *CLASS_NAME::OBJECT_NAME() const { \
if (BYTES_NAME.packed_length_ == 0) { \
return NULL; \
} else { \
return reinterpret_cast<const uint8_t *>(this) + BYTES_NAME.offset_; \
} \
} \
\
uint32_t CLASS_NAME::OBJECT_NAME##_size() const { \
return BYTES_NAME.packed_length_; \
}
#endif // MACE_DEFINE_BYTES_FUNC
#ifndef MACE_DECLARE_STRING_ARRAY_FUNC
#define MACE_DECLARE_STRING_ARRAY_FUNC(OBJECT_NAME) \
const char *OBJECT_NAME(uint32_t index) const; \
uint32_t OBJECT_NAME##_size() const
#endif
#ifndef MACE_DEFINE_STRING_ARRAY_FUNC
#define MACE_DEFINE_STRING_ARRAY_FUNC(CLASS_NAME, OBJECT_NAME, ARRAY_NAME) \
const char *CLASS_NAME::OBJECT_NAME(uint32_t index) const { \
const SerialString *array = reinterpret_cast<const SerialString *>( \
reinterpret_cast<const char *>(this) + ARRAY_NAME.offset_); \
const SerialString *serial_str = array + index; \
const char *str = reinterpret_cast<const char *>(serial_str) + \
serial_str->offset_; \
return str; \
} \
\
uint32_t CLASS_NAME::OBJECT_NAME##_size() const { \
return ARRAY_NAME.size_; \
}
#endif // MACE_DEFINE_STRING_ARRAY_FUNC
} // namespace micro
#endif // MICRO_BASE_SERIALIZE_TYPE_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_TYPES_H_
#define MICRO_BASE_TYPES_H_
#include "micro/include/public/micro.h"
#include "micro/include/utils/bfloat16.h"
namespace micro {
#ifdef MACE_ENABLE_BFLOAT16
typedef BFloat16 mifloat;
#else
typedef float mifloat;
#endif // MACE_ENABLE_BFLOAT16
template<class T>
struct DataTypeToEnum;
template<DataType VALUE>
struct EnumToDataType;
#ifndef MACE_MAPPING_DATA_TYPE_AND_ENUM
#define MACE_MAPPING_DATA_TYPE_AND_ENUM(DATA_TYPE, ENUM_VALUE) \
template <> \
struct DataTypeToEnum<DATA_TYPE> { \
static DataType v() { return ENUM_VALUE; } \
static const DataType value = ENUM_VALUE; \
}; \
template <> \
struct EnumToDataType<ENUM_VALUE> { \
typedef DATA_TYPE Type; \
};
#endif // MACE_MAPPING_DATA_TYPE_AND_ENUM
MACE_MAPPING_DATA_TYPE_AND_ENUM(float, DT_FLOAT);
MACE_MAPPING_DATA_TYPE_AND_ENUM(uint8_t, DT_UINT8);
MACE_MAPPING_DATA_TYPE_AND_ENUM(int32_t, DT_INT32);
#ifdef MACE_ENABLE_BFLOAT16
MACE_MAPPING_DATA_TYPE_AND_ENUM(BFloat16, DT_BFLOAT16);
#endif
} // namespace micro
#endif // MICRO_BASE_TYPES_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/base/utils.h"
#include <math.h>
#include "micro/base/logging.h"
namespace micro {
namespace base {
uint32_t strlen(const char *str) {
MACE_ASSERT1(str != NULL, "str can not be NULL.");
uint32_t length = 0;
while (*str++ != '\0') {
++length;
}
return length;
}
int32_t strcmp(const char *str1, const char *str2) {
MACE_ASSERT1(str1 != NULL && str2 != NULL,
"strcmp str can not be NULL.");
while (*str1 == *str2) {
if (*str1 == '\0') {
return 0;
}
++str1;
++str2;
}
return (*str1) - (*str2);
}
void memcpy(void *dst, const void *src, uint32_t bytes) {
MACE_ASSERT1(dst != NULL && src != NULL && bytes > 0,
"Invalid params.");
uint8_t *dst_mem = static_cast<uint8_t *>(dst);
const uint8_t *src_mem = static_cast<const uint8_t *>(src);
while (bytes-- > 0) {
*dst_mem++ = *src_mem++;
}
}
int32_t GetShapeSize(uint32_t dim_size, const int32_t *dims) {
return accumulate_multi(dims, 0, dim_size);
}
float sqrt(float x) {
return ::sqrt(x);
}
int32_t ceil(float f) {
int32_t i = (int32_t) f;
return (f == static_cast<float>(i)) ? i : i + 1;
}
int32_t floor(float f) {
return ::floor(f);
}
float fabs(float x) {
if (x < 0.0f) {
return -x;
} else if (x > 0.0f) {
return x;
} else {
return 0.0f;
}
}
float lowest() {
return -3.402823466e+38F;
}
float highest() {
return 3.402823466e+38F;
}
float tanh(float x) {
return ::tanh(x);
}
float exp(float x) {
return ::exp(x);
}
float pow(float x, float y) {
return ::pow(x, y);
}
float log(float x) {
return ::log(x);
}
} // namespace base
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_UTILS_H_
#define MICRO_BASE_UTILS_H_
#include <stdint.h>
#include "micro/base/logging.h"
namespace micro {
namespace base {
uint32_t strlen(const char *str);
int32_t strcmp(const char *str1, const char *str2);
void memcpy(void *dst, const void *src, uint32_t bytes);
int32_t GetShapeSize(uint32_t dim_size, const int32_t *dims);
float sqrt(float x);
int32_t ceil(float f);
int32_t floor(float f);
float fabs(float x);
float lowest();
float highest();
float tanh(float x);
float exp(float x);
float pow(float x, float y);
float log(float x);
template<typename T>
void memset(T *src, T value, uint32_t size) {
for (uint32_t i = 0; i < size; ++i) {
src[i] = value;
}
}
template<typename T>
T accumulate_multi(const T *array, uint32_t array_start, uint32_t array_end) {
MACE_ASSERT(array_start >= 0 && array_start <= array_end);
if (array == NULL || array_start == array_end) {
return 1;
}
T total = array[array_start];
for (uint32_t i = array_start + 1; i < array_end; ++i) {
total *= array[i];
}
return total;
}
template<typename T>
T abs(T x) {
return x > 0 ? x : -x;
}
template<typename T>
T max(T a, T b) {
return a > b ? a : b;
}
template<typename T>
T min(T a, T b) {
return a < b ? a : b;
}
template<typename T>
void swap(T *a, T *b) { // NOLINT
T c = *a;
*a = *b;
*b = c;
}
template<typename T>
T clamp(T in, T low, T high) {
return max<T>(low, min<T>(in, high)); // NOLINT
}
} // namespace base
} // namespace micro
#endif // MICRO_BASE_UTILS_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/base/value_to_str.h"
namespace micro {
namespace base {
#ifndef MACE_SIGNED_TO_STRING
#define MACE_SIGNED_TO_STRING(T, UNSIGNED_T) \
template<> \
char *ToString(T value, char *buffer, char *end) { \
if (value < 0) { \
value = -value; \
*buffer++ = '-'; \
} \
return ToString(static_cast<UNSIGNED_T>(value), buffer, end); \
}
#endif // MACE_SIGNED_TO_STRING
void ReverseInplace(char *start, char *end) {
end--;
while (start < end) {
char tmp = *start;
*start++ = *end;
*end-- = tmp;
}
}
MACE_SIGNED_TO_STRING(int64_t, uint64_t)
MACE_SIGNED_TO_STRING(int32_t, uint32_t)
MACE_SIGNED_TO_STRING(int16_t, uint16_t)
MACE_SIGNED_TO_STRING(int8_t, uint8_t)
template<>
char *ToString(const char *str, char *buffer, char *end) {
end--;
while (*str != '\0' && buffer < end) {
*buffer++ = *str++;
}
*buffer = '\0';
return buffer;
}
template<>
char *ToString(float value, char *buffer, char *end) {
if (value <= -1e-8) {
*buffer++ = '-';
}
int32_t int_part = (int32_t) value;
buffer = ToString(int_part, buffer, end);
float deci_part = value - int_part;
if (deci_part < 1e-8 && deci_part > -1e-8) {
return buffer;
}
if (deci_part < 0.0) {
deci_part = -deci_part;
}
end--;
*buffer++ = '.';
do {
deci_part *= 10;
int32_t remainder = (int32_t) deci_part;
*buffer++ = '0' + remainder;
deci_part -= remainder;
} while (deci_part > 0 && buffer < end);
*buffer = '\0';
return buffer;
}
} // namespace base
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_BASE_VALUE_TO_STR_H_
#define MICRO_BASE_VALUE_TO_STR_H_
#include <stdint.h>
namespace micro {
namespace base {
void ReverseInplace(char *start, char *end);
// for uint64_t/uint32_t/uint16_t/uint8_t
template<typename T>
char *ToString(T value, char *buffer, char *end) {
char *start = buffer;
end--;
do {
*buffer++ = '0' + (value % 10);
value /= 10;
} while (value > 0 && buffer < end);
ReverseInplace(start, buffer);
*buffer = '\0';
return buffer;
}
template<>
char *ToString(int64_t value, char *buffer, char *end);
template<>
char *ToString(int32_t value, char *buffer, char *end);
template<>
char *ToString(int16_t value, char *buffer, char *end);
template<>
char *ToString(int8_t value, char *buffer, char *end);
template<>
char *ToString(const char *str, char *buffer, char *end);
template<>
char *ToString(float value, char *buffer, char *end);
} // namespace base
} // namespace micro
#endif // MICRO_BASE_VALUE_TO_STR_H_
# Description:
# Generated model and runtime code.
#
package(
default_visibility = ["//visibility:public"],
)
cc_library(
name = "generated_models",
srcs = glob(["models/**/*.cc"]),
hdrs = glob(["models/**/*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/framework",
"//micro/include",
"//micro/model",
"//micro/ops",
],
)
cc_library(
name = "micro_engine_c",
srcs = glob(["micro/codegen/engines/**/micro_engine_c_interface.cc"]),
hdrs = glob(["micro/codegen/engines/**/micro_engine_c_interface.cc"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
":micro_engine",
],
alwayslink = 1,
)
cc_library(
name = "micro_engine",
srcs = glob(
["engines/**/*.cc"],
exclude = ["micro/codegen/engines/**/micro_engine_c_interface.cc"],
),
hdrs = glob(
[
"engines/**/*.h",
],
exclude = ["micro/codegen/engines/**/micro_engine_c_interface.cc"],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"generated_models",
"//micro/framework",
"//micro/model",
"//micro/ops",
],
alwayslink = 1,
)
cc_binary(
name = "libmicro.so",
linkshared = 1,
linkstatic = 1,
deps = [
":micro_engine",
],
)
cc_binary(
name = "libmicro.lo",
linkshared = False,
linkstatic = True,
deps = [
":micro_engine",
],
)
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "framework",
srcs = glob(["*.cc"]),
hdrs = glob(["*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/base",
"//micro/include",
"//micro/model",
],
)
cc_library(
name = "framework_for_optest",
srcs = glob(
["*.cc"],
exclude = ["operator.cc"],
),
hdrs = glob(["*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/base",
"//micro/include",
"//micro/model",
],
)
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/framework/graph.h"
#include "micro/base/logging.h"
#include "micro/base/serialize.h"
#include "micro/base/utils.h"
#include "micro/framework/operator.h"
#include "micro/include/public/micro.h"
#include "micro/model/net_def.h"
namespace micro {
namespace framework {
MACE_DEFINE_PTR_ARRAY_FUNC(Graph, OpContext, op_context, op_contexts_)
MACE_DEFINE_ARRAY_FUNC(Graph, uint32_t, input_op_idx, input_op_idxs_);
MACE_DEFINE_PTR_ARRAY_FUNC(Graph, OpIOInfo, output_info, output_infos_);
MaceStatus Graph::Init(MaceMicroEngineConfig *engine_config) {
MACE_ASSERT(engine_config->net_def_->op_size() == op_context_size());
uint32_t output_info_size = this->output_info_size();
for (uint32_t i = 0; i < output_info_size; ++i) {
Uint2OpIOInfo(this->output_info(i));
}
uint32_t op_size = engine_config->net_def_->op_size();
for (uint32_t i = 0; i < op_size; ++i) {
OpContext *op_ctx = const_cast<OpContext *>(op_context(i));
MACE_RETURN_IF_ERROR(op_ctx->Init(
engine_config, engine_config->net_def_->op(i)));
}
return MACE_SUCCESS;
}
MaceStatus Graph::RegisterInputData(MaceMicroEngineConfig *engine_config,
uint32_t idx,
const void *input_buffer,
const int32_t *input_dims) {
engine_config->input_buffers_[idx] = input_buffer;
engine_config->input_shapes_[idx] = input_dims;
// update the op's input buffers
uint32_t op_idx = input_op_idx(idx);
framework::Operator *input_op = engine_config->op_array_[op_idx];
return input_op->OnInit();
}
MaceStatus Graph::Run(MaceMicroEngineConfig *engine_config) {
uint32_t op_size = engine_config->net_def_->op_size();
for (uint32_t i = 0; i < op_size; ++i) {
OpContext *op_ctx = const_cast<OpContext *>(op_context(i));
MACE_RETURN_IF_ERROR(op_ctx->Run(engine_config));
}
return MACE_SUCCESS;
}
MaceStatus Graph::GetOutputData(MaceMicroEngineConfig *engine_config,
const uint32_t idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size) {
MACE_ASSERT(idx < output_info_size());
const OpIOInfo *o_info = output_info(idx);
return GetOpOutputData(engine_config, o_info->op_def_idx_,
o_info->output_idx_, output_data,
output_dims, output_dim_size);
}
MaceStatus Graph::GetOpOutputData(MaceMicroEngineConfig *engine_config,
const uint32_t op_def_idx,
const uint32_t output_idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size) {
MACE_ASSERT(engine_config != NULL);
MACE_ASSERT(output_data != NULL);
MACE_ASSERT(output_dims != NULL);
MACE_ASSERT(output_dim_size != NULL);
const model::OperatorDef *op_def = engine_config->net_def_->op(op_def_idx);
*output_data = engine_config->tensor_mem_ + op_def->mem_offset(output_idx);
const model::OutputShape *output_shape =
op_context(op_def_idx)->output_resize_shape(output_idx);
*output_dims = output_shape->dim();
*output_dim_size = output_shape->dim_size();
return MACE_SUCCESS;
}
} // namespace framework
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_FRAMEWORK_GRAPH_H_
#define MICRO_FRAMEWORK_GRAPH_H_
#include "micro/base/serialize.h"
#include "micro/framework/op_context.h"
namespace micro {
struct MaceMicroEngineConfig;
namespace framework {
class Graph : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(Graph)
MACE_DECLARE_PTR_ARRAY_FUNC(OpContext, op_context);
MACE_DECLARE_ARRAY_FUNC(uint32_t, input_op_idx);
MACE_DECLARE_PTR_ARRAY_FUNC(OpIOInfo, output_info);
MaceStatus Init(MaceMicroEngineConfig *engine_config);
MaceStatus RegisterInputData(MaceMicroEngineConfig *engine_config,
uint32_t idx,
const void *input_buffer,
const int32_t *input_dims);
MaceStatus Run(MaceMicroEngineConfig *engine_config);
MaceStatus GetOutputData(MaceMicroEngineConfig *engine_config,
const uint32_t idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size);
MaceStatus GetOpOutputData(MaceMicroEngineConfig *engine_config,
const uint32_t op_def_idx,
const uint32_t output_idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size);
protected:
SerialArray<OpContext> op_contexts_;
SerialArray<SerialUint32> input_op_idxs_;
SerialArray<OpIOInfo> output_infos_;
};
} // namespace framework
} // namespace micro
#endif // MICRO_FRAMEWORK_GRAPH_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/base/logging.h"
#include "micro/base/utils.h"
#include "micro/framework/graph.h"
#include "micro/framework/scratch_buffer.h"
#include "micro/include/public/micro.h"
#include "micro/model/net_def.h"
#include "micro/model/operator_def.h"
#include "micro/port/api.h"
namespace micro {
MaceStatus MaceMicroEngine::Init(MaceMicroEngineConfig *engine_config) {
MACE_ASSERT(engine_config != NULL && engine_config->net_def_ != NULL
&& engine_config->model_data_ != NULL
&& engine_config->graph_ != NULL
&& engine_config->op_array_ != NULL
&& engine_config->tensor_mem_ != NULL);
engine_config_ = engine_config;
MACE_RETURN_IF_ERROR(engine_config_->graph_->Init(engine_config_));
return MACE_SUCCESS;
}
MaceStatus MaceMicroEngine::RegisterInputData(uint32_t idx,
const void *input_buffer,
const int32_t *input_dims) {
MACE_ASSERT(idx < engine_config_->net_def_->input_info_size());
MACE_ASSERT(input_buffer != NULL);
MACE_ASSERT(input_dims != NULL);
return engine_config_->graph_->RegisterInputData(engine_config_, idx,
input_buffer, input_dims);
}
MaceStatus MaceMicroEngine::Run() {
return engine_config_->graph_->Run(engine_config_);
}
MaceStatus MaceMicroEngine::GetOutputData(const uint32_t idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size) {
return engine_config_->graph_->GetOutputData(engine_config_, idx,
output_data, output_dims,
output_dim_size);
}
MaceStatus MaceMicroEngine::GetOpOutputData(const uint32_t op_def_idx,
const uint32_t output_idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size) {
return engine_config_->graph_->GetOpOutputData(engine_config_, op_def_idx,
output_idx, output_data,
output_dims, output_dim_size);
}
MaceMicroEngine::MaceMicroEngine(const MaceMicroEngine &) {
MACE_NOT_IMPLEMENTED;
}
MaceMicroEngine &MaceMicroEngine::operator=(const MaceMicroEngine &) {
MACE_NOT_IMPLEMENTED;
return *this;
}
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/framework/op_context.h"
#include "micro/framework/operator.h"
#include "micro/model/net_def.h"
#include "micro/model/operator_def.h"
#include "micro/include/public/micro.h"
namespace micro {
namespace framework {
MACE_DEFINE_OBJECT_FUNC(OpContext, uint32_t, op_idx)
MACE_DEFINE_PTR_ARRAY_FUNC(OpContext, OpIOInfo, input_info, input_infos_)
MACE_DEFINE_PTR_ARRAY_FUNC(OpContext, model::OutputShape,
output_resize_shape, output_resize_shapes_)
MaceStatus OpContext::Init(MaceMicroEngineConfig *engine_config,
const model::OperatorDef *op_def) {
// init OpContext
uint32_t input_info_size = this->input_info_size();
for (uint32_t i = 0; i < input_info_size; ++i) {
Uint2OpIOInfo(this->input_info(i));
}
// init Op
uint32_t op_i = op_idx();
MACE_RETURN_IF_ERROR(
engine_config->op_array_[op_i]->Init(engine_config, this, op_def));
return MACE_SUCCESS;
}
MaceStatus OpContext::Run(MaceMicroEngineConfig *engine_config) {
return engine_config->op_array_[op_idx()]->Run();
}
} // namespace framework
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_FRAMEWORK_OP_CONTEXT_H_
#define MICRO_FRAMEWORK_OP_CONTEXT_H_
#include "micro/base/serialize.h"
#include "micro/model/operator_def.h"
#include "micro/model/output_shape.h"
namespace micro {
struct MaceMicroEngineConfig;
namespace framework {
class Operator;
class OpContext : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(OpContext)
MACE_DECLARE_OBJECT_FUNC(uint32_t, op_idx);
MACE_DECLARE_PTR_ARRAY_FUNC(OpIOInfo, input_info);
MACE_DECLARE_PTR_ARRAY_FUNC(model::OutputShape, output_resize_shape);
MaceStatus Init(MaceMicroEngineConfig *engine_config,
const model::OperatorDef *op_def);
MaceStatus Run(MaceMicroEngineConfig *engine_config);
protected:
SerialUint32 op_idx_;
SerialArray<OpIOInfo> input_infos_;
SerialArray<model::OutputShape> output_resize_shapes_;
};
} // namespace framework
} // namespace micro
#endif // MICRO_FRAMEWORK_OP_CONTEXT_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/framework/operator.h"
#include "micro/base/utils.h"
#include "micro/framework/op_context.h"
#include "micro/include/port/define.h"
#include "micro/include/public/micro.h"
#include "micro/model/const_tensor.h"
#include "micro/model/input_output_info.h"
#include "micro/model/net_def.h"
#include "micro/model/operator_def.h"
namespace micro {
namespace framework {
namespace {
const uint16_t kIdxConstTensor = 0xffff;
const uint16_t kIdxModelInput = 0xfffe;
} // namespace
Operator::~Operator() {}
MaceStatus Operator::Init(MaceMicroEngineConfig *engine_config,
framework::OpContext *op_context,
const model::OperatorDef *op_def) {
engine_config_ = engine_config;
op_context_ = op_context;
op_def_ = op_def;
MACE_ASSERT1(op_def_->input_size() == op_context_->input_info_size(),
"op_def_'s input dosen't match the op_context_'s");
MACE_ASSERT1(
op_def_->output_size() == op_context_->output_resize_shape_size(),
"op_def_'s output dosen't match the op_context_'s");
return OnInit();
}
MaceStatus Operator::Run() {
MACE_NOT_IMPLEMENTED;
return MACE_SUCCESS;
}
MaceStatus Operator::OnInit() {
return MACE_SUCCESS;
}
const model::Argument *Operator::GetArgByName(const char *name) const {
MACE_ASSERT(op_def_ != NULL);
for (uint32_t i = 0; i < op_def_->arg_size(); ++i) {
const model::Argument *argument = op_def_->arg(i);
if (base::strcmp(name, argument->name()) == 0) {
return argument;
}
}
return NULL;
}
uint32_t Operator::GetInputSize() {
return op_def_->input_size();
}
const void *Operator::DoGetInputData(uint32_t idx) {
const void *data = NULL;
const OpIOInfo *input_info = op_context_->input_info(idx);
const uint32_t op_def_idx = input_info->op_def_idx_;
if (kIdxConstTensor == op_def_idx) {
const model::ConstTensor *const_tensor =
engine_config_->net_def_->tensor(input_info->output_idx_);
data = engine_config_->model_data_ + const_tensor->offset();
} else if (kIdxModelInput == op_def_idx) {
data = engine_config_->input_buffers_[input_info->output_idx_];
} else {
const model::OperatorDef *pre_op_def =
engine_config_->net_def_->op(op_def_idx);
data = engine_config_->tensor_mem_ +
pre_op_def->mem_offset(input_info->output_idx_);
}
return data;
}
uint32_t Operator::GetInputShapeDimSize(uint32_t idx) {
uint32_t dim_size = 0;
const OpIOInfo *input_info = op_context_->input_info(idx);
const uint32_t op_def_idx = input_info->op_def_idx_;
if (kIdxConstTensor == op_def_idx) {
const model::ConstTensor *const_tensor =
engine_config_->net_def_->tensor(input_info->output_idx_);
dim_size = const_tensor->dim_size();
} else if (kIdxModelInput == op_def_idx) {
const model::InputOutputInfo *info =
engine_config_->net_def_->input_info(input_info->output_idx_);
dim_size = info->dim_size();
} else {
const model::OperatorDef *op_def = engine_config_->net_def_->op(op_def_idx);
const model::OutputShape *output_shape =
op_def->output_shape(input_info->output_idx_);
dim_size = output_shape->dim_size();
}
return dim_size;
}
const int32_t *Operator::GetInputShapeDims(uint32_t idx) {
const int32_t *dims = NULL;
const OpIOInfo *input_info = op_context_->input_info(idx);
const uint32_t op_def_idx = input_info->op_def_idx_;
if (kIdxConstTensor == op_def_idx) {
const model::ConstTensor *const_tensor =
engine_config_->net_def_->tensor(input_info->output_idx_);
dims = const_tensor->dim();
} else if (kIdxModelInput == op_def_idx) {
dims = engine_config_->input_shapes_[input_info->output_idx_];
} else {
const model::OperatorDef *op_def = engine_config_->net_def_->op(op_def_idx);
const model::OutputShape *output_shape =
op_def->output_shape(input_info->output_idx_);
dims = output_shape->dim();
}
return dims;
}
uint32_t Operator::GetOutputSize() {
return op_def_->output_size();
}
DataType Operator::GetOutputDataType(uint32_t idx) {
return op_def_->output_type(idx);
}
void *Operator::DoGetOutputData(uint32_t idx) {
return engine_config_->tensor_mem_ + op_def_->mem_offset(idx);
}
uint32_t Operator::GetOutputShapeDimSize(uint32_t idx) {
uint32_t dim_size = 0;
model::OutputShape *output_shape =
const_cast<model::OutputShape *>(op_context_->output_resize_shape(idx));
if (output_shape != NULL) {
dim_size = output_shape->dim_size();
}
return dim_size;
}
const int32_t *Operator::GetOutputShapeDims(uint32_t idx) {
const int32_t *dims = NULL;
model::OutputShape *output_shape =
const_cast<model::OutputShape *>(op_context_->output_resize_shape(idx));
if (output_shape != NULL) {
dims = output_shape->dim();
}
return dims;
}
MaceStatus Operator::ResizeOutputShape(uint32_t idx, uint32_t dim_size,
const int32_t *dims) {
model::OutputShape *output_shape =
const_cast<model::OutputShape *>(op_context_->output_resize_shape(idx));
#ifndef NDEBUG
if (op_def_->output_shape(idx)->dim_size() < dim_size
|| output_shape->dim_size() < dim_size) {
LOG(FATAL) << "Can not support dynamic dim_size. op_def_dim_size = "
<< op_def_->output_shape(idx)->dim_size()
<< ", output_shape_dim_size = " << output_shape->dim_size()
<< ", dim_size = " << dim_size;
}
int32_t def_output_shape_size =
base::GetShapeSize(output_shape->dim_size(), output_shape->dim());
int32_t input_shape_size = base::GetShapeSize(dim_size, dims);
if (def_output_shape_size < input_shape_size) {
LOG(INFO) << op_def_->name() << " resize failed, because "
<< def_output_shape_size << " < " << input_shape_size;
LOG(INFO) << "input: ";
for (uint32_t i = 0; i < dim_size; ++i) {
LOG(INFO) << dims[i] << ", ";
}
LOG(INFO) << "old output: ";
for (uint32_t i = 0; i < output_shape->dim_size(); ++i) {
LOG(INFO) << output_shape->dim(i) << ", ";
}
MACE_ASSERT(def_output_shape_size >= input_shape_size);
}
#endif // NDEBUG
if (dim_size > 0) {
base::memcpy(output_shape->mutable_dim(), dims, dim_size * sizeof(int32_t));
}
return MACE_SUCCESS;
}
#ifndef MACE_DEFINE_GET_ARG_BY_NAME_FUNC
#define MACE_DEFINE_GET_ARG_BY_NAME_FUNC(T, FUNC) \
template <> \
T Operator::GetArgByName(const char *name, T default_value) const { \
const model::Argument *arg = GetArgByName(name); \
if (arg == NULL) { \
return default_value; \
} else { \
return arg->FUNC(); \
} \
}
#endif // MACE_DEFINE_GET_ARG_BY_NAME_FUNC
MACE_DEFINE_GET_ARG_BY_NAME_FUNC(bool, i)
MACE_DEFINE_GET_ARG_BY_NAME_FUNC(int32_t, i)
MACE_DEFINE_GET_ARG_BY_NAME_FUNC(float, f)
#ifndef MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC
#define MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC(T, FUNC) \
template <> \
const T *Operator::GetRepeatArgByName(const char *name, \
uint32_t *size) const { \
const model::Argument *arg = GetArgByName(name); \
if (arg == NULL) { \
return NULL; \
} \
if (size != NULL) { \
*size = arg->FUNC##_size(); \
} \
return arg->FUNC(); \
}
#endif // MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC
MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC(int32_t, ints)
MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC(float, floats)
MACE_DEFINE_GET_ARRAY_ARG_BY_NAME_FUNC(uint8_t, s)
} // namespace framework
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_FRAMEWORK_OPERATOR_H_
#define MICRO_FRAMEWORK_OPERATOR_H_
#include "micro/base/logging.h"
#include "micro/base/types.h"
#include "micro/include/public/micro.h"
#include "micro/framework/scratch_buffer.h"
namespace micro {
struct MaceMicroEngineConfig;
namespace model {
class Argument;
class OperatorDef;
class OutputShape;
} // namespace model
namespace ops {
typedef framework::ScratchBuffer ScratchBuffer;
}
namespace framework {
#ifndef MACE_OP_INPUT_TAGS
#define MACE_OP_INPUT_TAGS(first_input, ...) \
enum _InputTags { first_input = 0, __VA_ARGS__ }
#endif // MACE_OP_INPUT_TAGS
#ifndef MACE_OP_OUTPUT_TAGS
#define MACE_OP_OUTPUT_TAGS(first_input, ...) \
enum _OutputTags { first_input = 0, __VA_ARGS__ }
#endif // MACE_OP_OUTPUT_TAGS
class OpContext;
class Operator {
public:
Operator() {}
// Note: This func should be virtual, but if we make it virtual,
// the operator delete will be needed, which is in c++ runtime library.
// For we don't use the Operator pointer to point sub-classes, the
// virtual ~Operator() is not needed.
~Operator();
MaceStatus Init(MaceMicroEngineConfig *engine_config,
OpContext *op_context,
const model::OperatorDef *op_def);
virtual MaceStatus OnInit();
virtual MaceStatus Run();
template<typename T>
T GetArgByName(const char *name, T default_value) const;
template<typename T>
const T *GetRepeatArgByName(const char *name,
uint32_t *size = NULL) const;
protected:
uint32_t GetInputSize();
const void *DoGetInputData(uint32_t idx);
uint32_t GetInputShapeDimSize(uint32_t idx);
const int32_t *GetInputShapeDims(uint32_t idx);
uint32_t GetOutputSize();
DataType GetOutputDataType(uint32_t idx);
void *DoGetOutputData(uint32_t idx);
uint32_t GetOutputShapeDimSize(uint32_t idx);
const int32_t *GetOutputShapeDims(uint32_t idx);
MaceStatus ResizeOutputShape(uint32_t idx, uint32_t input_dim_size,
const int32_t *input_dims);
MaceStatus ReuseInputBufferForOutput(uint32_t output_idx, uint32_t input_idx);
template<typename T>
const T *GetInputData(uint32_t idx) {
return static_cast<const T *>(DoGetInputData(idx));
}
template<typename T>
T *GetOutputData(uint32_t idx) {
return static_cast<T *>(DoGetOutputData(idx));
}
private:
const model::Argument *GetArgByName(const char *name) const;
protected:
const model::OperatorDef *op_def_;
MaceMicroEngineConfig *engine_config_;
private:
OpContext *op_context_;
};
} // namespace framework
} // namespace micro
#endif // MICRO_FRAMEWORK_OPERATOR_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/framework/scratch_buffer.h"
#include "micro/base/logging.h"
#include "micro/include/public/micro.h"
namespace micro {
namespace framework {
#ifndef NDEBUG
namespace {
int64_t kDetectHandle = -1;
}
#endif
ScratchBuffer::ScratchBuffer(MaceMicroEngineConfig *engine_config) :
engine_config_(engine_config), offset_(0) {
#ifndef NDEBUG
int64_t cur_handle = reinterpret_cast<int64_t>(engine_config);
MACE_ASSERT1(cur_handle != kDetectHandle, "Detect scratch buffer error.");
kDetectHandle = cur_handle;
#endif
}
ScratchBuffer::~ScratchBuffer() {
#ifndef NDEBUG
kDetectHandle = -1;
#endif
}
void *ScratchBuffer::DoGetBuffer(uint32_t size) {
if (size % 4 != 0) {
size = (size + 3) / 4 * 4;
}
if (offset_ + size > engine_config_->scratch_buffer_size_) {
LOG(FATAL) << "The scratch buffer is not enough."
<< "offset_: " << offset_ << ", size: " << size
<< ", engine_config_->scratch_buffer_size_: "
<< engine_config_->scratch_buffer_size_;
}
void *ptr = engine_config_->scratch_buffer_ + offset_;
offset_ += size;
return ptr;
}
} // namespace framework
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_FRAMEWORK_SCRATCH_BUFFER_H_
#define MICRO_FRAMEWORK_SCRATCH_BUFFER_H_
#include "micro/base/logging.h"
#include "micro/include/public/micro.h"
namespace micro {
namespace framework {
class ScratchBuffer {
public:
explicit ScratchBuffer(MaceMicroEngineConfig *engine_config);
~ScratchBuffer();
template<typename T>
T *GetBuffer(int32_t size) {
MACE_ASSERT(size > 0);
return static_cast<T *>(
DoGetBuffer(static_cast<uint32_t>(size) * sizeof(T)));
}
template<typename T>
T *GetBuffer(uint32_t size) {
return static_cast<T *>(DoGetBuffer(size * sizeof(T)));
}
private:
void *DoGetBuffer(uint32_t size);
private:
const MaceMicroEngineConfig *engine_config_;
uint32_t offset_;
};
} // namespace framework
} // namespace micro
#endif // MICRO_FRAMEWORK_SCRATCH_BUFFER_H_
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "include",
hdrs = glob([
"public/*.h",
"port/*.h",
"utils/*.h",
]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
)
cc_library(
name = "public_headers",
hdrs = glob([
"public/*.h",
]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
)
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_INCLUDE_PORT_DEFINE_H_
#define MICRO_INCLUDE_PORT_DEFINE_H_
#define MACE_API
#define MACE_DEPRECATED
#ifndef __FILE__
#define __FILE__ ""
#endif
#ifndef __LINE__
#define __LINE__ 0
#endif
#ifndef NULL
#define NULL 0
#endif
#endif // MICRO_INCLUDE_PORT_DEFINE_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_INCLUDE_PUBLIC_MICRO_H_
#define MICRO_INCLUDE_PUBLIC_MICRO_H_
#include <stdint.h>
#include "micro/include/port/define.h"
namespace micro {
enum DataFormat {
NONE = 0, NHWC = 1, NCHW = 2,
HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103,
AUTO = 1000,
};
enum PerfHint {
PERF_DEFAULT = 0,
PERF_LOW = 1,
PERF_NORMAL = 2,
PERF_HIGH = 3
};
enum DataType {
DT_INVALID = 0,
DT_FLOAT = 1,
DT_UINT8 = 2,
DT_HALF = 3,
DT_INT32 = 4,
DT_FLOAT16 = 5,
DT_BFLOAT16 = 6,
};
enum MaceStatus {
MACE_SUCCESS = 0,
MACE_INVALID_ARGS = 1,
MACE_OUT_OF_RESOURCES = 2,
MACE_UNSUPPORTED = 3,
MACE_RUNTIME_ERROR = 4,
};
namespace model {
class NetDef;
} // namespace model
namespace framework {
class Graph;
class Operator;
} // namespace framework
struct MACE_API MaceMicroEngineConfig {
model::NetDef *net_def_;
const uint8_t *model_data_;
framework::Graph *graph_;
framework::Operator **op_array_;
uint8_t *tensor_mem_;
const void **input_buffers_;
const int32_t **input_shapes_;
uint8_t *scratch_buffer_;
uint32_t scratch_buffer_size_;
};
class MACE_API MaceMicroEngine {
public:
MaceMicroEngine() {}
~MaceMicroEngine() {}
MaceStatus Init(MaceMicroEngineConfig *engine_config);
MaceStatus RegisterInputData(uint32_t idx, const void *input_buffer,
const int32_t *input_dims);
MaceStatus Run();
MaceStatus GetOutputData(const uint32_t idx, void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size);
MaceStatus GetOpOutputData(const uint32_t op_def_idx,
const uint32_t output_idx,
void **output_data,
const int32_t **output_dims,
uint32_t *output_dim_size);
private:
MaceMicroEngineConfig *engine_config_;
MaceMicroEngine(const MaceMicroEngine &);
MaceMicroEngine &operator=(const MaceMicroEngine &);
};
} // namespace micro
#endif // MICRO_INCLUDE_PUBLIC_MICRO_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_INCLUDE_UTILS_BFLOAT16_H_
#define MICRO_INCLUDE_UTILS_BFLOAT16_H_
#include <stdint.h>
#ifdef MACE_ENABLE_BFLOAT16
namespace micro {
union Sphinx {
uint32_t i;
float f;
Sphinx(uint32_t value) : i(value) {}
Sphinx(float value) : f(value) {}
};
class BFloat16 {
public:
BFloat16();
operator float() const {
return Sphinx(static_cast<uint32_t>(data_ << 16)).f;
}
void operator=(const BFloat16 &value) {
data_ = value.data_;
}
void operator=(float value) {
data_ = Sphinx(value).i >> 16;
}
public:
uint16_t data_;
};
} // namespace micro
#endif // MACE_ENABLE_BFLOAT16
#endif // MICRO_INCLUDE_UTILS_BFLOAT16_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_INCLUDE_UTILS_MACROS_H_
#define MICRO_INCLUDE_UTILS_MACROS_H_
#include "micro/include/public/micro.h"
namespace micro {
#ifndef MACE_EMPTY_VIRTUAL_DESTRUCTOR
#define MACE_EMPTY_VIRTUAL_DESTRUCTOR(CLASSNAME) \
public: \
virtual ~CLASSNAME() {}
#endif // MACE_EMPTY_VIRTUAL_DESTRUCTOR
#define MACE_UNUSED(var) (void)(var)
} // namespace micro
#endif // MICRO_INCLUDE_UTILS_MACROS_H_
def if_hexagon_enabled(a):
return select({
"//micro:hexagon_enabled": a,
"//conditions:default": [],
})
def if_not_hexagon_enabled(a):
return select({
"//micro:hexagon_enabled": [],
"//conditions:default": a,
})
def new_local_repository_env_impl(repository_ctx):
echo_cmd = "echo " + repository_ctx.attr.path
echo_result = repository_ctx.execute(["bash", "-c", echo_cmd])
src_path_str = echo_result.stdout.splitlines()[0]
source_path = repository_ctx.path(src_path_str)
work_path = repository_ctx.path(".")
child_list = source_path.readdir()
for child in child_list:
child_name = child.basename
repository_ctx.symlink(child, work_path.get_child(child_name))
build_file_babel = Label("//:" + repository_ctx.attr.build_file)
build_file_path = repository_ctx.path(build_file_babel)
repository_ctx.symlink(build_file_path, work_path.get_child("BUILD"))
# a new_local_repository support environment variable
new_local_repository_env = repository_rule(
implementation = new_local_repository_env_impl,
local = True,
attrs = {
"path": attr.string(mandatory = True),
"build_file": attr.string(mandatory = True),
},
)
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "model",
srcs = glob(["*.cc"]),
hdrs = glob(["*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/base",
"//micro/include",
],
)
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/argument.h"
namespace micro {
namespace model {
MACE_DEFINE_STRING_FUNC(Argument, name, name_)
MACE_DEFINE_OBJECT_FUNC(Argument, float, f)
MACE_DEFINE_OBJECT_FUNC(Argument, int32_t, i)
MACE_DEFINE_BYTES_FUNC(Argument, s, s_)
MACE_DEFINE_ARRAY_FUNC(Argument, float, floats, floats_)
MACE_DEFINE_ARRAY_BASE_PTR_FUNC(Argument, float, floats, floats_)
MACE_DEFINE_ARRAY_FUNC(Argument, int32_t, ints, ints_)
MACE_DEFINE_ARRAY_BASE_PTR_FUNC(Argument, int32_t, ints, ints_)
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_ARGUMENT_H_
#define MICRO_MODEL_ARGUMENT_H_
#include "micro/base/serialize.h"
namespace micro {
namespace model {
class Argument : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(Argument)
MACE_DECLARE_STRING_FUNC(name);
MACE_DECLARE_OBJECT_FUNC(float, f);
MACE_DECLARE_OBJECT_FUNC(int32_t, i);
MACE_DECLARE_BYTES_FUNC(s);
MACE_DECLARE_ARRAY_FUNC(float, floats);
MACE_DECLARE_ARRAY_BASE_PTR_FUNC(float, floats);
MACE_DECLARE_ARRAY_FUNC(int32_t, ints);
MACE_DECLARE_ARRAY_BASE_PTR_FUNC(int32_t, ints);
private:
SerialString name_;
SerialFloat f_;
SerialInt32 i_;
SerialBytes s_;
SerialArray<SerialFloat> floats_;
SerialArray<SerialInt32> ints_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_ARGUMENT_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/const_tensor.h"
namespace micro {
namespace model {
MACE_DEFINE_ARRAY_FUNC(ConstTensor, int32_t, dim, dims_)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, DataType, data_type)
MACE_DEFINE_ARRAY_FUNC(ConstTensor, float, float_data, float_datas_)
MACE_DEFINE_ARRAY_FUNC(ConstTensor, int32_t, int32_data, int32_datas_)
MACE_DEFINE_STRING_FUNC(ConstTensor, name, name_)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, int32_t, offset)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, int32_t, data_size)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, float, scale)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, int32_t, zero_point)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, float, minval)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, float, maxval)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, bool, quantized)
MACE_DEFINE_OBJECT_FUNC(ConstTensor, uint32_t, node_id)
const int32_t *ConstTensor::dim() const {
const int32_t *array = reinterpret_cast<const int32_t *>(
reinterpret_cast<const uint8_t *>(this) + dims_.offset_);
return array;
}
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_CONST_TENSOR_H_
#define MICRO_MODEL_CONST_TENSOR_H_
#include "micro/base/serialize.h"
#include "micro/include/public/micro.h"
namespace micro {
namespace model {
class ConstTensor : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(ConstTensor)
MACE_DECLARE_ARRAY_FUNC(int32_t, dim);
MACE_DECLARE_OBJECT_FUNC(DataType, data_type);
MACE_DECLARE_ARRAY_FUNC(float, float_data);
MACE_DECLARE_ARRAY_FUNC(int32_t, int32_data);
MACE_DECLARE_STRING_FUNC(name);
MACE_DECLARE_OBJECT_FUNC(int32_t, offset);
MACE_DECLARE_OBJECT_FUNC(int32_t, data_size);
MACE_DECLARE_OBJECT_FUNC(float, scale);
MACE_DECLARE_OBJECT_FUNC(int32_t, zero_point);
MACE_DECLARE_OBJECT_FUNC(float, minval);
MACE_DECLARE_OBJECT_FUNC(float, maxval);
MACE_DECLARE_OBJECT_FUNC(bool, quantized);
MACE_DECLARE_OBJECT_FUNC(uint32_t, node_id);
const int32_t *dim() const;
private:
SerialArray<SerialInt32> dims_;
DataType data_type_;
SerialArray<SerialFloat> float_datas_;
SerialArray<SerialInt32> int32_datas_;
SerialString name_;
SerialInt32 offset_;
SerialInt32 data_size_;
SerialFloat scale_;
SerialInt32 zero_point_;
SerialFloat minval_;
SerialFloat maxval_;
SerialBool quantized_;
SerialUint32 node_id_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_CONST_TENSOR_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/input_output_info.h"
namespace micro {
namespace model {
MACE_DEFINE_STRING_FUNC(InputOutputInfo, name, name_)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, int32_t, node_id)
MACE_DEFINE_ARRAY_FUNC(InputOutputInfo, int32_t, dim, dims_)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, int32_t, max_byte_size)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, int32_t, data_type)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, int32_t, data_format)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, float, scale)
MACE_DEFINE_OBJECT_FUNC(InputOutputInfo, int32_t, zero_point)
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_INPUT_OUTPUT_INFO_H_
#define MICRO_MODEL_INPUT_OUTPUT_INFO_H_
#include "micro/base/serialize.h"
namespace micro {
namespace model {
class InputOutputInfo : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(InputOutputInfo)
MACE_DECLARE_STRING_FUNC(name);
MACE_DECLARE_OBJECT_FUNC(int32_t, node_id);
MACE_DECLARE_ARRAY_FUNC(int32_t, dim);
MACE_DECLARE_OBJECT_FUNC(int32_t, max_byte_size);
MACE_DECLARE_OBJECT_FUNC(int32_t, data_type);
MACE_DECLARE_OBJECT_FUNC(int32_t, data_format);
MACE_DECLARE_OBJECT_FUNC(float, scale);
MACE_DECLARE_OBJECT_FUNC(int32_t, zero_point);
private:
SerialString name_;
SerialInt32 node_id_;
SerialArray<SerialInt32> dims_;
SerialInt32 max_byte_size_;
SerialInt32 data_type_;
SerialInt32 data_format_;
SerialFloat scale_;
SerialInt32 zero_point_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_INPUT_OUTPUT_INFO_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/net_def.h"
namespace micro {
namespace model {
MACE_DEFINE_PTR_ARRAY_FUNC(NetDef, OperatorDef, op, ops_)
MACE_DEFINE_PTR_ARRAY_FUNC(NetDef, Argument, arg, args_)
MACE_DEFINE_PTR_ARRAY_FUNC(NetDef, ConstTensor, tensor, tensors_)
MACE_DEFINE_OBJECT_FUNC(NetDef, int32_t, data_type)
MACE_DEFINE_PTR_ARRAY_FUNC(NetDef, InputOutputInfo, input_info, input_infos_)
MACE_DEFINE_PTR_ARRAY_FUNC(NetDef, InputOutputInfo, output_info, output_infos_)
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_NET_DEF_H_
#define MICRO_MODEL_NET_DEF_H_
#include "micro/base/serialize.h"
#include "micro/model/argument.h"
#include "micro/model/const_tensor.h"
#include "micro/model/input_output_info.h"
#include "micro/model/operator_def.h"
namespace micro {
namespace model {
class NetDef : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(NetDef)
MACE_DECLARE_PTR_ARRAY_FUNC(OperatorDef, op);
MACE_DECLARE_PTR_ARRAY_FUNC(Argument, arg);
MACE_DECLARE_PTR_ARRAY_FUNC(ConstTensor, tensor);
MACE_DECLARE_OBJECT_FUNC(int32_t, data_type);
MACE_DECLARE_PTR_ARRAY_FUNC(InputOutputInfo, input_info);
MACE_DECLARE_PTR_ARRAY_FUNC(InputOutputInfo, output_info);
private:
SerialArray<OperatorDef> ops_;
SerialArray<Argument> args_;
SerialArray<ConstTensor> tensors_;
SerialInt32 data_type_;
SerialArray<InputOutputInfo> input_infos_;
SerialArray<InputOutputInfo> output_infos_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_NET_DEF_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/operator_def.h"
namespace micro {
namespace model {
MACE_DEFINE_STRING_ARRAY_FUNC(OperatorDef, input, inputs_)
MACE_DEFINE_STRING_ARRAY_FUNC(OperatorDef, output, outputs_)
MACE_DEFINE_STRING_FUNC(OperatorDef, name, name_)
MACE_DEFINE_STRING_FUNC(OperatorDef, type, type_)
MACE_DEFINE_OBJECT_FUNC(OperatorDef, int32_t, device_type)
MACE_DEFINE_PTR_ARRAY_FUNC(OperatorDef, Argument, arg, args_)
MACE_DEFINE_PTR_ARRAY_FUNC(OperatorDef, OutputShape,
output_shape, output_shapes_)
MACE_DEFINE_ARRAY_FUNC(OperatorDef, DataType, output_type, output_types_)
// the mem_offset is the mem_id in proto file
MACE_DEFINE_ARRAY_FUNC(OperatorDef, int32_t, mem_offset, mem_offsets_)
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_OPERATOR_DEF_H_
#define MICRO_MODEL_OPERATOR_DEF_H_
#include "micro/base/serialize.h"
#include "micro/include/public/micro.h"
#include "micro/model/argument.h"
#include "micro/model/output_shape.h"
namespace micro {
namespace model {
class OperatorDef : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(OperatorDef)
MACE_DECLARE_STRING_ARRAY_FUNC(input);
MACE_DECLARE_STRING_ARRAY_FUNC(output);
MACE_DECLARE_STRING_FUNC(name);
MACE_DECLARE_STRING_FUNC(type);
MACE_DECLARE_OBJECT_FUNC(int32_t, device_type);
MACE_DECLARE_PTR_ARRAY_FUNC(Argument, arg);
MACE_DECLARE_PTR_ARRAY_FUNC(OutputShape, output_shape);
MACE_DECLARE_ARRAY_FUNC(DataType, output_type);
// the mem_offset is the mem_id in proto file
MACE_DECLARE_ARRAY_FUNC(int32_t, mem_offset);
private:
SerialArray<SerialString> inputs_;
SerialArray<SerialString> outputs_;
SerialString name_;
SerialString type_;
// device_type_ is not used currently, for future;
SerialInt32 device_type_;
SerialArray<Argument> args_;
SerialArray<OutputShape> output_shapes_;
SerialArray<DataType> output_types_;
SerialArray<SerialInt32> mem_offsets_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_OPERATOR_DEF_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/model/output_shape.h"
namespace micro {
namespace model {
MACE_DEFINE_ARRAY_FUNC(OutputShape, int32_t, dim, dims_)
const int32_t *OutputShape::dim() const {
const int32_t *array = reinterpret_cast<const int32_t *>(
reinterpret_cast<const char *>(this) + dims_.offset_);
return array;
}
int32_t *OutputShape::mutable_dim() {
char *base_addr = reinterpret_cast<char *>(const_cast<OutputShape *>(this));
int32_t *array = reinterpret_cast<int32_t *>(base_addr + dims_.offset_);
return array;
}
} // namespace model
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_MODEL_OUTPUT_SHAPE_H_
#define MICRO_MODEL_OUTPUT_SHAPE_H_
#include "micro/base/serialize.h"
namespace micro {
namespace model {
class OutputShape : public Serialize {
public:
MACE_DEFINE_HARD_CODE_MAGIC(OutputShape)
MACE_DECLARE_ARRAY_FUNC(int32_t, dim);
const int32_t *dim() const;
int32_t *mutable_dim();
private:
SerialArray<SerialInt32> dims_;
};
} // namespace model
} // namespace micro
#endif // MICRO_MODEL_OUTPUT_SHAPE_H_
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "ops",
srcs = glob(["**/*.cc"]),
hdrs = glob(["**/*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/base",
"//micro/framework",
],
)
cc_library(
name = "ops_for_test",
srcs = glob(["**/*.cc"]),
hdrs = glob(["**/*.h"]),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
],
deps = [
"//micro/base",
"//micro/framework:framework_for_optest",
],
alwayslink = 1,
)
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/ops/activation.h"
#include "micro/base/logging.h"
#include "micro/base/utils.h"
#include "micro/model/argument.h"
namespace micro {
namespace ops {
namespace {
template<typename T>
void PReLUActivation(const T *input_ptr, const int32_t outer_size,
const int32_t channel, const T *alpha_ptr,
T *output_ptr) {
for (int32_t i = 0; i < outer_size; ++i) {
const int32_t outer_base = i * channel;
for (int32_t c = 0; c < channel; ++c) {
const int32_t idx = outer_base + c;
if (input_ptr[idx] < 0) {
output_ptr[idx] = input_ptr[idx] * alpha_ptr[c];
} else {
output_ptr[idx] = input_ptr[idx];
}
}
}
}
} // namespace
MaceStatus ActivationOp::OnInit() {
input_ = GetInputData<mifloat>(INPUT);
input_dims_ = GetInputShapeDims(INPUT);
input_dim_size_ = GetInputShapeDimSize(INPUT);
output_ = GetOutputData<mifloat>(OUTPUT);
return activation_.Init(this);
}
MaceStatus ActivationOp::Run() {
MACE_RETURN_IF_ERROR(ResizeOutputShape(OUTPUT, input_dim_size_, input_dims_));
if (activation_.GetActivationType() == PRELU) {
MACE_ASSERT(GetInputSize() > 1);
const mifloat *alpha = GetInputData<mifloat>(ALPHA);
const int32_t outer_size =
base::accumulate_multi(input_dims_, 0, input_dim_size_ - 1);
const int32_t channel = input_dims_[input_dim_size_ - 1];
PReLUActivation(input_, outer_size, channel, alpha, output_);
return MACE_SUCCESS;
} else {
const int32_t input_size = base::GetShapeSize(input_dim_size_, input_dims_);
return activation_.Compute(input_, input_size, output_);
}
}
} // namespace ops
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_OPS_ACTIVATION_H_
#define MICRO_OPS_ACTIVATION_H_
#include "micro/framework/operator.h"
#include "micro/ops/utils/activation.h"
namespace micro {
namespace ops {
class ActivationOp : public framework::Operator {
public:
MaceStatus OnInit();
MaceStatus Run();
private:
const mifloat *input_;
const int32_t *input_dims_;
uint32_t input_dim_size_;
mifloat *output_;
Activation activation_;
MACE_OP_INPUT_TAGS(INPUT, ALPHA);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace micro
#endif // MICRO_OPS_ACTIVATION_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_OPS_ARGMAX_H_
#define MICRO_OPS_ARGMAX_H_
#include "micro/base/logging.h"
#include "micro/base/utils.h"
#include "micro/framework/operator.h"
#include "micro/framework/scratch_buffer.h"
#include "micro/include/utils/macros.h"
namespace micro {
namespace ops {
template<class T>
class ArgMaxOp : public framework::Operator {
public:
MaceStatus OnInit() {
axis_ = GetArgByName("axis", static_cast<int32_t>(0));
keep_dims_ = GetArgByName("keepdims", true);
MACE_ASSERT1(keep_dims_, "Mace only supports keep_dims ArgMax.");
argmin_ = GetArgByName("argmin", false);
input_ = GetInputData<T>(INPUT);
input_dims_ = GetInputShapeDims(INPUT);
input_dim_size_ = GetInputShapeDimSize(INPUT);
MACE_ASSERT1(input_dim_size_ > 0, "ArgMax input should not be a scalar");
output_ = GetOutputData<int32_t>(OUTPUT);
output_dims_ = GetOutputShapeDims(OUTPUT);
output_dim_size_ = GetOutputShapeDimSize(OUTPUT);
return MACE_SUCCESS;
}
MaceStatus Run() {
int32_t axis_value = 0;
const int32_t *axis = GetInputSize() == 2 ?
GetInputData<int32_t>(AXIS) : NULL;
if (axis != NULL) {
MACE_ASSERT1(GetInputShapeDimSize(AXIS) == 0,
"Mace argmax only supports scalar axis");
axis_value = axis[0];
} else {
axis_value = axis_;
}
if (axis_value < 0) {
axis_value += input_dim_size_;
}
MACE_ASSERT1(axis_value == static_cast<int32_t>(input_dim_size_) - 1,
"Mace argmax only supports last dimension as axis");
MACE_ASSERT1(output_dim_size_ >= input_dim_size_ - 1,
"Convert model error.");
int32_t *output_dims =
ScratchBuffer(engine_config_).GetBuffer<int32_t>(output_dim_size_);
for (int32_t d = 0; d < static_cast<int32_t>(output_dim_size_); ++d) {
output_dims[d] = input_dims_[d < axis_value ? d : d + 1];
}
ResizeOutputShape(OUTPUT, output_dim_size_, output_dims);
int32_t outer_size = base::GetShapeSize(output_dim_size_, output_dims_);
int32_t inner_size = input_dims_[axis_value];
if (argmin_) {
for (int32_t i = 0; i < outer_size; ++i) {
int32_t idx = 0;
T min_value = base::highest();
const T *input_ptr = input_ + i * inner_size;
for (int32_t j = 0; j < inner_size; ++j) {
float input = input_ptr[j];
if (input < min_value) {
min_value = input;
idx = j;
}
}
output_[i] = idx;
}
} else {
for (int32_t i = 0; i < outer_size; ++i) {
int32_t idx = 0;
T max_value = base::lowest();
const T *input_ptr = input_ + i * inner_size;
for (int32_t j = 0; j < inner_size; ++j) {
float input = input_ptr[j];
if (input > max_value) {
max_value = input;
idx = j;
}
}
output_[i] = idx;
}
}
return MaceStatus::MACE_SUCCESS;
}
private:
int32_t axis_;
bool keep_dims_;
bool argmin_;
const T *input_;
const int32_t *input_dims_;
uint32_t input_dim_size_;
int32_t *output_;
const int32_t *output_dims_;
uint32_t output_dim_size_;
MACE_OP_INPUT_TAGS(INPUT, AXIS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace micro
#endif // MICRO_OPS_ARGMAX_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/ops/bias_add.h"
#include "micro/base/logging.h"
#include "micro/ops/utils/crumb_utils.h"
namespace micro {
namespace ops {
MaceStatus BiasAddOp::OnInit() {
MACE_ASSERT1(static_cast<DataFormat>(
GetArgByName("data_format", static_cast<int32_t>(NHWC)))
!= NCHW, "Now only support NHWC");
input_ = GetInputData<mifloat>(INPUT);
input_dims_ = GetInputShapeDims(INPUT);
input_dim_size_ = GetInputShapeDimSize(INPUT);
bias_ = GetInputData<mifloat>(BIAS);
bias_dims_ = GetInputShapeDims(BIAS);
bias_dim_size_ = GetInputShapeDimSize(BIAS);
output_ = GetOutputData<mifloat>(OUTPUT);
MACE_ASSERT1(bias_dim_size_ == 1, "Bias dim must be 1.");
MACE_ASSERT1(bias_dims_[0] == input_dims_[input_dim_size_ - 1],
"The bias's channel dim should be equal to the input's");
return ResizeOutputShape(OUTPUT, input_dim_size_, input_dims_);
}
MaceStatus BiasAddOp::Run() {
return crumb::ComputeBias(input_, input_dims_, input_dim_size_,
bias_, bias_dims_[0], output_);
}
} // namespace ops
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_OPS_BIAS_ADD_H_
#define MICRO_OPS_BIAS_ADD_H_
#include "micro/framework/operator.h"
namespace micro {
namespace ops {
class BiasAddOp : public framework::Operator {
public:
MaceStatus OnInit();
MaceStatus Run();
private:
const mifloat *input_;
const int32_t *input_dims_;
uint32_t input_dim_size_;
const mifloat *bias_;
const int32_t *bias_dims_;
uint32_t bias_dim_size_;
mifloat *output_;
MACE_OP_INPUT_TAGS(INPUT, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace micro
#endif // MICRO_OPS_BIAS_ADD_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_OPS_CAST_H_
#define MICRO_OPS_CAST_H_
#include "micro/base/utils.h"
#include "micro/base/types.h"
#include "micro/framework/operator.h"
#include "micro/include/utils/bfloat16.h"
namespace micro {
namespace ops {
#ifndef MACE_CAST_OP_CAST_TENSOR
#define MACE_CAST_OP_CAST_TENSOR(SrcType, DstType) \
const SrcType *input = static_cast<const SrcType *>(input_); \
DstType *output = static_cast<DstType *>(output_); \
for (int32_t i = 0; i < tensor_size_; ++i) { \
output[i] = input[i]; \
}
#endif // MACE_CAST_OP_CAST_TENSOR
class CastOp : public framework::Operator {
public:
MaceStatus OnInit() {
input_ = GetInputData<void>(INPUT);
input_dt_ = static_cast<DataType>(
GetArgByName("T", static_cast<int32_t >(DT_FLOAT)));
const int32_t *input_dims = GetInputShapeDims(INPUT);
const uint32_t input_dim_size_ = GetInputShapeDimSize(INPUT);
tensor_size_ = base::GetShapeSize(input_dim_size_, input_dims);
MACE_ASSERT(tensor_size_ > 0);
output_ = GetOutputData<void>(OUTPUT);
output_dt_ = GetOutputDataType(OUTPUT);
return MACE_SUCCESS;
}
MaceStatus Run() {
if (input_dt_ == DT_FLOAT && output_dt_ == DT_BFLOAT16) {
#ifdef MACE_ENABLE_BFLOAT16
MACE_CAST_OP_CAST_TENSOR(float, BFloat16)
#else
MACE_NOT_IMPLEMENTED;
#endif
} else if (input_dt_ == DT_BFLOAT16 && output_dt_ == DT_FLOAT) {
#ifdef MACE_ENABLE_BFLOAT16
MACE_CAST_OP_CAST_TENSOR(BFloat16, float)
#else
MACE_NOT_IMPLEMENTED;
#endif
} else {
MACE_NOT_IMPLEMENTED;
}
return MACE_SUCCESS;
}
private:
const void *input_;
DataType input_dt_;
int32_t tensor_size_;
void *output_;
DataType output_dt_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace micro
#endif // MICRO_OPS_CAST_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/ops/eltwise.h"
#include "micro/base/logging.h"
namespace micro {
namespace ops {
namespace eltwise {
bool ShapeIsEqual(const int32_t *dims0,
const int32_t *dims1, uint32_t dim_size) {
while (--dim_size > 0) {
if (dims0[dim_size] != dims1[dim_size])
return false;
}
return true;
}
int32_t GetIndex(const int32_t *shape,
const int32_t *index, int32_t dim_size) {
int32_t idx = 0;
for (int32_t i = 0; i < dim_size; ++i) {
if (shape[i] > 1) {
idx = idx * shape[i] + index[i];
}
}
return idx;
}
void IncreaseIndex(const int32_t *shape, int32_t **index, int32_t dim_size) {
for (int32_t i = dim_size - 1; i >= 0; --i) {
++(*index)[i];
if ((*index)[i] >= shape[i]) {
(*index)[i] -= shape[i];
} else {
break;
}
}
}
} // namespace eltwise
} // namespace ops
} // namespace micro
此差异已折叠。
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/ops/expand_dims.h"
#include "micro/base/logging.h"
#include "micro/base/utils.h"
#include "micro/framework/scratch_buffer.h"
#include "micro/model/argument.h"
namespace micro {
namespace ops {
MaceStatus ExpandDimsOp::OnInit() {
input_ = GetInputData<mifloat>(INPUT);
input_dims_ = GetInputShapeDims(INPUT);
input_dim_size_ = GetInputShapeDimSize(INPUT);
output_ = GetOutputData<mifloat>(OUTPUT);
axis_ = GetArgByName("axis", static_cast<int32_t>(0));
if (axis_ < 0) {
axis_ += input_dim_size_ + 1;
}
MACE_ASSERT2(axis_ >= 0 && axis_ <= static_cast<int32_t>(input_dim_size_),
"axis is out of bound: ", axis_);
return MACE_SUCCESS;
}
MaceStatus ExpandDimsOp::Run() {
int32_t output_dim_size = input_dim_size_ + 1;
int32_t *output_dims =
ScratchBuffer(engine_config_).GetBuffer<int32_t>(output_dim_size);
for (int32_t i = 0; i < output_dim_size; ++i) {
if (i < axis_) {
output_dims[i] = input_dims_[i];
} else if (i == axis_) {
output_dims[i] = 1;
} else {
output_dims[i] = input_dims_[i - 1];
}
}
// TODO(luxuhui): optimize this method by reusing buffer
int32_t input_data_size = base::GetShapeSize(input_dim_size_, input_dims_);
base::memcpy(output_, input_, input_data_size * sizeof(mifloat));
return ResizeOutputShape(OUTPUT, output_dim_size, output_dims);
}
} // namespace ops
} // namespace micro
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MICRO_OPS_EXPAND_DIMS_H_
#define MICRO_OPS_EXPAND_DIMS_H_
#include "micro/base/types.h"
#include "micro/framework/operator.h"
namespace micro {
namespace ops {
class ExpandDimsOp : public framework::Operator {
public:
MaceStatus OnInit();
MaceStatus Run();
private:
const mifloat *input_;
const int32_t *input_dims_;
uint32_t input_dim_size_;
mifloat *output_;
int32_t axis_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace micro
#endif // MICRO_OPS_EXPAND_DIMS_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "micro/ops/matmul.h"
#include "micro/base/logging.h"
#include "micro/base/utils.h"
#include "micro/framework/scratch_buffer.h"
#include "micro/model/argument.h"
namespace micro {
namespace ops {
MaceStatus MatMulOp::OnInit() {
transpose_a_ = GetArgByName("transpose_a", false);
transpose_b_ = GetArgByName("transpose_b", false);
input_a_ = GetInputData<mifloat>(INPUT_A);
input_b_ = GetInputData<mifloat>(INPUT_B);
bias_ = GetInputSize() > 3 ? GetInputData<mifloat>(BIAS) : NULL;
output_ = GetOutputData<mifloat>(OUTPUT);
input_a_dim_size_ = GetInputShapeDimSize(INPUT_A);
input_b_dim_size_ = GetInputShapeDimSize(INPUT_B);
input_a_dims_ = GetInputShapeDims(INPUT_A);
input_b_dims_ = GetInputShapeDims(INPUT_B);
MACE_ASSERT1(input_a_dim_size_ >= 2 && input_b_dim_size_ >= 2,
"rank should be greater than or equal to 2");
return MACE_SUCCESS;
}
MaceStatus MatMulOp::Run() {
MACE_ASSERT(Validate());
const int32_t lhs_rank = input_a_dim_size_;
const int32_t lhs_rows = input_a_dims_[lhs_rank - 2];
const int32_t lhs_cols = input_a_dims_[lhs_rank - 1];
const int32_t rhs_rank = input_b_dim_size_;
const int32_t rhs_rows = input_b_dims_[rhs_rank - 2];
const int32_t rhs_cols = input_b_dims_[rhs_rank - 1];
const int32_t rows = transpose_a_ ? lhs_cols : lhs_rows;
const int32_t cols = transpose_b_ ? rhs_rows : rhs_cols;
const int32_t depth = transpose_a_ ? lhs_rows : lhs_cols;
const int32_t lhs_batch =
base::accumulate_multi(input_a_dims_, 0, input_a_dim_size_ - 2);
const int32_t rhs_batch =
base::accumulate_multi(input_b_dims_, 0, input_b_dim_size_ - 2);
int32_t *output_dims =
ScratchBuffer(engine_config_).GetBuffer<int32_t>(input_a_dim_size_);
int32_t batch = 1;
base::memcpy(output_dims, input_a_dims_, input_a_dim_size_);
if (lhs_rank >= rhs_rank) {
output_dims[lhs_rank - 2] = rows;
output_dims[lhs_rank - 1] = cols;
batch = lhs_batch;
} else {
output_dims[rhs_rank - 2] = rows;
output_dims[rhs_rank - 1] = cols;
batch = rhs_batch;
}
bool lhs_batched = true;
bool rhs_batched = true;
if (lhs_rank < rhs_rank) {
lhs_batched = false;
} else if (rhs_rank < lhs_rank) {
rhs_batched = false;
}
MACE_RETURN_IF_ERROR(
ResizeOutputShape(OUTPUT, input_a_dim_size_, output_dims));
if (rows == 1 && transpose_b_) {
return gemv_.Compute(input_b_,
input_a_,
bias_,
batch,
cols,
depth,
rhs_batched,
lhs_batched,
output_);
} else if (cols == 1 && !transpose_a_) {
return gemv_.Compute(input_a_,
input_b_,
bias_,
batch,
rows,
depth,
lhs_batched,
rhs_batched,
output_);
} else {
MaceStatus ret = gemm_.Compute(input_a_,
input_b_,
batch,
lhs_rows,
lhs_cols,
rhs_rows,
rhs_cols,
transpose_a_,
transpose_b_,
false,
lhs_batched,
rhs_batched,
output_);
if (bias_ != NULL) {
MACE_ASSERT1(bias_dim_size_ == 1 && bias_dims_[0] == cols,
"bias' dim should be <= 2.");
for (int32_t i = 0; i < batch * rows; ++i) {
for (int32_t w = 0; w < cols; ++w) {
int32_t idx = i * cols + w;
output_[idx] = output_[idx] + bias_[w];
}
}
}
return ret;
}
}
bool MatMulOp::Validate() {
const int32_t lhs_rank = input_a_dim_size_;
const int32_t rhs_rank = input_b_dim_size_;
if (input_a_dim_size_ == input_b_dim_size_) {
for (uint32_t i = 0; i < input_a_dim_size_ - 2; ++i) {
MACE_ASSERT1(input_a_dims_[i] == input_b_dims_[i],
"batch dimensions are not equal");
}
} else {
MACE_ASSERT1(input_a_dim_size_ == 2 || input_b_dim_size_ == 2,
"Either lhs or rhs matrix should has rank 2 "
"for non-batched matrix multiplication");
}
int32_t lhs_depth = transpose_a_ ? input_a_dims_[lhs_rank - 2] :
input_a_dims_[lhs_rank - 1];
int32_t rhs_depth = transpose_b_ ? input_b_dims_[rhs_rank - 1] :
input_b_dims_[rhs_rank - 2];
if (lhs_depth != rhs_depth) {
MACE_ASSERT1(false, "the number of A's column must be equal to B's row ");
return false;
}
return true;
}
} // namespace ops
} // namespace micro
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册