diff --git a/mace/BUILD b/mace/BUILD index c1ad71119cfcfe85fda19010cde257c323d423ec..70d33ab2e7beafad8b83bcecf2c20541a8c47c12 100644 --- a/mace/BUILD +++ b/mace/BUILD @@ -63,7 +63,7 @@ cc_binary( linkstatic = 0, deps = [ ":mace_version_script.lds", - "//mace/ops", + "//mace/libmace", ], ) @@ -81,6 +81,7 @@ genrule( "//mace/core", "//mace/kernels", "//mace/ops", + "//mace/libmace", "//mace/utils", "//mace/proto:mace_cc", "@com_google_protobuf//:protobuf_lite", @@ -93,6 +94,7 @@ genrule( "$(locations //mace/core:core) " + "$(locations //mace/kernels:kernels) " + "$(locations //mace/ops:ops) " + + "$(locations //mace/libmace:libmace) " + "$(locations //mace/utils:utils) " + "$(locations //mace/proto:mace_cc) " + "$(locations @com_google_protobuf//:protobuf_lite) " + diff --git a/mace/core/net.cc b/mace/core/net.cc index 2f5703194ac38f5d871304be9e24f3540d2de8cb..6d8a751d16501ea678f6f6f71b700fab053a2687 100644 --- a/mace/core/net.cc +++ b/mace/core/net.cc @@ -22,7 +22,7 @@ namespace mace { -NetBase::NetBase(const std::shared_ptr op_registry, +NetBase::NetBase(const std::shared_ptr op_registry, const std::shared_ptr net_def, Workspace *ws, DeviceType type) @@ -31,11 +31,12 @@ NetBase::NetBase(const std::shared_ptr op_registry, MACE_UNUSED(type); } -SerialNet::SerialNet(const std::shared_ptr op_registry, - const std::shared_ptr net_def, - Workspace *ws, - DeviceType type, - const NetMode mode) +SerialNet::SerialNet( + const std::shared_ptr op_registry, + const std::shared_ptr net_def, + Workspace *ws, + DeviceType type, + const NetMode mode) : NetBase(op_registry, net_def, ws, type), device_type_(type) { MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name()); for (int idx = 0; idx < net_def->op_size(); ++idx) { @@ -130,7 +131,7 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) { } std::unique_ptr CreateNet( - const std::shared_ptr op_registry, + const std::shared_ptr op_registry, const NetDef &net_def, Workspace *ws, DeviceType type, @@ -140,7 +141,7 @@ std::unique_ptr CreateNet( } std::unique_ptr CreateNet( - const std::shared_ptr op_registry, + const std::shared_ptr op_registry, const std::shared_ptr net_def, Workspace *ws, DeviceType type, diff --git a/mace/core/net.h b/mace/core/net.h index e901188e75d5e2f5b43ccd45c378e596bfdc99ab..0cec40594c5a12924ff3ee82595b12af4b6f689c 100644 --- a/mace/core/net.h +++ b/mace/core/net.h @@ -30,7 +30,7 @@ class Workspace; class NetBase { public: - NetBase(const std::shared_ptr op_registry, + NetBase(const std::shared_ptr op_registry, const std::shared_ptr net_def, Workspace *ws, DeviceType type); @@ -42,14 +42,14 @@ class NetBase { protected: std::string name_; - const std::shared_ptr op_registry_; + const std::shared_ptr op_registry_; MACE_DISABLE_COPY_AND_ASSIGN(NetBase); }; class SerialNet : public NetBase { public: - SerialNet(const std::shared_ptr op_registry, + SerialNet(const std::shared_ptr op_registry, const std::shared_ptr net_def, Workspace *ws, DeviceType type, @@ -65,13 +65,13 @@ class SerialNet : public NetBase { }; std::unique_ptr CreateNet( - const std::shared_ptr op_registry, + const std::shared_ptr op_registry, const NetDef &net_def, Workspace *ws, DeviceType type, const NetMode mode = NetMode::NORMAL); std::unique_ptr CreateNet( - const std::shared_ptr op_registry, + const std::shared_ptr op_registry, const std::shared_ptr net_def, Workspace *ws, DeviceType type, diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 6389d1172210fd65db9d62ce029e95be0909e0b3..90013d7f30180f472fae9dba0f4385348ac9aac3 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -50,7 +50,9 @@ const std::string OpKeyBuilder::Build() { return ss.str(); } -std::unique_ptr OperatorRegistry::CreateOperator( +OperatorRegistryBase::~OperatorRegistryBase() {} + +std::unique_ptr OperatorRegistryBase::CreateOperator( const OperatorDef &operator_def, Workspace *ws, DeviceType type, @@ -72,102 +74,4 @@ std::unique_ptr OperatorRegistry::CreateOperator( } } -namespace ops { -// Keep in lexicographical order -extern void Register_Activation(OperatorRegistry *op_registry); -extern void Register_AddN(OperatorRegistry *op_registry); -extern void Register_ArgMax(OperatorRegistry *op_registry); -extern void Register_BatchNorm(OperatorRegistry *op_registry); -extern void Register_BatchToSpaceND(OperatorRegistry *op_registry); -extern void Register_BiasAdd(OperatorRegistry *op_registry); -extern void Register_Cast(OperatorRegistry *op_registry); -extern void Register_ChannelShuffle(OperatorRegistry *op_registry); -extern void Register_Concat(OperatorRegistry *op_registry); -extern void Register_Conv2D(OperatorRegistry *op_registry); -extern void Register_Deconv2D(OperatorRegistry *op_registry); -extern void Register_DepthToSpace(OperatorRegistry *op_registry); -extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); -extern void Register_Dequantize(OperatorRegistry *op_registry); -extern void Register_Eltwise(OperatorRegistry *op_registry); -extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); -extern void Register_FullyConnected(OperatorRegistry *op_registry); -extern void Register_Gather(OperatorRegistry *op_registry); -extern void Register_Identity(OperatorRegistry *op_registry); -extern void Register_LocalResponseNorm(OperatorRegistry *op_registry); -extern void Register_MatMul(OperatorRegistry *op_registry); -extern void Register_Pad(OperatorRegistry *op_registry); -extern void Register_Pooling(OperatorRegistry *op_registry); -extern void Register_Proposal(OperatorRegistry *op_registry); -extern void Register_Quantize(OperatorRegistry *op_registry); -extern void Register_ReduceMean(OperatorRegistry *op_registry); -extern void Register_Requantize(OperatorRegistry *op_registry); -extern void Register_Reshape(OperatorRegistry *op_registry); -extern void Register_ResizeBilinear(OperatorRegistry *op_registry); -extern void Register_Shape(OperatorRegistry *op_registry); -extern void Register_Slice(OperatorRegistry *op_registry); -extern void Register_Softmax(OperatorRegistry *op_registry); -extern void Register_Stack(OperatorRegistry *op_registry); -extern void Register_StridedSlice(OperatorRegistry *op_registry); -extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); -extern void Register_SpaceToDepth(OperatorRegistry *op_registry); -extern void Register_Squeeze(OperatorRegistry *op_registry); -extern void Register_Transpose(OperatorRegistry *op_registry); -extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); -extern void Register_WinogradTransform(OperatorRegistry *op_registry); - -#ifdef MACE_ENABLE_OPENCL -extern void Register_BufferToImage(OperatorRegistry *op_registry); -extern void Register_ImageToBuffer(OperatorRegistry *op_registry); -#endif // MACE_ENABLE_OPENCL -} // namespace ops - -OperatorRegistry::OperatorRegistry() { - // Keep in lexicographical order - ops::Register_Activation(this); - ops::Register_AddN(this); - ops::Register_ArgMax(this); - ops::Register_BatchNorm(this); - ops::Register_BatchToSpaceND(this); - ops::Register_BiasAdd(this); - ops::Register_Cast(this); - ops::Register_ChannelShuffle(this); - ops::Register_Concat(this); - ops::Register_Conv2D(this); - ops::Register_Deconv2D(this); - ops::Register_DepthToSpace(this); - ops::Register_DepthwiseConv2d(this); - ops::Register_Dequantize(this); - ops::Register_Eltwise(this); - ops::Register_FoldedBatchNorm(this); - ops::Register_FullyConnected(this); - ops::Register_Gather(this); - ops::Register_Identity(this); - ops::Register_LocalResponseNorm(this); - ops::Register_MatMul(this); - ops::Register_Pad(this); - ops::Register_Pooling(this); - ops::Register_Proposal(this); - ops::Register_Quantize(this); - ops::Register_ReduceMean(this); - ops::Register_Requantize(this); - ops::Register_Reshape(this); - ops::Register_ResizeBilinear(this); - ops::Register_Shape(this); - ops::Register_Slice(this); - ops::Register_Softmax(this); - ops::Register_Stack(this); - ops::Register_StridedSlice(this); - ops::Register_SpaceToBatchND(this); - ops::Register_SpaceToDepth(this); - ops::Register_Squeeze(this); - ops::Register_Transpose(this); - ops::Register_WinogradInverseTransform(this); - ops::Register_WinogradTransform(this); - -#ifdef MACE_ENABLE_OPENCL - ops::Register_BufferToImage(this); - ops::Register_ImageToBuffer(this); -#endif // MACE_ENABLE_OPENCL -} - } // namespace mace diff --git a/mace/core/operator.h b/mace/core/operator.h index 3a2285d97d6948fbd6fe1b9a23d8ace9585a8c27..330f8002288badec78de4d6987caff0d0762cb05 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -163,12 +163,12 @@ OpKeyBuilder &OpKeyBuilder::TypeConstraint(const char *attr_name) { return this->TypeConstraint(attr_name, DataTypeToEnum::value); } -class OperatorRegistry { +class OperatorRegistryBase { public: typedef Registry RegistryType; - OperatorRegistry(); - ~OperatorRegistry() = default; + OperatorRegistryBase() = default; + virtual ~OperatorRegistryBase(); RegistryType *registry() { return ®istry_; } std::unique_ptr CreateOperator(const OperatorDef &operator_def, Workspace *ws, @@ -177,7 +177,7 @@ class OperatorRegistry { private: RegistryType registry_; - MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistry); + MACE_DISABLE_COPY_AND_ASSIGN(OperatorRegistryBase); }; MACE_DECLARE_REGISTRY(OpRegistry, diff --git a/mace/core/registry.h b/mace/core/registry.h index 0cc7ebf500bb6da3638167f3de5dc6269ba3c3e4..277cabb3ec3f3c854d80d4c095643d9f59f547e5 100644 --- a/mace/core/registry.h +++ b/mace/core/registry.h @@ -78,7 +78,6 @@ class Registerer { #endif #define MACE_DECLARE_TYPED_REGISTRY(RegistryName, SrcType, ObjectType, ...) \ - Registry *RegistryName(); \ typedef Registerer \ Registerer##RegistryName; diff --git a/mace/libmace/BUILD b/mace/libmace/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cbbf16a8baed277161bc98705411c804a2b67224 --- /dev/null +++ b/mace/libmace/BUILD @@ -0,0 +1,36 @@ +# Description: +# Mace libmace. +# +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled") + +cc_library( + name = "libmace", + srcs = glob( + ["*.cc"], + ), + copts = [ + "-Werror", + "-Wextra", + ] + if_openmp_enabled(["-fopenmp"]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + ]) + if_android_armv7([ + "-mfloat-abi=softfp", + ]) + if_android([ + "-DMACE_ENABLE_OPENCL", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + "//mace/public", + "//mace/ops", + ], + alwayslink = 1, +) diff --git a/mace/core/mace.cc b/mace/libmace/mace.cc similarity index 99% rename from mace/core/mace.cc rename to mace/libmace/mace.cc index db04fcc699c481f0e06c385a7b7cc8325a463aa7..93518f85da8413197befe2f115f785b49224d717 100644 --- a/mace/core/mace.cc +++ b/mace/libmace/mace.cc @@ -21,7 +21,7 @@ #include #include "mace/core/net.h" -#include "mace/core/types.h" +#include "mace/ops/ops_register.h" #include "mace/public/mace.h" #ifdef MACE_ENABLE_OPENCL @@ -138,7 +138,7 @@ class MaceEngine::Impl { private: const unsigned char *model_data_; size_t model_data_size_; - std::shared_ptr op_registry_; + std::shared_ptr op_registry_; DeviceType device_type_; std::unique_ptr ws_; std::unique_ptr net_; diff --git a/mace/core/mace_runtime.cc b/mace/libmace/mace_runtime.cc similarity index 100% rename from mace/core/mace_runtime.cc rename to mace/libmace/mace_runtime.cc diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 2bfc0b3d9f34f162edf076be48785e0897050d6c..f349b8b9c1fe5480ae3c753e39e7cea64f7454a6 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -58,7 +58,6 @@ cc_library( deps = [ "//mace/kernels", ], - alwayslink = 1, ) cc_test( diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index 37fd8117b83ab511e425e40eb185246a3856a172..44b2ba90c035d3dda95f4bcf31497f3a3b08f205 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Activation(OperatorRegistry *op_registry) { +void Register_Activation(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/addn.cc b/mace/ops/addn.cc index 6bfc4c09503501c0d4aa601335edbfb6fc86453c..a30cba48b4f2fb2bf6620cd23b807c6d4462f451 100644 --- a/mace/ops/addn.cc +++ b/mace/ops/addn.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_AddN(OperatorRegistry *op_registry) { +void Register_AddN(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/argmax.cc b/mace/ops/argmax.cc index 977cbbc6b238b1f909ca4e5ce06c5c81cc9ea36f..e14b7bb8c193b153c5a5f36c563b62b98d57607a 100644 --- a/mace/ops/argmax.cc +++ b/mace/ops/argmax.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_ArgMax(OperatorRegistry *op_registry) { +void Register_ArgMax(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ArgMax") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index fe63559285dc91f061199d4747feee5e06f2d8c3..c1a6c0cf3c292df95c5f94fb42ff2ca2d1987577 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_BatchNorm(OperatorRegistry *op_registry) { +void Register_BatchNorm(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/batch_to_space.cc b/mace/ops/batch_to_space.cc index 50bc84ed2f56d46c62ccd3356f6023978373fc6b..b0ffd66bdf38fb0f96de601588e67cc32cb1874a 100644 --- a/mace/ops/batch_to_space.cc +++ b/mace/ops/batch_to_space.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_BatchToSpaceND(OperatorRegistry *op_registry) { +void Register_BatchToSpaceND(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchToSpaceND") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/bias_add.cc b/mace/ops/bias_add.cc index deb67368c9964e68c95ef411c16046e9be7506bc..bf082cf9286940858f7ef5eb9cfeae06b43252af 100644 --- a/mace/ops/bias_add.cc +++ b/mace/ops/bias_add.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_BiasAdd(OperatorRegistry *op_registry) { +void Register_BiasAdd(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BiasAdd") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/buffer_to_image.cc b/mace/ops/buffer_to_image.cc index 04cb9b8292340004600353c760a6dd43e8555104..83569ba3546e0b6a640f199b565d987f0486368e 100644 --- a/mace/ops/buffer_to_image.cc +++ b/mace/ops/buffer_to_image.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_BufferToImage(OperatorRegistry *op_registry) { +void Register_BufferToImage(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("BufferToImage") .Device(DeviceType::GPU) .TypeConstraint("T") diff --git a/mace/ops/cast.cc b/mace/ops/cast.cc index 556a79f81021909917ee7bf19b74203a6b8af8e6..87abfdd46eac3c4064ea448569d396005434970d 100644 --- a/mace/ops/cast.cc +++ b/mace/ops/cast.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Cast(OperatorRegistry *op_registry) { +void Register_Cast(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Cast") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/channel_shuffle.cc b/mace/ops/channel_shuffle.cc index f3311be64271876ebec1b7967d38faecdfe1f200..e13ac92a60390d5e76277a3afec4222e56336ab9 100644 --- a/mace/ops/channel_shuffle.cc +++ b/mace/ops/channel_shuffle.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_ChannelShuffle(OperatorRegistry *op_registry) { +void Register_ChannelShuffle(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ChannelShuffle") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/concat.cc b/mace/ops/concat.cc index bf82f796c88f93085a37cb803bb8279db0fd1814..c281f0cce2f6ce2600b93a769a94f89451d22a95 100644 --- a/mace/ops/concat.cc +++ b/mace/ops/concat.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Concat(OperatorRegistry *op_registry) { +void Register_Concat(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Concat") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 29d3ac7159c1ca952065c2b7f9bbf28c67fcf9dd..4377afb01e040eddf74f27e5d7b5e963e0246d26 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Conv2D(OperatorRegistry *op_registry) { +void Register_Conv2D(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/core_test.cc b/mace/ops/core_test.cc index d471a0f84bca0a3edaebf1fa299951d92790a144..e7d256f17c9f6d8dc23573ee6a74f727e1013cd6 100644 --- a/mace/ops/core_test.cc +++ b/mace/ops/core_test.cc @@ -51,7 +51,7 @@ TEST(CoreTest, INIT_MODE) { for (auto &op_def : op_defs) { net_def.add_op()->CopyFrom(op_def); } - std::shared_ptr op_registry(new OperatorRegistry()); + std::shared_ptr op_registry(new OperatorRegistryBase()); auto net = CreateNet(op_registry, net_def, &ws, DeviceType::GPU, NetMode::INIT); net->Run(); diff --git a/mace/ops/deconv_2d.cc b/mace/ops/deconv_2d.cc index 342e27aa13151837febf2256927787d5205585ab..af0d7232e3e42745c24241f915006acb5623c64e 100644 --- a/mace/ops/deconv_2d.cc +++ b/mace/ops/deconv_2d.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Deconv2D(OperatorRegistry *op_registry) { +void Register_Deconv2D(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Deconv2D") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc index 682a6770f9fa743f5cf17750ba18307d3eed4fb2..0da2bb00865d0a1b47a3295bf143e600fd392c6a 100644 --- a/mace/ops/depth_to_space.cc +++ b/mace/ops/depth_to_space.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_DepthToSpace(OperatorRegistry *op_registry) { +void Register_DepthToSpace(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index cdb53595088bed8b163a74bf54707b3d0f129ab7..66396f6002b80219280b98015910eedab51ef0a6 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_DepthwiseConv2d(OperatorRegistry *op_registry) { +void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index 81050b16e4f5e030e6ff210f9022c3f866cdbe6c..b3d46025133bb9617d436bca2d02e8653323635a 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Eltwise(OperatorRegistry *op_registry) { +void Register_Eltwise(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/folded_batch_norm.cc b/mace/ops/folded_batch_norm.cc index ace0b857d2c3a8a8997559424e536974c6ae634b..f760075077396e699b16f3da15a8f57e4523623b 100644 --- a/mace/ops/folded_batch_norm.cc +++ b/mace/ops/folded_batch_norm.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_FoldedBatchNorm(OperatorRegistry *op_registry) { +void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc index 3147a598abf43682b0b599bd443c765b620f09ef..5ad8c4664a40fbbe6a9b7fe8cde20f1705a78d41 100644 --- a/mace/ops/fully_connected.cc +++ b/mace/ops/fully_connected.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_FullyConnected(OperatorRegistry *op_registry) { +void Register_FullyConnected(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("FullyConnected") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/gather.cc b/mace/ops/gather.cc index bc9687cf8dcf36f6a8a1cc004fdc81d852dc0c3a..12891c5d9ce00db7fd1dd25a4145a07b922f797b 100644 --- a/mace/ops/gather.cc +++ b/mace/ops/gather.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Gather(OperatorRegistry *op_registry) { +void Register_Gather(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Gather") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/identity.cc b/mace/ops/identity.cc index 628bfd2d593ec9b817221bc9e5852d3a2ceeef49..61a3335672e4d8b0f2e358dc40728d4271ea174e 100644 --- a/mace/ops/identity.cc +++ b/mace/ops/identity.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Identity(OperatorRegistry *op_registry) { +void Register_Identity(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/image_to_buffer.cc b/mace/ops/image_to_buffer.cc index 168f75b6dcbff8e233375cc547b89a7bc56e3d9f..cc60d146417069b03afa72a8ac8ea0b656212ba9 100644 --- a/mace/ops/image_to_buffer.cc +++ b/mace/ops/image_to_buffer.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_ImageToBuffer(OperatorRegistry *op_registry) { +void Register_ImageToBuffer(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ImageToBuffer") .Device(DeviceType::GPU) .TypeConstraint("T") diff --git a/mace/ops/local_response_norm.cc b/mace/ops/local_response_norm.cc index 8517c0140aba91bfbf79cfeaa4df4918b72d0f9b..f3e199706da84f9bb902eed8cb427a211e79261b 100644 --- a/mace/ops/local_response_norm.cc +++ b/mace/ops/local_response_norm.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_LocalResponseNorm(OperatorRegistry *op_registry) { +void Register_LocalResponseNorm(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index fa342659ef4f0bf771a1edc644cc9a9d87932a0d..e1c5932c2626ff4211ce97bed22be05250ae11f4 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_MatMul(OperatorRegistry *op_registry) { +void Register_MatMul(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc new file mode 100644 index 0000000000000000000000000000000000000000..61ecb5af68ac56ae8e59dc1ee6370b64fc8e9689 --- /dev/null +++ b/mace/ops/ops_register.cc @@ -0,0 +1,118 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_register.h" + +namespace mace { + +namespace ops { +// Keep in lexicographical order +extern void Register_Activation(OperatorRegistryBase *op_registry); +extern void Register_AddN(OperatorRegistryBase *op_registry); +extern void Register_ArgMax(OperatorRegistryBase *op_registry); +extern void Register_BatchNorm(OperatorRegistryBase *op_registry); +extern void Register_BatchToSpaceND(OperatorRegistryBase *op_registry); +extern void Register_BiasAdd(OperatorRegistryBase *op_registry); +extern void Register_Cast(OperatorRegistryBase *op_registry); +extern void Register_ChannelShuffle(OperatorRegistryBase *op_registry); +extern void Register_Concat(OperatorRegistryBase *op_registry); +extern void Register_Conv2D(OperatorRegistryBase *op_registry); +extern void Register_Deconv2D(OperatorRegistryBase *op_registry); +extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); +extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); +extern void Register_Dequantize(OperatorRegistryBase *op_registry); +extern void Register_Eltwise(OperatorRegistryBase *op_registry); +extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); +extern void Register_FullyConnected(OperatorRegistryBase *op_registry); +extern void Register_Gather(OperatorRegistryBase *op_registry); +extern void Register_Identity(OperatorRegistryBase *op_registry); +extern void Register_LocalResponseNorm(OperatorRegistryBase *op_registry); +extern void Register_MatMul(OperatorRegistryBase *op_registry); +extern void Register_Pad(OperatorRegistryBase *op_registry); +extern void Register_Pooling(OperatorRegistryBase *op_registry); +extern void Register_Proposal(OperatorRegistryBase *op_registry); +extern void Register_Quantize(OperatorRegistryBase *op_registry); +extern void Register_ReduceMean(OperatorRegistryBase *op_registry); +extern void Register_Requantize(OperatorRegistryBase *op_registry); +extern void Register_Reshape(OperatorRegistryBase *op_registry); +extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); +extern void Register_Shape(OperatorRegistryBase *op_registry); +extern void Register_Slice(OperatorRegistryBase *op_registry); +extern void Register_Softmax(OperatorRegistryBase *op_registry); +extern void Register_Stack(OperatorRegistryBase *op_registry); +extern void Register_StridedSlice(OperatorRegistryBase *op_registry); +extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry); +extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry); +extern void Register_Squeeze(OperatorRegistryBase *op_registry); +extern void Register_Transpose(OperatorRegistryBase *op_registry); +extern void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry); // NOLINT(whitespace/line_length) +extern void Register_WinogradTransform(OperatorRegistryBase *op_registry); + +#ifdef MACE_ENABLE_OPENCL +extern void Register_BufferToImage(OperatorRegistryBase *op_registry); +extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry); +#endif // MACE_ENABLE_OPENCL +} // namespace ops + + +OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { + // Keep in lexicographical order + ops::Register_Activation(this); + ops::Register_AddN(this); + ops::Register_ArgMax(this); + ops::Register_BatchNorm(this); + ops::Register_BatchToSpaceND(this); + ops::Register_BiasAdd(this); + ops::Register_Cast(this); + ops::Register_ChannelShuffle(this); + ops::Register_Concat(this); + ops::Register_Conv2D(this); + ops::Register_Deconv2D(this); + ops::Register_DepthToSpace(this); + ops::Register_DepthwiseConv2d(this); + ops::Register_Dequantize(this); + ops::Register_Eltwise(this); + ops::Register_FoldedBatchNorm(this); + ops::Register_FullyConnected(this); + ops::Register_Gather(this); + ops::Register_Identity(this); + ops::Register_LocalResponseNorm(this); + ops::Register_MatMul(this); + ops::Register_Pad(this); + ops::Register_Pooling(this); + ops::Register_Proposal(this); + ops::Register_Quantize(this); + ops::Register_ReduceMean(this); + ops::Register_Requantize(this); + ops::Register_Reshape(this); + ops::Register_ResizeBilinear(this); + ops::Register_Shape(this); + ops::Register_Slice(this); + ops::Register_Softmax(this); + ops::Register_Stack(this); + ops::Register_StridedSlice(this); + ops::Register_SpaceToBatchND(this); + ops::Register_SpaceToDepth(this); + ops::Register_Squeeze(this); + ops::Register_Transpose(this); + ops::Register_WinogradInverseTransform(this); + ops::Register_WinogradTransform(this); + +#ifdef MACE_ENABLE_OPENCL + ops::Register_BufferToImage(this); + ops::Register_ImageToBuffer(this); +#endif // MACE_ENABLE_OPENCL +} + +} // namespace mace diff --git a/mace/ops/ops_register.h b/mace/ops/ops_register.h new file mode 100644 index 0000000000000000000000000000000000000000..9369fde5d7a717a8e74a155253f838eecf0e96cb --- /dev/null +++ b/mace/ops/ops_register.h @@ -0,0 +1,30 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_OPS_REGISTER_H_ +#define MACE_OPS_OPS_REGISTER_H_ + +#include "mace/core/operator.h" + +namespace mace { + +class OperatorRegistry : public OperatorRegistryBase { + public: + OperatorRegistry(); + ~OperatorRegistry() = default; +}; + +} // namespace mace + +#endif // MACE_OPS_OPS_REGISTER_H_ diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index e348ba1f82a603db0be689681abb69bab049206f..f34797c9cc20c41c417cfd938b07179e82d0399f 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -110,7 +110,7 @@ class OpDefBuilder { class OpsTestNet { public: - OpsTestNet() : op_registry_(new OperatorRegistry()) {} + OpsTestNet() : op_registry_(new OperatorRegistryBase()) {} template void AddInputFromArray(const std::string &name, @@ -397,7 +397,7 @@ class OpsTestNet { } public: - std::shared_ptr op_registry_; + std::shared_ptr op_registry_; Workspace ws_; std::vector op_defs_; std::unique_ptr net_; diff --git a/mace/ops/pad.cc b/mace/ops/pad.cc index 6875de6ab314b4d2ed183d7087e89774ebfacaed..e6d468b22cf1fa642d06b627e57d4c5f8f7e727a 100644 --- a/mace/ops/pad.cc +++ b/mace/ops/pad.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Pad(OperatorRegistry *op_registry) { +void Register_Pad(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pad") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 25cd44aad70a3052da27aa6e61b9c173edb27058..0b673b51ecf1a2da0107c3ae00c4c25c07fd4f9b 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Pooling(OperatorRegistry *op_registry) { +void Register_Pooling(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/proposal.cc b/mace/ops/proposal.cc index 4558bbb3d3bad5e9214303fb5f16401bac48308b..2b75eeafe777aa887602bbedb879185335ef3fa9 100644 --- a/mace/ops/proposal.cc +++ b/mace/ops/proposal.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Proposal(OperatorRegistry *op_registry) { +void Register_Proposal(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Proposal") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/quantize.cc b/mace/ops/quantize.cc index dad9610d25b8edfc688f90ab425167e0652ca24c..81a51fcec54e25c8af462136596090596a2d1ba7 100644 --- a/mace/ops/quantize.cc +++ b/mace/ops/quantize.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Quantize(OperatorRegistry *op_registry) { +void Register_Quantize(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize") .Device(DeviceType::CPU) .TypeConstraint("T") @@ -25,7 +25,7 @@ void Register_Quantize(OperatorRegistry *op_registry) { QuantizeOp); } -void Register_Dequantize(OperatorRegistry *op_registry) { +void Register_Dequantize(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize") .Device(DeviceType::CPU) .TypeConstraint("T") @@ -33,7 +33,7 @@ void Register_Dequantize(OperatorRegistry *op_registry) { DequantizeOp); } -void Register_Requantize(OperatorRegistry *op_registry) { +void Register_Requantize(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/reduce_mean.cc b/mace/ops/reduce_mean.cc index 4f181a776f5b5a589d657527c5124e5a1d524d26..ee4d171681ba56f2bfff5490d5742c8aeec9c70c 100644 --- a/mace/ops/reduce_mean.cc +++ b/mace/ops/reduce_mean.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_ReduceMean(OperatorRegistry *op_registry) { +void Register_ReduceMean(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ReduceMean") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index aefc63371b94ece503bea4375866be7c3ec0bd56..2831aeba12d1632e1e23773b4ccbba0fa2cee9e6 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Reshape(OperatorRegistry *op_registry) { +void Register_Reshape(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Reshape") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/resize_bilinear.cc b/mace/ops/resize_bilinear.cc index e18d70387345fc1bb857deba2bbaa9945c054c53..82bbfd0a3aea8caa88a22821852435dcd9567e62 100644 --- a/mace/ops/resize_bilinear.cc +++ b/mace/ops/resize_bilinear.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_ResizeBilinear(OperatorRegistry *op_registry) { +void Register_ResizeBilinear(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/shape.cc b/mace/ops/shape.cc index c65586e6e366197c4bd3d154bdfa73e66ca728a8..7014aa8d8ee86cc55ca2023354cd7971444eb5bf 100644 --- a/mace/ops/shape.cc +++ b/mace/ops/shape.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Shape(OperatorRegistry *op_registry) { +void Register_Shape(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Shape") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/slice.cc b/mace/ops/slice.cc index a9b1c9bd2c494721801345c7910b0768cc0c6f16..b6bf4b24e7fd6e974448e9751866503429cdea84 100644 --- a/mace/ops/slice.cc +++ b/mace/ops/slice.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Slice(OperatorRegistry *op_registry) { +void Register_Slice(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index eff2b41565ad40140798a566576c08cdfd3c9822..6c1a895b76015488eb5e9788f2b51345ea5e2dd0 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Softmax(OperatorRegistry *op_registry) { +void Register_Softmax(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/space_to_batch.cc b/mace/ops/space_to_batch.cc index ca905e785ee884ebc6e6f81d9e202fe48a6720a5..e0291172bc1daccbef28c9662e4d0fc07657c8f6 100644 --- a/mace/ops/space_to_batch.cc +++ b/mace/ops/space_to_batch.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_SpaceToBatchND(OperatorRegistry *op_registry) { +void Register_SpaceToBatchND(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToBatchND") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/space_to_depth.cc b/mace/ops/space_to_depth.cc index 1807226505a1bd73aad7c1426cbf0cf37d74e108..67b520f6487f2115771fe6e0d05c7576febb4fd8 100644 --- a/mace/ops/space_to_depth.cc +++ b/mace/ops/space_to_depth.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_SpaceToDepth(OperatorRegistry *op_registry) { +void Register_SpaceToDepth(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/squeeze.cc b/mace/ops/squeeze.cc index e917936fc949d9ccf99ba611753e97fa7a503248..e30a87bdc5d870099d1e270b7424dac7a5974d32 100644 --- a/mace/ops/squeeze.cc +++ b/mace/ops/squeeze.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Squeeze(OperatorRegistry *op_registry) { +void Register_Squeeze(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Squeeze") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/stack.cc b/mace/ops/stack.cc index 992ee408a65cdcf528f3471cb2378b4a8d1cea8a..968f859d5945d8d353d56cf631f433e409c22f54 100644 --- a/mace/ops/stack.cc +++ b/mace/ops/stack.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Stack(OperatorRegistry *op_registry) { +void Register_Stack(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc index 84cf788394cab669b290ccff10ef82e3c79fb8f1..b449be038f33e34d03b3af9360634513f852f544 100644 --- a/mace/ops/strided_slice.cc +++ b/mace/ops/strided_slice.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_StridedSlice(OperatorRegistry *op_registry) { +void Register_StridedSlice(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/transpose.cc b/mace/ops/transpose.cc index a0c726af9461d3913d21d17b2b722166181a594b..73dcaf7b650dbd168bd5c74a38c3a8fbdc3a7318 100644 --- a/mace/ops/transpose.cc +++ b/mace/ops/transpose.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_Transpose(OperatorRegistry *op_registry) { +void Register_Transpose(OperatorRegistryBase *op_registry) { MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Transpose") .Device(DeviceType::CPU) .TypeConstraint("T") diff --git a/mace/ops/winograd_inverse_transform.cc b/mace/ops/winograd_inverse_transform.cc index f84b69a2074823bd9c97df1a5ba14acd7719ce02..62e86248136c3cd4b8f94ee305c700dcaa16277e 100644 --- a/mace/ops/winograd_inverse_transform.cc +++ b/mace/ops/winograd_inverse_transform.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_WinogradInverseTransform(OperatorRegistry *op_registry) { +void Register_WinogradInverseTransform(OperatorRegistryBase *op_registry) { #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradInverseTransform") .Device(DeviceType::GPU) diff --git a/mace/ops/winograd_transform.cc b/mace/ops/winograd_transform.cc index 24f822551ac536931e661d4ae2193d8509096fd5..a4dab0ec1d1d1cacdd30c292c481834b86d35918 100644 --- a/mace/ops/winograd_transform.cc +++ b/mace/ops/winograd_transform.cc @@ -17,7 +17,7 @@ namespace mace { namespace ops { -void Register_WinogradTransform(OperatorRegistry *op_registry) { +void Register_WinogradTransform(OperatorRegistryBase *op_registry) { #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("WinogradTransform") .Device(DeviceType::GPU) diff --git a/mace/tools/validation/BUILD b/mace/tools/validation/BUILD index a1ba419b27ea448b2d53ea2b75fa31fdcac83af5..822bd2e91af7bacccc6db8b37569ee18c109bdfb 100644 --- a/mace/tools/validation/BUILD +++ b/mace/tools/validation/BUILD @@ -16,7 +16,7 @@ cc_binary( "//external:gflags_nothreads", "//mace/codegen:generated_mace_engine_factory", "//mace/codegen:generated_models", - "//mace/ops:ops", + "//mace/libmace", ], )