You need to sign in or sign up before continuing.
提交 209f1cb3 编写于 作者: Y yejianwu

fix targets circular dependencies

上级 bd942fe7
......@@ -63,7 +63,7 @@ cc_binary(
linkstatic = 0,
deps = [
":mace_version_script.lds",
"//mace/ops",
"//mace/libmace",
],
)
......@@ -81,6 +81,7 @@ genrule(
"//mace/core",
"//mace/kernels",
"//mace/ops",
"//mace/libmace",
"//mace/utils",
"//mace/proto:mace_cc",
"@com_google_protobuf//:protobuf_lite",
......@@ -93,6 +94,7 @@ genrule(
"$(locations //mace/core:core) " +
"$(locations //mace/kernels:kernels) " +
"$(locations //mace/ops:ops) " +
"$(locations //mace/libmace:libmace) " +
"$(locations //mace/utils:utils) " +
"$(locations //mace/proto:mace_cc) " +
"$(locations @com_google_protobuf//:protobuf_lite) " +
......
......@@ -22,7 +22,7 @@
namespace mace {
NetBase::NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
NetBase::NetBase(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type)
......@@ -31,11 +31,12 @@ NetBase::NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
MACE_UNUSED(type);
}
SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
const NetMode mode)
SerialNet::SerialNet(
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
const NetMode mode)
: NetBase(op_registry, net_def, ws, type), device_type_(type) {
MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name());
for (int idx = 0; idx < net_def->op_size(); ++idx) {
......@@ -130,7 +131,7 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
}
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const NetDef &net_def,
Workspace *ws,
DeviceType type,
......@@ -140,7 +141,7 @@ std::unique_ptr<NetBase> CreateNet(
}
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......
......@@ -30,7 +30,7 @@ class Workspace;
class NetBase {
public:
NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
NetBase(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type);
......@@ -42,14 +42,14 @@ class NetBase {
protected:
std::string name_;
const std::shared_ptr<const OperatorRegistry> op_registry_;
const std::shared_ptr<const OperatorRegistryBase> op_registry_;
MACE_DISABLE_COPY_AND_ASSIGN(NetBase);
};
class SerialNet : public NetBase {
public:
SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
SerialNet(const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......@@ -65,13 +65,13 @@ class SerialNet : public NetBase {
};
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const NetDef &net_def,
Workspace *ws,
DeviceType type,
const NetMode mode = NetMode::NORMAL);
std::unique_ptr<NetBase> CreateNet(
const std::shared_ptr<const OperatorRegistry> op_registry,
const std::shared_ptr<const OperatorRegistryBase> op_registry,
const std::shared_ptr<const NetDef> net_def,
Workspace *ws,
DeviceType type,
......
......@@ -50,7 +50,9 @@ const std::string OpKeyBuilder::Build() {
return ss.str();
}
std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
OperatorRegistryBase::~OperatorRegistryBase() {}
std::unique_ptr<OperatorBase> OperatorRegistryBase::CreateOperator(
const OperatorDef &operator_def,
Workspace *ws,
DeviceType type,
......@@ -72,102 +74,4 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
}
}
namespace ops {
// Keep in lexicographical order
extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_ArgMax(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistry *op_registry);
extern void Register_BiasAdd(OperatorRegistry *op_registry);
extern void Register_Cast(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_Deconv2D(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Gather(OperatorRegistry *op_registry);
extern void Register_Identity(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_ReduceMean(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Shape(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_Stack(OperatorRegistry *op_registry);
extern void Register_StridedSlice(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_Squeeze(OperatorRegistry *op_registry);
extern void Register_Transpose(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
#ifdef MACE_ENABLE_OPENCL
extern void Register_BufferToImage(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
#endif // MACE_ENABLE_OPENCL
} // namespace ops
OperatorRegistry::OperatorRegistry() {
// Keep in lexicographical order
ops::Register_Activation(this);
ops::Register_AddN(this);
ops::Register_ArgMax(this);
ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this);
ops::Register_Cast(this);
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Squeeze(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
#ifdef MACE_ENABLE_OPENCL
ops::Register_BufferToImage(this);
ops::Register_ImageToBuffer(this);
#endif // MACE_ENABLE_OPENCL
}
} // namespace mace
......@@ -163,12 +163,12 @@ OpKeyBuilder &OpKeyBuilder::TypeConstraint(const char *attr_name) {
return this->TypeConstraint(attr_name, DataTypeToEnum<T>::value);
}
class OperatorRegistry {
class OperatorRegistryBase {
public:
typedef Registry<std::string, OperatorBase, const OperatorDef &, Workspace *>
RegistryType;
OperatorRegistry();
~OperatorRegistry() = default;
OperatorRegistryBase() = default;
virtual ~OperatorRegistryBase();
RegistryType *registry() { return &registry_; }
std::unique_ptr<OperatorBase> CreateOperator(const OperatorDef &operator_def,
Workspace *ws,
......@@ -177,7 +177,7 @@ class OperatorRegistry {
private:
RegistryType registry_;
MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistry);
MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistryBase);
};
MACE_DECLARE_REGISTRY(OpRegistry,
......
......@@ -78,7 +78,6 @@ class Registerer {
#endif
#define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \
Registry<SrcType, ObjectType, ##__VA_ARGS__> *RegistryName(); \
typedef Registerer<SrcType, ObjectType, ##__VA_ARGS__> \
Registerer##RegistryName;
......
# Description:
# Mace libmace.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled")
cc_library(
name = "libmace",
srcs = glob(
["*.cc"],
),
copts = [
"-Werror",
"-Wextra",
] + if_openmp_enabled(["-fopenmp"]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
]) + if_android_armv7([
"-mfloat-abi=softfp",
]) + if_android([
"-DMACE_ENABLE_OPENCL",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
"//mace/public",
"//mace/ops",
],
alwayslink = 1,
)
......@@ -21,7 +21,7 @@
#include <memory>
#include "mace/core/net.h"
#include "mace/core/types.h"
#include "mace/ops/ops_register.h"
#include "mace/public/mace.h"
#ifdef MACE_ENABLE_OPENCL
......@@ -138,7 +138,7 @@ class MaceEngine::Impl {
private:
const unsigned char *model_data_;
size_t model_data_size_;
std::shared_ptr<OperatorRegistry> op_registry_;
std::shared_ptr<OperatorRegistryBase> op_registry_;
DeviceType device_type_;
std::unique_ptr<Workspace> ws_;
std::unique_ptr<NetBase> net_;
......
......@@ -58,7 +58,6 @@ cc_library(
deps = [
"//mace/kernels",
],
alwayslink = 1,
)
cc_test(
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Activation(OperatorRegistry *op_registry) {
void Register_Activation(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_AddN(OperatorRegistry *op_registry) {
void Register_AddN(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ArgMax(OperatorRegistry *op_registry) {
void Register_ArgMax(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ArgMax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BatchNorm(OperatorRegistry *op_registry) {
void Register_BatchNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BatchToSpaceND(OperatorRegistry *op_registry) {
void Register_BatchToSpaceND(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchToSpaceND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BiasAdd(OperatorRegistry *op_registry) {
void Register_BiasAdd(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BiasAdd")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_BufferToImage(OperatorRegistry *op_registry) {
void Register_BufferToImage(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BufferToImage")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Cast(OperatorRegistry *op_registry) {
void Register_Cast(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ChannelShuffle(OperatorRegistry *op_registry) {
void Register_ChannelShuffle(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ChannelShuffle")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Concat(OperatorRegistry *op_registry) {
void Register_Concat(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Conv2D(OperatorRegistry *op_registry) {
void Register_Conv2D(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -51,7 +51,7 @@ TEST(CoreTest, INIT_MODE) {
for (auto &op_def : op_defs) {
net_def.add_op()->CopyFrom(op_def);
}
std::shared_ptr<OperatorRegistry> op_registry(new OperatorRegistry());
std::shared_ptr<OperatorRegistryBase> op_registry(new OperatorRegistryBase());
auto net =
CreateNet(op_registry, net_def, &ws, DeviceType::GPU, NetMode::INIT);
net->Run();
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Deconv2D(OperatorRegistry *op_registry) {
void Register_Deconv2D(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_DepthToSpace(OperatorRegistry *op_registry) {
void Register_DepthToSpace(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Eltwise(OperatorRegistry *op_registry) {
void Register_Eltwise(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_FullyConnected(OperatorRegistry *op_registry) {
void Register_FullyConnected(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Gather(OperatorRegistry *op_registry) {
void Register_Gather(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Gather")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Identity(OperatorRegistry *op_registry) {
void Register_Identity(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ImageToBuffer(OperatorRegistry *op_registry) {
void Register_ImageToBuffer(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ImageToBuffer")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_LocalResponseNorm(OperatorRegistry *op_registry) {
void Register_LocalResponseNorm(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_MatMul(OperatorRegistry *op_registry) {
void Register_MatMul(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_register.h"
namespace mace {
namespace ops {
// Keep in lexicographical order
extern void Register_Activation(OperatorRegistryBase *op_registry);
extern void Register_AddN(OperatorRegistryBase *op_registry);
extern void Register_ArgMax(OperatorRegistryBase *op_registry);
extern void Register_BatchNorm(OperatorRegistryBase *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistryBase *op_registry);
extern void Register_BiasAdd(OperatorRegistryBase *op_registry);
extern void Register_Cast(OperatorRegistryBase *op_registry);
extern void Register_ChannelShuffle(OperatorRegistryBase *op_registry);
extern void Register_Concat(OperatorRegistryBase *op_registry);
extern void Register_Conv2D(OperatorRegistryBase *op_registry);
extern void Register_Deconv2D(OperatorRegistryBase *op_registry);
extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry);
extern void Register_Identity(OperatorRegistryBase *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistryBase *op_registry);
extern void Register_MatMul(OperatorRegistryBase *op_registry);
extern void Register_Pad(OperatorRegistryBase *op_registry);
extern void Register_Pooling(OperatorRegistryBase *op_registry);
extern void Register_Proposal(OperatorRegistryBase *op_registry);
extern void Register_Quantize(OperatorRegistryBase *op_registry);
extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Requantize(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Slice(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry);
extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry);
extern void Register_Squeeze(OperatorRegistryBase *op_registry);
extern void Register_Transpose(OperatorRegistryBase *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry); // NOLINT(whitespace/line_length)
extern void Register_WinogradTransform(OperatorRegistryBase *op_registry);
#ifdef MACE_ENABLE_OPENCL
extern void Register_BufferToImage(OperatorRegistryBase *op_registry);
extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry);
#endif // MACE_ENABLE_OPENCL
} // namespace ops
OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
// Keep in lexicographical order
ops::Register_Activation(this);
ops::Register_AddN(this);
ops::Register_ArgMax(this);
ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this);
ops::Register_Cast(this);
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_Deconv2D(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
ops::Register_Identity(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_Quantize(this);
ops::Register_ReduceMean(this);
ops::Register_Requantize(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Squeeze(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
#ifdef MACE_ENABLE_OPENCL
ops::Register_BufferToImage(this);
ops::Register_ImageToBuffer(this);
#endif // MACE_ENABLE_OPENCL
}
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPS_REGISTER_H_
#define MACE_OPS_OPS_REGISTER_H_
#include "mace/core/operator.h"
namespace mace {
class OperatorRegistry : public OperatorRegistryBase {
public:
OperatorRegistry();
~OperatorRegistry() = default;
};
} // namespace mace
#endif // MACE_OPS_OPS_REGISTER_H_
......@@ -110,7 +110,7 @@ class OpDefBuilder {
class OpsTestNet {
public:
OpsTestNet() : op_registry_(new OperatorRegistry()) {}
OpsTestNet() : op_registry_(new OperatorRegistryBase()) {}
template <DeviceType D, typename T>
void AddInputFromArray(const std::string &name,
......@@ -397,7 +397,7 @@ class OpsTestNet {
}
public:
std::shared_ptr<OperatorRegistry> op_registry_;
std::shared_ptr<OperatorRegistryBase> op_registry_;
Workspace ws_;
std::vector<OperatorDef> op_defs_;
std::unique_ptr<NetBase> net_;
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Pad(OperatorRegistry *op_registry) {
void Register_Pad(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Pooling(OperatorRegistry *op_registry) {
void Register_Pooling(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Proposal(OperatorRegistry *op_registry) {
void Register_Proposal(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Proposal")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Quantize(OperatorRegistry *op_registry) {
void Register_Quantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......@@ -25,7 +25,7 @@ void Register_Quantize(OperatorRegistry *op_registry) {
QuantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Dequantize(OperatorRegistry *op_registry) {
void Register_Dequantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......@@ -33,7 +33,7 @@ void Register_Dequantize(OperatorRegistry *op_registry) {
DequantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Requantize(OperatorRegistry *op_registry) {
void Register_Requantize(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ReduceMean(OperatorRegistry *op_registry) {
void Register_ReduceMean(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Reshape(OperatorRegistry *op_registry) {
void Register_Reshape(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_ResizeBilinear(OperatorRegistry *op_registry) {
void Register_ResizeBilinear(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Shape(OperatorRegistry *op_registry) {
void Register_Shape(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Shape")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Slice(OperatorRegistry *op_registry) {
void Register_Slice(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Softmax(OperatorRegistry *op_registry) {
void Register_Softmax(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_SpaceToBatchND(OperatorRegistry *op_registry) {
void Register_SpaceToBatchND(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToBatchND")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_SpaceToDepth(OperatorRegistry *op_registry) {
void Register_SpaceToDepth(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Squeeze(OperatorRegistry *op_registry) {
void Register_Squeeze(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Stack(OperatorRegistry *op_registry) {
void Register_Stack(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_StridedSlice(OperatorRegistry *op_registry) {
void Register_StridedSlice(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_Transpose(OperatorRegistry *op_registry) {
void Register_Transpose(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_WinogradInverseTransform(OperatorRegistry *op_registry) {
void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry) {
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradInverseTransform")
.Device(DeviceType::GPU)
......
......@@ -17,7 +17,7 @@
namespace mace {
namespace ops {
void Register_WinogradTransform(OperatorRegistry *op_registry) {
void Register_WinogradTransform(OperatorRegistryBase *op_registry) {
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradTransform")
.Device(DeviceType::GPU)
......
......@@ -16,7 +16,7 @@ cc_binary(
"//external:gflags_nothreads",
"//mace/codegen:generated_mace_engine_factory",
"//mace/codegen:generated_models",
"//mace/ops:ops",
"//mace/libmace",
],
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册