提交 3713852b 编写于 作者: 李寅

Merge branch 'refactor_target_deps' into 'master'

fix targets circular dependencies

See merge request !655
......@@ -52,6 +52,7 @@ api_test:
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_mt_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//mace/test:mace_api_exception_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
ops_benchmark:
stage: ops_benchmark
......
......@@ -58,7 +58,7 @@ Define the Ops registering function in `mace/ops/my_custom_op.cc`.
namespace mace {
namespace ops {
void Register_My_Custom_Op(OperatorRegistry *op_registry) {
void Register_My_Custom_Op(OperatorRegistryBase *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("my_custom_op")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -63,7 +63,7 @@ cc_binary(
linkstatic = 0,
deps = [
":mace_version_script.lds",
"//mace/ops",
"//mace/libmace",
],
)
......@@ -81,6 +81,7 @@ genrule(
"//mace/core",
"//mace/kernels",
"//mace/ops",
"//mace/libmace",
"//mace/utils",
"//mace/proto:mace_cc",
"@com_google_protobuf//:protobuf_lite",
......@@ -93,6 +94,7 @@ genrule(
"$(locations //mace/core:core) " +
"$(locations //mace/kernels:kernels) " +
"$(locations //mace/ops:ops) " +
"$(locations //mace/libmace:libmace) " +
"$(locations //mace/utils:utils) " +
"$(locations //mace/proto:mace_cc) " +
"$(locations @com_google_protobuf//:protobuf_lite) " +
......
......@@ -36,7 +36,7 @@ cc_binary(
"//external:gflags_nothreads",
"//mace/codegen:generated_models",
"//mace/codegen:generated_mace_engine_factory",
"//mace/ops:ops",
"//mace/libmace:libmace",
],
)
......
......@@ -70,7 +70,6 @@ cc_library(
]) + if_hexagon_enabled([
"//third_party/nnlib:libhexagon",
]),
alwayslink = 1,
)
cc_library(
......@@ -109,5 +108,4 @@ cc_library(
"//external:gflags_nothreads",
"//mace/utils",
],
alwayslink = 1,
)
......@@ -22,7 +22,7 @@
namespace mace {
NetBase::NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
NetBase::NetBase(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type)
......@@ -31,11 +31,12 @@ NetBase::NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
MACE_UNUSED(type);
}
SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
const NetMode mode)
SerialNet::SerialNet(
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
const NetMode mode)
: NetBase(op_registry, net_def, ws, type), device_type_(type) {
MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name());
for (int idx = 0; idx < net_def->op_size(); ++idx) {
......@@ -130,7 +131,7 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
}
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const NetDef &net_def,
Workspace *ws,
DeviceType type,
......@@ -140,7 +141,7 @@ std::unique_ptr<NetBase> CreateNet(
}
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......
......@@ -30,7 +30,7 @@ class Workspace;
class NetBase {
public:
NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
NetBase(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type);
......@@ -42,14 +42,14 @@ class NetBase {
protected:
std::string name_;
const std::shared_ptr<const OperatorRegistry> op_registry_;
const std::shared_ptr<const OperatorRegistryBase> op_registry_;
MACE_DISABLE_COPY_AND_ASSIGN(NetBase);
};
class SerialNet : public NetBase {
public:
SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
SerialNet(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......@@ -65,13 +65,13 @@ class SerialNet : public NetBase {
};
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const NetDef &net_def,
Workspace *ws,
DeviceType type,
const NetMode mode = NetMode::NORMAL);
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......
......@@ -50,7 +50,9 @@ const std::string OpKeyBuilder::Build() {
return ss.str();
}
std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
OperatorRegistryBase::~OperatorRegistryBase() {}
std::unique_ptr<OperatorBase> OperatorRegistryBase::CreateOperator(
const OperatorDef &operator_def,
Workspace *ws,
DeviceType type,
......@@ -72,102 +74,4 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
}
}
namespace ops {
// Keep in lexicographical order
extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_ArgMax(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistry *op_registry);
extern void Register_BiasAdd(OperatorRegistry *op_registry);
extern void Register_Cast(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_Deconv2D(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Gather(OperatorRegistry *op_registry);
extern void Register_Identity(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_ReduceMean(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Shape(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_Stack(OperatorRegistry *op_registry);
extern void Register_StridedSlice(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_Squeeze(OperatorRegistry *op_registry);
extern void Register_Transpose(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
#ifdef MACE_ENABLE_OPENCL
extern void Register_BufferToImage(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
#endif // MACE_ENABLE_OPENCL
} // namespace ops
OperatorRegistry::OperatorRegistry() {
// Keep in lexicographical order
ops::Register_Activation(this);
ops::Register_AddN(this);
ops::Register_ArgMax(this);
ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this);
ops::Register_Cast(this);
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Squeeze(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
#ifdef MACE_ENABLE_OPENCL
ops::Register_BufferToImage(this);
ops::Register_ImageToBuffer(this);
#endif // MACE_ENABLE_OPENCL
}
} // namespace mace
......@@ -163,12 +163,12 @@ OpKeyBuilder &OpKeyBuilder::TypeConstraint(const char *attr_name) {
return this->TypeConstraint(attr_name, DataTypeToEnum<T>::value);
}
class OperatorRegistry {
class OperatorRegistryBase {
public:
typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *>
RegistryType;
OperatorRegistry();
~OperatorRegistry() = default;
OperatorRegistryBase() = default;
virtual ~OperatorRegistryBase();
RegistryType *registry() { return &registry_; }
std::unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
......@@ -177,7 +177,7 @@ class OperatorRegistry {
private:
RegistryType registry_;
MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistry);
MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistryBase);
};
MACE_DECLARE_REGISTRY(OpRegistry,
......
......@@ -78,7 +78,6 @@ class Registerer {
#endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__> *RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
......
......@@ -206,22 +206,5 @@ MaceStatus SetOpenMPThreadsAndAffinityPolicy(int omp_num_threads_hint,
return SetOpenMPThreadsAndAffinityCPUs(omp_num_threads_hint, use_cpu_ids);
}
MaceStatus SetOpenMPThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy) {
VLOG(1) << "Set OpenMP threads number hint: " << num_threads_hint
<< ", affinity policy: " << policy;
return SetOpenMPThreadsAndAffinityPolicy(num_threads_hint, policy);
}
MaceStatus SetOpenMPThreadAffinity(int num_threads,
const std::vector<int> &cpu_ids) {
return SetOpenMPThreadsAndAffinityCPUs(num_threads, cpu_ids);
}
MaceStatus GetBigLittleCoreIDs(std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids) {
return GetCPUBigLittleCoreIDs(big_core_ids, little_core_ids);
}
} // namespace mace
......@@ -33,20 +33,12 @@
namespace mace {
extern const std::map<std::string, std::vector<unsigned char>>
kEncryptedProgramMap;
std::shared_ptr<KVStorageFactory> kStorageFactory = nullptr;
void SetGPUHints(GPUPerfHint gpu_perf_hint, GPUPriorityHint gpu_priority_hint) {
VLOG(1) << "Set GPU configurations, gpu_perf_hint: " << gpu_perf_hint
<< ", gpu_priority_hint: " << gpu_priority_hint;
OpenCLRuntime::Configure(gpu_perf_hint, gpu_priority_hint);
}
// Set OpenCL Compiled Binary paths, just call once. (Not thread-safe)
void SetOpenCLBinaryPaths(const std::vector<std::string> &paths) {
OpenCLRuntime::ConfigureOpenCLBinaryPath(paths);
}
std::string kOpenCLParameterPath; // NOLINT(runtime/string)
extern const std::map<std::string, std::vector<unsigned char>>
kEncryptedProgramMap;
const std::string OpenCLErrorToString(cl_int error) {
switch (error) {
......
......@@ -15,6 +15,8 @@
#include <iostream>
#include "gflags/gflags.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/public/mace.h"
#include "mace/public/mace_runtime.h"
......@@ -34,13 +36,13 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
// config runtime
mace::MaceStatus status = mace::SetOpenMPThreadPolicy(
mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy(
FLAGS_omp_num_threads,
static_cast<mace::CPUAffinityPolicy >(FLAGS_cpu_affinity_policy));
if (status != mace::MACE_SUCCESS) {
LOG(WARNING) << "Set openmp or cpu affinity failed.";
}
mace::SetGPUHints(
mace::OpenCLRuntime::Configure(
static_cast<mace::GPUPerfHint>(FLAGS_gpu_perf_hint),
static_cast<mace::GPUPriorityHint>(FLAGS_gpu_priority_hint));
......
# Description:
# Mace libmace.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled")
cc_library(
name = "libmace",
srcs = glob(
["*.cc"],
),
copts = [
"-Werror",
"-Wextra",
] + if_openmp_enabled(["-fopenmp"]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
]) + if_android_armv7([
"-mfloat-abi=softfp",
]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
"//mace/public",
"//mace/ops",
],
alwayslink = 1,
)
......@@ -21,7 +21,7 @@
#include <memory>
#include "mace/core/net.h"
#include "mace/core/types.h"
#include "mace/ops/ops_register.h"
#include "mace/public/mace.h"
#ifdef MACE_ENABLE_OPENCL
......@@ -138,7 +138,7 @@ class MaceEngine::Impl {
private:
const unsigned char *model_data_;
size_t model_data_size_;
std::shared_ptr<OperatorRegistry> op_registry_;
std::shared_ptr<OperatorRegistryBase> op_registry_;
DeviceType device_type_;
std::unique_ptr<Workspace> ws_;
std::unique_ptr<NetBase> net_;
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h"
namespace mace {
extern std::shared_ptr<KVStorageFactory> kStorageFactory;
void SetKVStorageFactory(std::shared_ptr<KVStorageFactory> storage_factory) {
VLOG(1) << "Set internal KV Storage Engine";
kStorageFactory = storage_factory;
}
// Set OpenCL Compiled Binary paths, just call once. (Not thread-safe)
void SetOpenCLBinaryPaths(const std::vector<std::string> &paths) {
OpenCLRuntime::ConfigureOpenCLBinaryPath(paths);
}
extern std::string kOpenCLParameterPath;
void SetOpenCLParameterPath(const std::string &path) {
kOpenCLParameterPath = path;
}
void SetGPUHints(GPUPerfHint gpu_perf_hint, GPUPriorityHint gpu_priority_hint) {
VLOG(1) << "Set GPU configurations, gpu_perf_hint: " << gpu_perf_hint
<< ", gpu_priority_hint: " << gpu_priority_hint;
OpenCLRuntime::Configure(gpu_perf_hint, gpu_priority_hint);
}
MaceStatus SetOpenMPThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy) {
VLOG(1) << "Set OpenMP threads number hint: " << num_threads_hint
<< ", affinity policy: " << policy;
return SetOpenMPThreadsAndAffinityPolicy(num_threads_hint, policy);
}
MaceStatus SetOpenMPThreadAffinity(int num_threads,
const std::vector<int> &cpu_ids) {
return SetOpenMPThreadsAndAffinityCPUs(num_threads, cpu_ids);
}
MaceStatus GetBigLittleCoreIDs(std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids) {
return GetCPUBigLittleCoreIDs(big_core_ids, little_core_ids);
}
}; // namespace mace
......@@ -58,7 +58,6 @@ cc_library(
deps = [
"//mace/kernels",
],
alwayslink = 1,
)
cc_test(
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Activation(OperatorRegistry *op_registry) {
void Register_Activation(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_AddN(OperatorRegistry *op_registry) {
void Register_AddN(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ArgMax(OperatorRegistry *op_registry) {
void Register_ArgMax(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ArgMax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BatchNorm(OperatorRegistry *op_registry) {
void Register_BatchNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BatchToSpaceND(OperatorRegistry *op_registry) {
void Register_BatchToSpaceND(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchToSpaceND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BiasAdd(OperatorRegistry *op_registry) {
void Register_BiasAdd(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BiasAdd")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BufferToImage(OperatorRegistry *op_registry) {
void Register_BufferToImage(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BufferToImage")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Cast(OperatorRegistry *op_registry) {
void Register_Cast(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ChannelShuffle(OperatorRegistry *op_registry) {
void Register_ChannelShuffle(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ChannelShuffle")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Concat(OperatorRegistry *op_registry) {
void Register_Concat(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Conv2D(OperatorRegistry *op_registry) {
void Register_Conv2D(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -51,7 +51,7 @@ TEST(CoreTest, INIT_MODE) {
for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def);
}
std::shared_ptr<OperatorRegistry> op_registry(new OperatorRegistry());
std::shared_ptr<OperatorRegistryBase> op_registry(new OperatorRegistry());
auto net =
CreateNet(op_registry, net_def, &ws, DeviceType::GPU, NetMode::INIT);
net->Run();
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Deconv2D(OperatorRegistry *op_registry) {
void Register_Deconv2D(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_DepthToSpace(OperatorRegistry *op_registry) {
void Register_DepthToSpace(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Eltwise(OperatorRegistry *op_registry) {
void Register_Eltwise(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_FullyConnected(OperatorRegistry *op_registry) {
void Register_FullyConnected(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Gather(OperatorRegistry *op_registry) {
void Register_Gather(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Gather")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Identity(OperatorRegistry *op_registry) {
void Register_Identity(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ImageToBuffer(OperatorRegistry *op_registry) {
void Register_ImageToBuffer(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ImageToBuffer")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_LocalResponseNorm(OperatorRegistry *op_registry) {
void Register_LocalResponseNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_MatMul(OperatorRegistry *op_registry) {
void Register_MatMul(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_register.h"
namespace mace {
namespace ops {
// Keep in lexicographical order
extern void Register_Activation(OperatorRegistryBase *op_registry);
extern void Register_AddN(OperatorRegistryBase *op_registry);
extern void Register_ArgMax(OperatorRegistryBase *op_registry);
extern void Register_BatchNorm(OperatorRegistryBase *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistryBase *op_registry);
extern void Register_BiasAdd(OperatorRegistryBase *op_registry);
extern void Register_Cast(OperatorRegistryBase *op_registry);
extern void Register_ChannelShuffle(OperatorRegistryBase *op_registry);
extern void Register_Concat(OperatorRegistryBase *op_registry);
extern void Register_Conv2D(OperatorRegistryBase *op_registry);
extern void Register_Deconv2D(OperatorRegistryBase *op_registry);
extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry);
extern void Register_Identity(OperatorRegistryBase *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistryBase *op_registry);
extern void Register_MatMul(OperatorRegistryBase *op_registry);
extern void Register_Pad(OperatorRegistryBase *op_registry);
extern void Register_Pooling(OperatorRegistryBase *op_registry);
extern void Register_Proposal(OperatorRegistryBase *op_registry);
extern void Register_Quantize(OperatorRegistryBase *op_registry);
extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Requantize(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Slice(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry);
extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry);
extern void Register_Squeeze(OperatorRegistryBase *op_registry);
extern void Register_Transpose(OperatorRegistryBase *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry); // NOLINT(whitespace/line_length)
extern void Register_WinogradTransform(OperatorRegistryBase *op_registry);
#ifdef MACE_ENABLE_OPENCL
extern void Register_BufferToImage(OperatorRegistryBase *op_registry);
extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry);
#endif // MACE_ENABLE_OPENCL
} // namespace ops
OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
// Keep in lexicographical order
ops::Register_Activation(this);
ops::Register_AddN(this);
ops::Register_ArgMax(this);
ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this);
ops::Register_Cast(this);
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Squeeze(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
#ifdef MACE_ENABLE_OPENCL
ops::Register_BufferToImage(this);
ops::Register_ImageToBuffer(this);
#endif // MACE_ENABLE_OPENCL
}
} // namespace mace
......@@ -12,22 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h"
#ifndef MACE_OPS_OPS_REGISTER_H_
#define MACE_OPS_OPS_REGISTER_H_
namespace mace {
std::shared_ptr<KVStorageFactory> kStorageFactory = nullptr;
#include "mace/core/operator.h"
void SetKVStorageFactory(std::shared_ptr<KVStorageFactory> storage_factory) {
VLOG(1) << "Set internal KV Storage Engine";
kStorageFactory = storage_factory;
}
namespace mace {
std::string kOpenCLParameterPath; // NOLINT(runtime/string)
class OperatorRegistry : public OperatorRegistryBase {
public:
OperatorRegistry();
~OperatorRegistry() = default;
};
void SetOpenCLParameterPath(const std::string &path) {
kOpenCLParameterPath = path;
}
} // namespace mace
}; // namespace mace
#endif // MACE_OPS_OPS_REGISTER_H_
......@@ -29,6 +29,7 @@
#include "mace/core/tensor.h"
#include "mace/core/workspace.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/ops/ops_register.h"
#include "mace/utils/utils.h"
namespace mace {
......@@ -397,7 +398,7 @@ class OpsTestNet {
}
public:
std::shared_ptr<OperatorRegistry> op_registry_;
std::shared_ptr<OperatorRegistryBase> op_registry_;
Workspace ws_;
std::vector<OperatorDef> op_defs_;
std::unique_ptr<NetBase> net_;
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Pad(OperatorRegistry *op_registry) {
void Register_Pad(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Pooling(OperatorRegistry *op_registry) {
void Register_Pooling(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Proposal(OperatorRegistry *op_registry) {
void Register_Proposal(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Proposal")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Quantize(OperatorRegistry *op_registry) {
void Register_Quantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......@@ -25,7 +25,7 @@ void Register_Quantize(OperatorRegistry *op_registry) {
QuantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Dequantize(OperatorRegistry *op_registry) {
void Register_Dequantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......@@ -33,7 +33,7 @@ void Register_Dequantize(OperatorRegistry *op_registry) {
DequantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Requantize(OperatorRegistry *op_registry) {
void Register_Requantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ReduceMean(OperatorRegistry *op_registry) {
void Register_ReduceMean(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Reshape(OperatorRegistry *op_registry) {
void Register_Reshape(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ResizeBilinear(OperatorRegistry *op_registry) {
void Register_ResizeBilinear(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Shape(OperatorRegistry *op_registry) {
void Register_Shape(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Shape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Slice(OperatorRegistry *op_registry) {
void Register_Slice(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Softmax(OperatorRegistry *op_registry) {
void Register_Softmax(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_SpaceToBatchND(OperatorRegistry *op_registry) {
void Register_SpaceToBatchND(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToBatchND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_SpaceToDepth(OperatorRegistry *op_registry) {
void Register_SpaceToDepth(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Squeeze(OperatorRegistry *op_registry) {
void Register_Squeeze(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Stack(OperatorRegistry *op_registry) {
void Register_Stack(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_StridedSlice(OperatorRegistry *op_registry) {
void Register_StridedSlice(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Transpose(OperatorRegistry *op_registry) {
void Register_Transpose(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_WinogradInverseTransform(OperatorRegistry *op_registry) {
void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry) {
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradInverseTransform")
.Device(DeviceType::GPU)
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_WinogradTransform(OperatorRegistry *op_registry) {
void Register_WinogradTransform(OperatorRegistryBase *op_registry) {
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradTransform")
.Device(DeviceType::GPU)
......
......@@ -24,8 +24,7 @@ cc_test(
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"//mace/libmace:libmace",
"@gtest//:gtest_main",
],
)
......@@ -45,8 +44,7 @@ cc_test(
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"//mace/libmace:libmace",
"@gtest//:gtest_main",
],
)
......@@ -66,8 +64,7 @@ cc_test(
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"//mace/libmace:libmace",
"@gtest//:gtest_main",
],
)
......@@ -16,7 +16,7 @@ cc_binary(
"//external:gflags_nothreads",
"//mace/codegen:generated_mace_engine_factory",
"//mace/codegen:generated_models",
"//mace/ops:ops",
"//mace/libmace",
],
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册