提交 2895c3ee 编写于 作者: L luxuhui

refactor: refactor op base module and op delegator mechanism

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 28954099
...@@ -19,7 +19,7 @@ Define the new Op class in `mace/ops/my_custom_op.cc`. ...@@ -19,7 +19,7 @@ Define the new Op class in `mace/ops/my_custom_op.cc`.
The structure of Op is like the following code. The structure of Op is like the following code.
```c++ ```c++
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -39,7 +39,7 @@ class MyCustomOp<DeviceType::GPU, float> : public Operation { ...@@ -39,7 +39,7 @@ class MyCustomOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterMyCustomOp(OpRegistryBase *op_registry) { void RegisterMyCustomOp(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "MyCustomOp", MyCustomOp, MACE_REGISTER_OP(op_registry, "MyCustomOp", MyCustomOp,
DeviceType::CPU, float); DeviceType::CPU, float);
...@@ -63,14 +63,14 @@ namespace ops { ...@@ -63,14 +63,14 @@ namespace ops {
... ...
extern void RegisterMyCustomOp(OpRegistryBase *op_registry); extern void RegisterMyCustomOp(OpRegistry *op_registry);
... ...
} // namespace ops } // namespace ops
OpRegistry::OpRegistry() : OpRegistryBase() { OpRegistry::OpRegistry() {
// Keep in lexicographical order // Keep in lexicographical order
... ...
......
...@@ -557,7 +557,7 @@ which will reduce the library size significantly. the final binary just link the ...@@ -557,7 +557,7 @@ which will reduce the library size significantly. the final binary just link the
} // namespace ops } // namespace ops
OpRegistry::OpRegistry() : OpRegistryBase() { OpRegistry::OpRegistry() {
// Just leave the ops used in your models // Just leave the ops used in your models
... ...
......
...@@ -370,12 +370,13 @@ the sample code show how to calculate the Top-1 accuracy with imagenet validatio ...@@ -370,12 +370,13 @@ the sample code show how to calculate the Top-1 accuracy with imagenet validatio
Reduce Library Size Reduce Library Size
------------------- -------------------
Remove the registration of the ops unused for your models in the ``mace/ops/ops_register.cc``, Remove the registration of the ops and delegators unused for your models in the
which will reduce the library size significantly. the final binary just link the registered ops' code. ``mace/ops/registry/ops_registry.cc`` and ``mace/ops/registry/op_delegators_registry.cc``,
which will reduce the library size significantly. the final binary just link the registered ops and delegators' code.
.. code-block:: cpp .. code-block:: cpp
#include "mace/ops/ops_register.h" #include "mace/ops/registry/registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -386,12 +387,38 @@ which will reduce the library size significantly. the final binary just link the ...@@ -386,12 +387,38 @@ which will reduce the library size significantly. the final binary just link the
} // namespace ops } // namespace ops
OpRegistry::OpRegistry() : OpRegistryBase() { void RegisterAllOps(OpRegistry *registry) {
// Just leave the ops used in your models // Just leave the ops used in your models
... ...
ops::RegisterMyCustomOp(this); ops::RegisterMyCustomOp(registry);
...
}
} // namespace mace
.. code-block:: cpp
#include "mace/ops/registry/registry.h"
namespace mace {
namespace ops {
// Just leave the delegators used in your ops
...
} // namespace ops
void RegisterAllOpDelegators(OpDelegatorRegistry *registry) {
// Just leave the delegators used in your ops
...
ops::RegisterMyCustomDelegator(registry);
... ...
......
...@@ -26,6 +26,8 @@ cc_library( ...@@ -26,6 +26,8 @@ cc_library(
srcs = glob( srcs = glob(
[ [
"*.cc", "*.cc",
"ops/*.cc",
"registry/*.cc",
"runtime/cpu/*.cc", "runtime/cpu/*.cc",
], ],
exclude = [ exclude = [
...@@ -53,6 +55,8 @@ cc_library( ...@@ -53,6 +55,8 @@ cc_library(
hdrs = glob( hdrs = glob(
[ [
"*.h", "*.h",
"ops/*.h",
"registry/*.h",
"runtime/cpu/*.h", "runtime/cpu/*.h",
], ],
exclude = [ exclude = [
...@@ -68,7 +72,7 @@ cc_library( ...@@ -68,7 +72,7 @@ cc_library(
])) + if_hta_enabled(glob([ ])) + if_hta_enabled(glob([
"runtime/hexagon/*hta*.h", "runtime/hexagon/*hta*.h",
])) + if_apu_enabled(glob([ ])) + if_apu_enabled(glob([
"runtime/apu/*.h" "runtime/apu/*.h",
])) + if_rpcmem_enabled([ ])) + if_rpcmem_enabled([
"rpcmem.h", "rpcmem.h",
]), ]),
......
...@@ -8,9 +8,16 @@ set(CORE_SRCS ...@@ -8,9 +8,16 @@ set(CORE_SRCS
net.cc net.cc
net_def_adapter.cc net_def_adapter.cc
net_optimizer.cc net_optimizer.cc
op_context.cc ops/op_condition_builder.cc
operator.cc ops/op_condition_context.cc
ops/op_construct_context.cc
ops/op_context.cc
ops/operator.cc
ops/op_init_context.cc
quantize.cc quantize.cc
registry/op_delegator_registry.cc
registry/op_registration_info.cc
registry/ops_registry.cc
runtime_failure_mock.cc runtime_failure_mock.cc
types.cc types.cc
workspace.cc workspace.cc
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/net.h"
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <set> #include <set>
...@@ -20,8 +22,9 @@ ...@@ -20,8 +22,9 @@
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/memory_optimizer.h" #include "mace/core/memory_optimizer.h"
#include "mace/core/net.h" #include "mace/core/ops/op_init_context.h"
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/port/env.h" #include "mace/port/env.h"
#include "mace/utils/conf_util.h" #include "mace/utils/conf_util.h"
...@@ -33,7 +36,7 @@ ...@@ -33,7 +36,7 @@
namespace mace { namespace mace {
SerialNet::SerialNet(const OpRegistryBase *op_registry, SerialNet::SerialNet(const OpRegistry *op_registry,
const NetDef *net_def, const NetDef *net_def,
Workspace *ws, Workspace *ws,
Device *target_device, Device *target_device,
......
...@@ -21,13 +21,14 @@ ...@@ -21,13 +21,14 @@
#include <unordered_map> #include <unordered_map>
#include <sstream> #include <sstream>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
namespace mace { namespace mace {
class RunMetadata; class RunMetadata;
class Workspace; class Workspace;
class MemoryOptimizer; class MemoryOptimizer;
class OpRegistry;
class NetBase { class NetBase {
public: public:
...@@ -44,7 +45,7 @@ class NetBase { ...@@ -44,7 +45,7 @@ class NetBase {
class SerialNet : public NetBase { class SerialNet : public NetBase {
public: public:
SerialNet(const OpRegistryBase *op_registry, SerialNet(const OpRegistry *op_registry,
const NetDef *net_def, const NetDef *net_def,
Workspace *ws, Workspace *ws,
Device *target_device, Device *target_device,
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/ops/op_condition_context.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h" #include "mace/core/runtime/opencl/opencl_util.h"
...@@ -82,7 +84,7 @@ void BuildTransposeOpDef( ...@@ -82,7 +84,7 @@ void BuildTransposeOpDef(
} // namespace } // namespace
NetDefAdapter::NetDefAdapter(const OpRegistryBase *op_registry, NetDefAdapter::NetDefAdapter(const OpRegistry *op_registry,
const Workspace *ws) const Workspace *ws)
: op_registry_(op_registry), ws_(ws) {} : op_registry_(op_registry), ws_(ws) {}
......
...@@ -23,14 +23,17 @@ ...@@ -23,14 +23,17 @@
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
#include "mace/port/port.h" #include "mace/port/port.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/net_optimizer.h" #include "mace/core/net_optimizer.h"
namespace mace { namespace mace {
class OpRegistryBase;
class Workspace;
class Device; class Device;
class OpConditionContext;
class OperatorDef;
class OpRegistry;
class Workspace;
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
/// Conventions /// Conventions
...@@ -49,7 +52,7 @@ class Device; ...@@ -49,7 +52,7 @@ class Device;
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
class NetDefAdapter { class NetDefAdapter {
public: public:
NetDefAdapter(const OpRegistryBase *op_registry, NetDefAdapter(const OpRegistry *op_registry,
const Workspace *ws); const Workspace *ws);
// Adapt original net_def to a better net. // Adapt original net_def to a better net.
// 1. Adapt device: choose best device for every op in the net. // 1. Adapt device: choose best device for every op in the net.
...@@ -122,7 +125,7 @@ class NetDefAdapter { ...@@ -122,7 +125,7 @@ class NetDefAdapter {
std::string DebugString(const NetDef *net_def); std::string DebugString(const NetDef *net_def);
private: private:
const OpRegistryBase *op_registry_; const OpRegistry *op_registry_;
const Workspace *ws_; const Workspace *ws_;
NetOptimizer net_optimizer_; NetOptimizer net_optimizer_;
}; };
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_OPERATOR_H_
#define MACE_CORE_OPERATOR_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "mace/core/arg_helper.h"
#include "mace/core/op_context.h"
#include "mace/core/tensor.h"
#include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
// OpConditionContext has all information used for choosing proper Op
class OpConditionContext {
public:
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
OpConditionContext(const Workspace *ws, TensorShapeMap *info);
~OpConditionContext() = default;
void set_operator_def(const OperatorDef *operator_def);
inline const OperatorDef *operator_def() const {
return operator_def_;
}
inline const Workspace *workspace() const {
return ws_;
}
inline void set_device(Device *device) {
device_ = device;
}
inline Device *device() const {
return device_;
}
inline TensorShapeMap *tensor_shape_info() const {
return tensor_shape_info_;
}
void set_output_mem_type(MemoryType type);
inline MemoryType output_mem_type() const {
return output_mem_type_;
}
void SetInputInfo(size_t idx, MemoryType mem_type, DataType dt);
MemoryType GetInputMemType(size_t idx) const;
DataType GetInputDataType(size_t idx) const;
#ifdef MACE_ENABLE_OPENCL
void SetInputOpenCLBufferType(size_t idx, OpenCLBufferType buffer_type);
OpenCLBufferType GetInputOpenCLBufferType(size_t idx) const;
#endif // MACE_ENABLE_OPENCL
private:
const OperatorDef *operator_def_;
const Workspace *ws_;
Device *device_;
TensorShapeMap *tensor_shape_info_;
// used for memory transform
std::vector<MemoryType> input_mem_types_;
std::vector<DataType> input_data_types_;
MemoryType output_mem_type_; // there is only one output memory type now.
#ifdef MACE_ENABLE_OPENCL
std::vector<OpenCLBufferType> input_opencl_buffer_types_;
#endif // MACE_ENABLE_OPENCL
};
// memory_optimizer, device
class OpConstructContext {
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
public:
explicit OpConstructContext(Workspace *ws);
~OpConstructContext() = default;
void set_operator_def(std::shared_ptr<OperatorDef> operator_def);
inline std::shared_ptr<OperatorDef> operator_def() const {
return operator_def_;
}
inline Workspace *workspace() const {
return ws_;
}
inline void set_device(Device *device) {
device_ = device;
}
inline Device *device() const {
return device_;
}
#ifdef MACE_ENABLE_OPENCL
inline MemoryType GetOpMemoryType() const {
return static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, OutputMemoryTypeTagName(),
static_cast<int>(MemoryType::CPU_BUFFER)));
}
#endif // MACE_ENABLE_OPENCL
private:
std::shared_ptr<OperatorDef> operator_def_;
Workspace *ws_;
Device *device_;
};
// memory_optimizer, device
class OpInitContext {
public:
explicit OpInitContext(Workspace *ws, Device *device = nullptr);
~OpInitContext() = default;
inline Workspace *workspace() const {
return ws_;
}
inline void set_device(Device *device) {
device_ = device;
}
inline Device *device() const {
return device_;
}
private:
Workspace *ws_;
Device *device_;
};
// Conventions
// * If there exist format, NHWC is the default format
// * The input/output format of CPU ops with float data type is NCHW
// * The input/output format of GPU ops and CPU Quantization ops is NHWC
// * Inputs' data type is same as the operation data type by default.
// * The outputs' data type is same as the operation data type by default.
class Operation {
public:
explicit Operation(OpConstructContext *context);
virtual ~Operation() = default;
template<typename T>
inline T GetOptionalArg(const std::string &name,
const T &default_value) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ProtoArgHelper::GetOptionalArg<OperatorDef, T>(
*operator_def_, name, default_value);
}
template<typename T>
inline std::vector<T> GetRepeatedArgs(
const std::string &name, const std::vector<T> &default_value = {}) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ProtoArgHelper::GetRepeatedArgs<OperatorDef, T>(
*operator_def_, name, default_value);
}
inline DeviceType device_type() const {
return static_cast<DeviceType>(operator_def_->device_type());
}
inline const Tensor *Input(unsigned int idx) {
MACE_CHECK(idx < inputs_.size());
return inputs_[idx];
}
inline Tensor *Output(int idx) { return outputs_[idx]; }
inline int InputSize() { return inputs_.size(); }
inline int OutputSize() { return outputs_.size(); }
inline const std::vector<const Tensor *> &Inputs() const { return inputs_; }
inline const std::vector<Tensor *> &Outputs() { return outputs_; }
// Run Op asynchronously (depends on device), return a future if not nullptr.
virtual MaceStatus Init(OpInitContext *);
virtual MaceStatus Run(OpContext *) = 0;
inline const OperatorDef &debug_def() const {
MACE_CHECK(has_debug_def(), "operator_def was null!");
return *operator_def_;
}
inline void set_debug_def(
const std::shared_ptr<OperatorDef> &operator_def) {
operator_def_ = operator_def;
}
inline bool has_debug_def() const { return operator_def_ != nullptr; }
inline std::shared_ptr<OperatorDef> operator_def() {
return operator_def_;
}
protected:
std::shared_ptr<OperatorDef> operator_def_;
std::vector<const Tensor *> inputs_;
std::vector<Tensor *> outputs_;
MACE_DISABLE_COPY_AND_ASSIGN(Operation);
};
// MACE_OP_INPUT_TAGS and MACE_OP_OUTPUT_TAGS are optional features to name the
// indices of the operator's inputs and outputs, in order to avoid confusion.
// For example, for a fully convolution layer that has input, weight and bias,
// you can define its input tags as:
// MACE_OP_INPUT_TAGS(INPUT, WEIGHT, BIAS);
// And in the code, instead of doing
// auto& weight = Input(1);
// you can now do
// auto& weight = Input(WEIGHT);
// to make it more clear.
#define MACE_OP_INPUT_TAGS(first_input, ...) \
enum _InputTags { first_input = 0, __VA_ARGS__ }
#define MACE_OP_OUTPUT_TAGS(first_input, ...) \
enum _OutputTags { first_input = 0, __VA_ARGS__ }
struct OpRegistrationInfo {
public:
typedef std::function<std::unique_ptr<Operation>(OpConstructContext *)>
OpCreator;
typedef std::function<std::set<DeviceType>(OpConditionContext *)>
DevicePlacer;
typedef std::function<void(OpConditionContext *)> MemoryTypeSetter;
typedef std::function<std::vector<DataFormat>(OpConditionContext *)>
DataFormatSelector;
OpRegistrationInfo();
void AddDevice(DeviceType);
void Register(const std::string &key, OpCreator creator);
std::set<DeviceType> devices;
std::unordered_map<std::string, OpCreator> creators;
DevicePlacer device_placer;
MemoryTypeSetter memory_type_setter;
DataFormatSelector data_format_selector;
};
class OpConditionBuilder {
public:
explicit OpConditionBuilder(const std::string &type);
const std::string type() const;
OpConditionBuilder &SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer);
// If you set input memory type for specified Op,
// you must call OpConditionContext::set_output_mem_type
OpConditionBuilder &SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter);
OpConditionBuilder &SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector);
void Finalize(OpRegistrationInfo *info) const;
private:
std::string type_;
OpRegistrationInfo::DevicePlacer placer_;
OpRegistrationInfo::MemoryTypeSetter memory_type_setter_;
OpRegistrationInfo::DataFormatSelector data_format_selector_;
};
class OpRegistryBase {
public:
OpRegistryBase() = default;
virtual ~OpRegistryBase() = default;
MaceStatus Register(const std::string &op_type,
const DeviceType device_type,
const DataType dt,
OpRegistrationInfo::OpCreator creator);
MaceStatus Register(const OpConditionBuilder &builder);
const std::set<DeviceType> AvailableDevices(
const std::string &op_type, OpConditionContext *context) const;
void GetInOutMemoryTypes(
const std::string &op_type, OpConditionContext *context) const;
const std::vector<DataFormat> InputsDataFormat(
const std::string &op_type, OpConditionContext *context) const;
std::unique_ptr<Operation> CreateOperation(
OpConstructContext *context,
DeviceType device_type) const;
template<class DerivedType>
static std::unique_ptr<Operation> DefaultCreator(
OpConstructContext *context) {
return std::unique_ptr<Operation>(new DerivedType(context));
}
private:
std::unordered_map<
std::string,
std::unique_ptr<OpRegistrationInfo>> registry_;
MACE_DISABLE_COPY_AND_ASSIGN(OpRegistryBase);
};
#define MACE_REGISTER_OP(op_registry, op_type, class_name, device, dt) \
op_registry->Register(op_type, \
device, \
DataTypeToEnum<dt>::value, \
OpRegistryBase::DefaultCreator<class_name<device, dt>>)
#define MACE_REGISTER_OP_BY_CLASS( \
op_registry, op_type, class_name, device, dt) \
op_registry->Register(op_type, \
device, \
DataTypeToEnum<dt>::value, \
OpRegistryBase::DefaultCreator<class_name>)
#ifdef MACE_ENABLE_OPENCL
#define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name) \
op_registry->Register( \
op_type, \
DeviceType::GPU, \
DT_FLOAT, \
OpRegistryBase::DefaultCreator<class_name<DeviceType::GPU, float>>)
#else
#define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name)
#endif
#define MACE_REGISTER_OP_CONDITION(op_registry, builder) \
op_registry->Register(builder)
} // namespace mace
#endif // MACE_CORE_OPERATOR_H_
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,39 +12,48 @@ ...@@ -12,39 +12,48 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_GENERAL_H_ #include "mace/core/ops/op_condition_builder.h"
#define MACE_OPS_ARM_FP32_CONV_GENERAL_H_
#include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace { namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dGeneral : public Conv2dBase {
public:
Conv2dGeneral(const std::vector<int> &strides,
const std::vector<int> &dilations,
const std::vector<int> &paddings,
const Padding padding_type)
: Conv2dBase(strides, dilations, paddings, padding_type) {}
virtual ~Conv2dGeneral() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) override;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_GENERAL_H_ OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {}
const std::string OpConditionBuilder::type() const {
return type_;
}
OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer) {
placer_ = placer;
return *this;
}
OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter;
return *this;
}
OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
}
void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr) {
if (placer_) {
info->device_placer = placer_;
}
if (memory_type_setter_) {
info->memory_type_setter = memory_type_setter_;
}
if (data_format_selector_) {
info->data_format_selector = data_format_selector_;
}
}
}
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_OPS_OP_CONDITION_BUILDER_H_
#define MACE_CORE_OPS_OP_CONDITION_BUILDER_H_
#include <memory>
#include <string>
#include "mace/core/registry/op_registration_info.h"
#include "mace/core/types.h"
namespace mace {
class OpConditionBuilder {
public:
explicit OpConditionBuilder(const std::string &type);
const std::string type() const;
OpConditionBuilder &SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer);
// If you set input memory type for specified Op,
// you must call OpConditionContext::set_output_mem_type
OpConditionBuilder &SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter);
OpConditionBuilder &SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector);
void Finalize(OpRegistrationInfo *info) const;
private:
std::string type_;
OpRegistrationInfo::DevicePlacer placer_;
OpRegistrationInfo::MemoryTypeSetter memory_type_setter_;
OpRegistrationInfo::DataFormatSelector data_format_selector_;
};
} // namespace mace
#endif // MACE_CORE_OPS_OP_CONDITION_BUILDER_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/ops/op_condition_context.h"
#include "mace/core/arg_helper.h"
#include "mace/proto/mace.pb.h"
#include "mace/utils/logging.h"
namespace mace {
OpConditionContext::OpConditionContext(
const Workspace *ws,
OpConditionContext::TensorShapeMap *info)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(info) {}
void OpConditionContext::set_operator_def(
const OperatorDef *operator_def) {
operator_def_ = operator_def;
input_data_types_.clear();
}
void OpConditionContext::SetInputInfo(size_t idx,
MemoryType mem_type,
DataType dt) {
if (input_mem_types_.empty()) {
// the default inputs' memory types are same as output memory type.
input_mem_types_.resize(operator_def_->input_size(), output_mem_type_);
}
if (input_data_types_.empty()) {
// the default inputs' data types are same as operation's data type.
DataType op_dt = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DataType::DT_FLOAT)));
input_data_types_.resize(operator_def_->input_size(), op_dt);
}
MACE_CHECK(idx < input_mem_types_.size() && idx < input_data_types_.size());
input_mem_types_[idx] = mem_type;
input_data_types_[idx] = dt;
}
void OpConditionContext::set_output_mem_type(MemoryType type) {
MACE_CHECK(operator_def_ != nullptr);
output_mem_type_ = type;
input_mem_types_.clear();
}
MemoryType OpConditionContext::GetInputMemType(size_t idx) const {
if (input_mem_types_.empty()) {
return output_mem_type_;
}
MACE_CHECK(idx < input_mem_types_.size(),
idx, " < ", input_mem_types_.size());
return input_mem_types_[idx];
}
DataType OpConditionContext::GetInputDataType(size_t idx) const {
if (input_data_types_.empty()) {
// the default inputs' data types are same as operation's data type.
return static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DataType::DT_FLOAT)));
}
MACE_CHECK(idx < input_data_types_.size());
return input_data_types_[idx];
}
#ifdef MACE_ENABLE_OPENCL
void OpConditionContext::SetInputOpenCLBufferType(
size_t idx, OpenCLBufferType buffer_type) {
if (input_opencl_buffer_types_.empty()) {
// the default inputs' memory types are same as output memory type.
input_opencl_buffer_types_.resize(operator_def_->input_size(),
OpenCLBufferType::IN_OUT_CHANNEL);
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
input_opencl_buffer_types_[idx] = buffer_type;
}
OpenCLBufferType OpConditionContext::GetInputOpenCLBufferType(
size_t idx) const {
if (input_opencl_buffer_types_.empty()) {
return OpenCLBufferType::IN_OUT_CHANNEL;
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
return input_opencl_buffer_types_[idx];
}
#endif // MACE_ENABLE_OPENCL
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_OPS_OP_CONDITION_CONTEXT_H_
#define MACE_CORE_OPS_OP_CONDITION_CONTEXT_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "mace/core/types.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
class Workspace;
class Device;
// OpConditionContext has all information used for choosing proper Op
class OpConditionContext {
public:
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
OpConditionContext(const Workspace *ws, TensorShapeMap *info);
~OpConditionContext() = default;
void set_operator_def(const OperatorDef *operator_def);
const OperatorDef *operator_def() const {
return operator_def_;
}
const Workspace *workspace() const {
return ws_;
}
void set_device(Device *device) {
device_ = device;
}
Device *device() const {
return device_;
}
TensorShapeMap *tensor_shape_info() const {
return tensor_shape_info_;
}
void set_output_mem_type(MemoryType type);
MemoryType output_mem_type() const {
return output_mem_type_;
}
void SetInputInfo(size_t idx, MemoryType mem_type, DataType dt);
MemoryType GetInputMemType(size_t idx) const;
DataType GetInputDataType(size_t idx) const;
#ifdef MACE_ENABLE_OPENCL
void SetInputOpenCLBufferType(size_t idx, OpenCLBufferType buffer_type);
OpenCLBufferType GetInputOpenCLBufferType(size_t idx) const;
#endif // MACE_ENABLE_OPENCL
private:
const OperatorDef *operator_def_;
const Workspace *ws_;
Device *device_;
TensorShapeMap *tensor_shape_info_;
// used for memory transform
std::vector<MemoryType> input_mem_types_;
std::vector<DataType> input_data_types_;
MemoryType output_mem_type_; // there is only one output memory type now.
#ifdef MACE_ENABLE_OPENCL
std::vector<OpenCLBufferType> input_opencl_buffer_types_;
#endif // MACE_ENABLE_OPENCL
};
} // namespace mace
#endif // MACE_CORE_OPS_OP_CONDITION_CONTEXT_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/ops/op_construct_context.h"
namespace mace {
OpConstructContext::OpConstructContext(Workspace *ws)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr) {}
void OpConstructContext::set_operator_def(
std::shared_ptr<OperatorDef> operator_def) {
operator_def_ = operator_def;
}
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,49 +12,62 @@ ...@@ -12,49 +12,62 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_FP32_DECONV_2D_GENERAL_H_ #ifndef MACE_CORE_OPS_OP_CONSTRUCT_CONTEXT_H_
#define MACE_OPS_ARM_FP32_DECONV_2D_GENERAL_H_ #define MACE_CORE_OPS_OP_CONSTRUCT_CONTEXT_H_
#include <vector>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "mace/public/mace.h" #include "mace/core/arg_helper.h"
#include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h" #include "mace/proto/mace.pb.h"
#include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace { namespace mace {
namespace ops { class Device;
namespace arm { class Workspace;
namespace fp32 {
// memory_optimizer, device
class OpConstructContext {
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
class Deconv2dGeneral : public Deconv2dBase {
public: public:
Deconv2dGeneral(const std::vector<int> &strides, explicit OpConstructContext(Workspace *ws);
const std::vector<int> &dilations, ~OpConstructContext() = default;
const std::vector<int> &paddings,
const Padding padding_type, void set_operator_def(std::shared_ptr<OperatorDef> operator_def);
const FrameworkType framework_type)
: Deconv2dBase(strides, std::shared_ptr<OperatorDef> operator_def() const {
dilations, return operator_def_;
paddings, }
padding_type,
framework_type) {} Workspace *workspace() const {
virtual ~Deconv2dGeneral() {} return ws_;
}
MaceStatus Compute(
const OpContext *context, void set_device(Device *device) {
const Tensor *input, device_ = device;
const Tensor *filter, }
const Tensor *output_shape,
Tensor *output) override; Device *device() const {
return device_;
}
#ifdef MACE_ENABLE_OPENCL
inline MemoryType GetOpMemoryType() const {
return static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, OutputMemoryTypeTagName(),
static_cast<int>(MemoryType::CPU_BUFFER)));
}
#endif // MACE_ENABLE_OPENCL
private:
std::shared_ptr<OperatorDef> operator_def_;
Workspace *ws_;
Device *device_;
}; };
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_FP32_DECONV_2D_GENERAL_H_ #endif // MACE_CORE_OPS_OP_CONSTRUCT_CONTEXT_H_
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
namespace mace { namespace mace {
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_CORE_OP_CONTEXT_H_ #ifndef MACE_CORE_OPS_OP_CONTEXT_H_
#define MACE_CORE_OP_CONTEXT_H_ #define MACE_CORE_OPS_OP_CONTEXT_H_
#include "mace/core/device.h" #include "mace/core/device.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
...@@ -35,8 +35,7 @@ class OpContext { ...@@ -35,8 +35,7 @@ class OpContext {
Device *device_; Device *device_;
Workspace *ws_; Workspace *ws_;
StatsFuture *future_; StatsFuture *future_;
// metadata
}; };
} // namespace mace } // namespace mace
#endif // MACE_CORE_OP_CONTEXT_H_ #endif // MACE_CORE_OPS_OP_CONTEXT_H_
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,40 +12,47 @@ ...@@ -12,40 +12,47 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_REF_ACTIVATION_H_ #ifndef MACE_CORE_OPS_OP_DELEGATOR_H_
#define MACE_OPS_REF_ACTIVATION_H_ #define MACE_CORE_OPS_OP_DELEGATOR_H_
#include "mace/core/op_context.h" #include <memory>
#include "mace/ops/common/activation_type.h"
#include "mace/utils/macros.h"
#include "mace/utils/memory.h"
namespace mace { namespace mace {
namespace ops {
namespace ref {
class Activation { enum ImplType {
REF = 0,
NEON,
};
#ifdef MACE_ENABLE_NEON
#define MACE_CPU_IMPL_TYPE NEON
#else
#define MACE_CPU_IMPL_TYPE REF
#endif
struct DelegatorParam {
public:
DelegatorParam() = default;
virtual ~DelegatorParam() = default;
};
class OpDelegator {
public: public:
explicit Activation(ActivationType type, explicit OpDelegator(const DelegatorParam &param) {
const float limit, MACE_UNUSED(param);
const float leakyrelu_coefficient); }
~Activation() = default; virtual ~OpDelegator() = default;
MaceStatus Compute( template<class DerivedType, class ParamType>
const OpContext *context, static std::unique_ptr<OpDelegator> DefaultCreator(
const Tensor *input, const DelegatorParam &param) {
Tensor *output); return make_unique<DerivedType>(static_cast<const ParamType &>(param));
}
private:
void DoActivation(const OpContext *context,
const Tensor *input,
Tensor *output);
ActivationType type_;
const float limit_;
const float leakyrelu_coefficient_;
}; };
} // namespace ref
} // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_REF_ACTIVATION_H_ #endif // MACE_CORE_OPS_OP_DELEGATOR_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/ops/op_init_context.h"
namespace mace {
OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {}
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_OPS_OP_INIT_CONTEXT_H_
#define MACE_CORE_OPS_OP_INIT_CONTEXT_H_
namespace mace {
class Workspace;
class Device;
// memory_optimizer, device
class OpInitContext {
public:
explicit OpInitContext(Workspace *ws, Device *device = nullptr);
~OpInitContext() = default;
Workspace *workspace() const {
return ws_;
}
void set_device(Device *device) {
device_ = device;
}
Device *device() const {
return device_;
}
private:
Workspace *ws_;
Device *device_;
};
} // namespace mace
#endif // MACE_CORE_OPS_OP_INIT_CONTEXT_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/ops/operator.h"
#include <vector>
#include "mace/core/ops/op_construct_context.h"
#include "mace/core/ops/op_init_context.h"
namespace mace {
Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def()) {}
MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace();
for (const std::string &input_str : operator_def_->input()) {
const Tensor *tensor = ws->GetTensor(input_str);
MACE_CHECK(tensor != nullptr, "op ", operator_def_->type(),
": Encountered a non-existing input tensor: ", input_str);
inputs_.push_back(tensor);
}
for (int i = 0; i < operator_def_->output_size(); ++i) {
const std::string output_str = operator_def_->output(i);
if (ws->HasTensor(output_str)) {
outputs_.push_back(ws->GetTensor(output_str));
} else {
MACE_CHECK(
operator_def_->output_type_size() == 0 ||
operator_def_->output_size() == operator_def_->output_type_size(),
"operator output size != operator output type size",
operator_def_->output_size(),
operator_def_->output_type_size());
DataType output_type;
if (i < operator_def_->output_type_size()) {
output_type = operator_def_->output_type(i);
} else {
output_type = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DT_FLOAT)));
}
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, context->device()->allocator(), output_type)));
}
if (i < operator_def_->output_shape_size()) {
std::vector<index_t>
shape_configured(operator_def_->output_shape(i).dims_size());
for (size_t dim = 0; dim < shape_configured.size(); ++dim) {
shape_configured[dim] = operator_def_->output_shape(i).dims(dim);
}
ws->GetTensor(output_str)->SetShapeConfigured(shape_configured);
}
}
return MaceStatus::MACE_SUCCESS;
}
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_OPS_OPERATOR_H_
#define MACE_CORE_OPS_OPERATOR_H_
#include <memory>
#include <string>
#include <vector>
#include "mace/core/arg_helper.h"
#include "mace/core/ops/op_construct_context.h"
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
class OpInitContext;
// Conventions
// * If there exist format, NHWC is the default format
// * The input/output format of CPU ops with float data type is NCHW
// * The input/output format of GPU ops and CPU Quantization ops is NHWC
// * Inputs' data type is same as the operation data type by default.
// * The outputs' data type is same as the operation data type by default.
class Operation {
public:
explicit Operation(OpConstructContext *context);
virtual ~Operation() = default;
template<typename T>
T GetOptionalArg(const std::string &name,
const T &default_value) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ProtoArgHelper::GetOptionalArg<OperatorDef, T>(
*operator_def_, name, default_value);
}
template<typename T>
std::vector<T> GetRepeatedArgs(
const std::string &name, const std::vector<T> &default_value = {}) const {
MACE_CHECK(operator_def_, "operator_def was null!");
return ProtoArgHelper::GetRepeatedArgs<OperatorDef, T>(
*operator_def_, name, default_value);
}
DeviceType device_type() const {
return static_cast<DeviceType>(operator_def_->device_type());
}
const Tensor *Input(unsigned int idx) {
MACE_CHECK(idx < inputs_.size());
return inputs_[idx];
}
Tensor *Output(int idx) { return outputs_[idx]; }
int InputSize() { return inputs_.size(); }
int OutputSize() { return outputs_.size(); }
const std::vector<const Tensor *> &Inputs() const { return inputs_; }
const std::vector<Tensor *> &Outputs() { return outputs_; }
// Run Op asynchronously (depends on device), return a future if not nullptr.
virtual MaceStatus Init(OpInitContext *);
virtual MaceStatus Run(OpContext *) = 0;
const OperatorDef &debug_def() const {
MACE_CHECK(has_debug_def(), "operator_def was null!");
return *operator_def_;
}
void set_debug_def(
const std::shared_ptr<OperatorDef> &operator_def) {
operator_def_ = operator_def;
}
bool has_debug_def() const { return operator_def_ != nullptr; }
inline std::shared_ptr<OperatorDef> operator_def() {
return operator_def_;
}
protected:
std::shared_ptr<OperatorDef> operator_def_;
std::vector<const Tensor *> inputs_;
std::vector<Tensor *> outputs_;
MACE_DISABLE_COPY_AND_ASSIGN(Operation);
};
// MACE_OP_INPUT_TAGS and MACE_OP_OUTPUT_TAGS are optional features to name the
// indices of the operator's inputs and outputs, in order to avoid confusion.
// For example, for a fully convolution layer that has input, weight and bias,
// you can define its input tags as:
// MACE_OP_INPUT_TAGS(INPUT, WEIGHT, BIAS);
// And in the code, instead of doing
// auto& weight = Input(1);
// you can now do
// auto& weight = Input(WEIGHT);
// to make it more clear.
#define MACE_OP_INPUT_TAGS(first_input, ...) \
enum _InputTags { first_input = 0, __VA_ARGS__ }
#define MACE_OP_OUTPUT_TAGS(first_input, ...) \
enum _OutputTags { first_input = 0, __VA_ARGS__ }
} // namespace mace
#endif // MACE_CORE_OPS_OPERATOR_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/registry/op_delegator_registry.h"
#include <utility>
#include "mace/utils/logging.h"
namespace mace {
MaceStatus OpDelegatorRegistry::Register(const std::string &key,
DelegatorCreator creator) {
MACE_CHECK(registry_.count(key) == 0, "Register an exist key.");
registry_[key] = std::move(creator);
return MaceStatus::MACE_SUCCESS;
}
DelegatorCreator OpDelegatorRegistry::GetCreator(const std::string &key) const {
MACE_CHECK(registry_.count(key) > 0, key, " not exist.");
return registry_.at(key);
}
template<> const char *DType<float>::name_ = "float";
template<> const char *DType<int>::name_ = "int";
template<> const char *DType<uint8_t>::name_ = "uint8_t";
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_REGISTRY_OP_DELEGATOR_REGISTRY_H_
#define MACE_CORE_REGISTRY_OP_DELEGATOR_REGISTRY_H_
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "mace/core/ops/op_delegator.h"
#include "mace/proto/mace.pb.h"
#include "mace/public/mace.h"
namespace mace {
typedef std::function<std::unique_ptr<OpDelegator>(const DelegatorParam &)>
DelegatorCreator;
class OpDelegatorRegistry {
public:
OpDelegatorRegistry() = default;
~OpDelegatorRegistry() = default;
MaceStatus Register(const std::string &key, DelegatorCreator creator);
DelegatorCreator GetCreator(const std::string &key) const;
private:
std::unordered_map<std::string, DelegatorCreator> registry_;
};
template<typename T>
struct DType { static const char *name_; };
template<> const char *DType<float>::name_;
template<> const char *DType<int>::name_;
template<> const char *DType<uint8_t>::name_;
} // namespace mace
#ifndef MACE_DELEGATOR_KEY_TMP
#define MACE_DELEGATOR_KEY_TMP(delegator_name, device, DT, impl) \
(std::string(#delegator_name"_"#device"_"#impl"_") + DType<DT>::name_)
#endif // MACE_DELEGATOR_KEY_TMP
#ifndef MACE_DELEGATOR_KEY
#define MACE_DELEGATOR_KEY(delegator_name, device, DT, impl) \
MACE_DELEGATOR_KEY_TMP(delegator_name, device, DT, impl)
#endif // MACE_DELEGATOR_KEY
#ifndef MACE_DELEGATOR_KEY_EX_TMP
#define MACE_DELEGATOR_KEY_EX_TMP(delegator_name, device, DT, impl, tag) \
(std::string(#delegator_name"_"#device"_"#impl"_"#tag"_") + DType<DT>::name_)
#endif // MACE_DELEGATOR_KEY_EX_TMP
#ifndef MACE_DELEGATOR_KEY_EX
#define MACE_DELEGATOR_KEY_EX(delegator_name, device, DT, impl, tag) \
MACE_DELEGATOR_KEY_EX_TMP(delegator_name, device, DT, impl, tag)
#endif // MACE_DELEGATOR_KEY_EX
#ifndef MACE_REGISTER_DELEGATOR
#define MACE_REGISTER_DELEGATOR(registry, class_name, param_name, key) \
void Register##class_name##Delegator(OpDelegatorRegistry *registry) { \
registry->Register( \
key, OpDelegator::DefaultCreator<class_name, param_name>); \
}
#endif // MACE_REGISTER_DELEGATOR
#ifndef MACE_DEFINE_DELEGATOR_CREATOR
#define MACE_DEFINE_DELEGATOR_CREATOR(class_name) \
static std::unique_ptr<class_name> Create( \
Workspace *workspace, const std::string &tag, \
const DelegatorParam &param) { \
DelegatorCreator creator = \
workspace->GetDelegatorRegistry()->GetCreator(tag); \
std::unique_ptr<OpDelegator> delegator = creator(param); \
return std::unique_ptr<class_name>( \
static_cast<class_name *>(delegator.release())); \
}
#endif // MACE_DEFINE_DELEGATOR_CREATOR
#endif // MACE_CORE_REGISTRY_OP_DELEGATOR_REGISTRY_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/registry/op_registration_info.h"
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "mace/core/ops/op_condition_context.h"
namespace mace {
OpRegistrationInfo::OpRegistrationInfo() {
// default device type placer
device_placer = [this](OpConditionContext *context) -> std::set<DeviceType> {
MACE_UNUSED(context);
return this->devices;
};
// default input and output memory type setter
memory_type_setter = [](OpConditionContext *context) -> void {
if (context->device()->device_type() == DeviceType::GPU) {
#ifdef MACE_ENABLE_OPENCL
if (context->device()->gpu_runtime()->UseImageMemory()) {
context->set_output_mem_type(MemoryType::GPU_IMAGE);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
}
#endif // MACE_ENABLE_OPENCL
} else {
context->set_output_mem_type(MemoryType::CPU_BUFFER);
}
};
data_format_selector = [](OpConditionContext *context)
-> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return std::vector<DataFormat>(context->operator_def()->input_size(),
op_data_format);
};
}
void OpRegistrationInfo::AddDevice(DeviceType device) {
devices.insert(device);
}
void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
VLOG(3) << "Registering: " << key;
MACE_CHECK(creators.count(key) == 0, "Key already registered: ", key);
creators[key] = std::move(creator);
}
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,40 +12,45 @@ ...@@ -12,40 +12,45 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#define MACE_OPS_ARM_FP32_CONV_2D_1X1_H_
#ifndef MACE_CORE_REGISTRY_OP_REGISTRATION_INFO_H_
#define MACE_CORE_REGISTRY_OP_REGISTRATION_INFO_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h" #include "mace/core/ops/operator.h"
#include "mace/core/op_context.h" #include "mace/proto/mace.pb.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace { namespace mace {
namespace ops { class OpConstructContext;
namespace arm { class OpConditionContext;
namespace fp32 {
class Conv2dK1x1 : public Conv2dBase { class OpRegistrationInfo {
public: public:
Conv2dK1x1(const std::vector<int> &paddings, const Padding padding_type) typedef std::function<std::unique_ptr<Operation>(OpConstructContext *)>
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} OpCreator;
virtual ~Conv2dK1x1() {} typedef std::function<std::set<DeviceType>(OpConditionContext *)>
DevicePlacer;
MaceStatus Compute( typedef std::function<void(OpConditionContext *)> MemoryTypeSetter;
const OpContext *context, typedef std::function<std::vector<DataFormat>(OpConditionContext *)>
const Tensor *input, DataFormatSelector;
const Tensor *filter,
Tensor *output) override; OpRegistrationInfo();
private:
Gemm gemm_;
};
} // namespace fp32 void AddDevice(DeviceType);
} // namespace arm
} // namespace ops void Register(const std::string &key, OpCreator creator);
std::set<DeviceType> devices;
std::unordered_map<std::string, OpCreator> creators;
DevicePlacer device_placer;
MemoryTypeSetter memory_type_setter;
DataFormatSelector data_format_selector;
};
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_1X1_H_ #endif // MACE_CORE_REGISTRY_OP_REGISTRATION_INFO_H_
// Copyright 2018 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,153 +12,15 @@ ...@@ -12,153 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <sstream> #include "mace/core/registry/ops_registry.h"
#include <map> #include <map>
#include <memory> #include <memory>
#include <set>
#include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h"
namespace mace { namespace mace {
OpConditionContext::OpConditionContext(
const Workspace *ws,
OpConditionContext::TensorShapeMap *info)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(info) {}
void OpConditionContext::set_operator_def(
const OperatorDef *operator_def) {
operator_def_ = operator_def;
input_data_types_.clear();
}
void OpConditionContext::SetInputInfo(size_t idx,
MemoryType mem_type,
DataType dt) {
if (input_mem_types_.empty()) {
// the default inputs' memory types are same as output memory type.
input_mem_types_.resize(operator_def_->input_size(), output_mem_type_);
}
if (input_data_types_.empty()) {
// the default inputs' data types are same as operation's data type.
DataType op_dt = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DataType::DT_FLOAT)));
input_data_types_.resize(operator_def_->input_size(), op_dt);
}
MACE_CHECK(idx < input_mem_types_.size() && idx < input_data_types_.size());
input_mem_types_[idx] = mem_type;
input_data_types_[idx] = dt;
}
void OpConditionContext::set_output_mem_type(MemoryType type) {
MACE_CHECK(operator_def_ != nullptr);
output_mem_type_ = type;
input_mem_types_.clear();
}
MemoryType OpConditionContext::GetInputMemType(size_t idx) const {
if (input_mem_types_.empty()) {
return output_mem_type_;
}
MACE_CHECK(idx < input_mem_types_.size(),
idx, " < ", input_mem_types_.size());
return input_mem_types_[idx];
}
DataType OpConditionContext::GetInputDataType(size_t idx) const {
if (input_data_types_.empty()) {
// the default inputs' data types are same as operation's data type.
return static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DataType::DT_FLOAT)));
}
MACE_CHECK(idx < input_data_types_.size());
return input_data_types_[idx];
}
#ifdef MACE_ENABLE_OPENCL
void OpConditionContext::SetInputOpenCLBufferType(
size_t idx, OpenCLBufferType buffer_type) {
if (input_opencl_buffer_types_.empty()) {
// the default inputs' memory types are same as output memory type.
input_opencl_buffer_types_.resize(operator_def_->input_size(),
OpenCLBufferType::IN_OUT_CHANNEL);
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
input_opencl_buffer_types_[idx] = buffer_type;
}
OpenCLBufferType OpConditionContext::GetInputOpenCLBufferType(
size_t idx) const {
if (input_opencl_buffer_types_.empty()) {
return OpenCLBufferType::IN_OUT_CHANNEL;
}
MACE_CHECK(idx < input_opencl_buffer_types_.size());
return input_opencl_buffer_types_[idx];
}
#endif // MACE_ENABLE_OPENCL
OpConstructContext::OpConstructContext(Workspace *ws)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr) {}
void OpConstructContext::set_operator_def(
std::shared_ptr<OperatorDef> operator_def) {
operator_def_ = operator_def;
}
OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {}
Operation::Operation(OpConstructContext *context)
: operator_def_(context->operator_def()) {}
MaceStatus Operation::Init(OpInitContext *context) {
Workspace *ws = context->workspace();
for (const std::string &input_str : operator_def_->input()) {
const Tensor *tensor = ws->GetTensor(input_str);
MACE_CHECK(tensor != nullptr, "op ", operator_def_->type(),
": Encountered a non-existing input tensor: ", input_str);
inputs_.push_back(tensor);
}
for (int i = 0; i < operator_def_->output_size(); ++i) {
const std::string output_str = operator_def_->output(i);
if (ws->HasTensor(output_str)) {
outputs_.push_back(ws->GetTensor(output_str));
} else {
MACE_CHECK(
operator_def_->output_type_size() == 0 ||
operator_def_->output_size() == operator_def_->output_type_size(),
"operator output size != operator output type size",
operator_def_->output_size(),
operator_def_->output_type_size());
DataType output_type;
if (i < operator_def_->output_type_size()) {
output_type = operator_def_->output_type(i);
} else {
output_type = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, "T", static_cast<int>(DT_FLOAT)));
}
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, context->device()->allocator(), output_type)));
}
if (i < operator_def_->output_shape_size()) {
std::vector<index_t>
shape_configured(operator_def_->output_shape(i).dims_size());
for (size_t dim = 0; dim < shape_configured.size(); ++dim) {
shape_configured[dim] = operator_def_->output_shape(i).dims(dim);
}
ws->GetTensor(output_str)->SetShapeConfigured(shape_configured);
}
}
return MaceStatus::MACE_SUCCESS;
}
// op registry
namespace { namespace {
class OpKeyBuilder { class OpKeyBuilder {
public: public:
...@@ -203,51 +65,7 @@ const std::string OpKeyBuilder::Build() { ...@@ -203,51 +65,7 @@ const std::string OpKeyBuilder::Build() {
} }
} // namespace } // namespace
OpRegistrationInfo::OpRegistrationInfo() { MaceStatus OpRegistry::Register(
// default device type placer
device_placer = [this](OpConditionContext *context) -> std::set<DeviceType> {
MACE_UNUSED(context);
return this->devices;
};
// default input and output memory type setter
memory_type_setter = [](OpConditionContext *context) -> void {
if (context->device()->device_type() == DeviceType::GPU) {
#ifdef MACE_ENABLE_OPENCL
if (context->device()->gpu_runtime()->UseImageMemory()) {
context->set_output_mem_type(MemoryType::GPU_IMAGE);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
}
#endif // MACE_ENABLE_OPENCL
} else {
context->set_output_mem_type(MemoryType::CPU_BUFFER);
}
};
data_format_selector = [](OpConditionContext *context)
-> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return std::vector<DataFormat>(context->operator_def()->input_size(),
op_data_format);
};
}
void OpRegistrationInfo::AddDevice(DeviceType device) {
devices.insert(device);
}
void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
VLOG(3) << "Registering: " << key;
MACE_CHECK(creators.count(key) == 0, "Key already registered: ", key);
creators[key] = creator;
}
MaceStatus OpRegistryBase::Register(
const std::string &op_type, const std::string &op_type,
const DeviceType device_type, const DeviceType device_type,
const DataType dt, const DataType dt,
...@@ -266,7 +84,7 @@ MaceStatus OpRegistryBase::Register( ...@@ -266,7 +84,7 @@ MaceStatus OpRegistryBase::Register(
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MaceStatus OpRegistryBase::Register( MaceStatus OpRegistry::Register(
const OpConditionBuilder &builder) { const OpConditionBuilder &builder) {
std::string op_type = builder.type(); std::string op_type = builder.type();
if (registry_.count(op_type) == 0) { if (registry_.count(op_type) == 0) {
...@@ -277,7 +95,7 @@ MaceStatus OpRegistryBase::Register( ...@@ -277,7 +95,7 @@ MaceStatus OpRegistryBase::Register(
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
const std::set<DeviceType> OpRegistryBase::AvailableDevices( const std::set<DeviceType> OpRegistry::AvailableDevices(
const std::string &op_type, OpConditionContext *context) const { const std::string &op_type, OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered."); op_type, " operation is not registered.");
...@@ -285,7 +103,7 @@ const std::set<DeviceType> OpRegistryBase::AvailableDevices( ...@@ -285,7 +103,7 @@ const std::set<DeviceType> OpRegistryBase::AvailableDevices(
return registry_.at(op_type)->device_placer(context); return registry_.at(op_type)->device_placer(context);
} }
void OpRegistryBase::GetInOutMemoryTypes( void OpRegistry::GetInOutMemoryTypes(
const std::string &op_type, const std::string &op_type,
OpConditionContext *context) const { OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
...@@ -293,7 +111,7 @@ void OpRegistryBase::GetInOutMemoryTypes( ...@@ -293,7 +111,7 @@ void OpRegistryBase::GetInOutMemoryTypes(
return registry_.at(op_type)->memory_type_setter(context); return registry_.at(op_type)->memory_type_setter(context);
} }
const std::vector<DataFormat> OpRegistryBase::InputsDataFormat( const std::vector<DataFormat> OpRegistry::InputsDataFormat(
const std::string &op_type, const std::string &op_type,
OpConditionContext *context) const { OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
...@@ -301,7 +119,7 @@ const std::vector<DataFormat> OpRegistryBase::InputsDataFormat( ...@@ -301,7 +119,7 @@ const std::vector<DataFormat> OpRegistryBase::InputsDataFormat(
return registry_.at(op_type)->data_format_selector(context); return registry_.at(op_type)->data_format_selector(context);
} }
std::unique_ptr<Operation> OpRegistryBase::CreateOperation( std::unique_ptr<Operation> OpRegistry::CreateOperation(
OpConstructContext *context, OpConstructContext *context,
DeviceType device_type) const { DeviceType device_type) const {
auto operator_def = context->operator_def(); auto operator_def = context->operator_def();
...@@ -328,44 +146,4 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation( ...@@ -328,44 +146,4 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
return registry_.at(op_type)->creators.at(key)(context); return registry_.at(op_type)->creators.at(key)(context);
} }
OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {}
const std::string OpConditionBuilder::type() const {
return type_;
}
OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer) {
placer_ = placer;
return *this;
}
OpConditionBuilder &OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter;
return *this;
}
OpConditionBuilder &OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
}
void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr) {
if (placer_) {
info->device_placer = placer_;
}
if (memory_type_setter_) {
info->memory_type_setter = memory_type_setter_;
}
if (data_format_selector_) {
info->data_format_selector = data_format_selector_;
}
}
}
} // namespace mace } // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_REGISTRY_OPS_REGISTRY_H_
#define MACE_CORE_REGISTRY_OPS_REGISTRY_H_
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "mace/core/ops/operator.h"
#include "mace/core/ops/op_condition_builder.h"
#include "mace/core/ops/op_condition_context.h"
#include "mace/public/mace.h"
#include "mace/proto/mace.pb.h"
#include "mace/utils/memory.h"
namespace mace {
class OpRegistry {
public:
OpRegistry() = default;
virtual ~OpRegistry() = default;
MaceStatus Register(const std::string &op_type,
const DeviceType device_type,
const DataType dt,
OpRegistrationInfo::OpCreator creator);
MaceStatus Register(const OpConditionBuilder &builder);
const std::set<DeviceType> AvailableDevices(
const std::string &op_type, OpConditionContext *context) const;
void GetInOutMemoryTypes(
const std::string &op_type, OpConditionContext *context) const;
const std::vector<DataFormat> InputsDataFormat(
const std::string &op_type, OpConditionContext *context) const;
std::unique_ptr<Operation> CreateOperation(
OpConstructContext *context,
DeviceType device_type) const;
template<class DerivedType>
static std::unique_ptr<Operation> DefaultCreator(
OpConstructContext *context) {
return make_unique<DerivedType>(context);
}
private:
std::unordered_map<std::string, std::unique_ptr<OpRegistrationInfo>>
registry_;
MACE_DISABLE_COPY_AND_ASSIGN(OpRegistry);
};
#define MACE_REGISTER_OP(op_registry, op_type, class_name, device, dt) \
op_registry->Register(op_type, \
device, \
DataTypeToEnum<dt>::value, \
OpRegistry::DefaultCreator<class_name<device, dt>>)
#define MACE_REGISTER_OP_BY_CLASS(\
op_registry, op_type, class_name, device, dt) \
op_registry->Register(op_type, \
device, \
DataTypeToEnum<dt>::value, \
OpRegistry::DefaultCreator<class_name>)
#ifdef MACE_ENABLE_OPENCL
#define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name) \
op_registry->Register( \
op_type, \
DeviceType::GPU, \
DT_FLOAT, \
OpRegistry::DefaultCreator<class_name<DeviceType::GPU, float>>)
#else
#define MACE_REGISTER_GPU_OP(op_registry, op_type, class_name)
#endif
#define MACE_REGISTER_OP_CONDITION(op_registry, builder) \
op_registry->Register(builder)
} // namespace mace
#endif // MACE_CORE_REGISTRY_OPS_REGISTRY_H_
...@@ -46,7 +46,7 @@ bool HasHalfTensor(const NetDef &net_def) { ...@@ -46,7 +46,7 @@ bool HasHalfTensor(const NetDef &net_def) {
return false; return false;
} }
template <typename T> template<typename T>
void DequantizeTensor(Device *device, void DequantizeTensor(Device *device,
const unsigned char *model_data, const unsigned char *model_data,
const ConstTensor &const_tensor, const ConstTensor &const_tensor,
...@@ -66,7 +66,8 @@ void DequantizeTensor(Device *device, ...@@ -66,7 +66,8 @@ void DequantizeTensor(Device *device,
} // namespace } // namespace
Workspace::Workspace() = default; Workspace::Workspace(const OpDelegatorRegistry *registry) :
op_delegator_registry_(registry) {}
Tensor *Workspace::CreateTensor(const std::string &name, Tensor *Workspace::CreateTensor(const std::string &name,
Allocator *alloc, Allocator *alloc,
...@@ -401,4 +402,8 @@ void Workspace::RemoveTensor(const std::string &name) { ...@@ -401,4 +402,8 @@ void Workspace::RemoveTensor(const std::string &name) {
} }
} }
const OpDelegatorRegistry *Workspace::GetDelegatorRegistry() const {
return op_delegator_registry_;
}
} // namespace mace } // namespace mace
...@@ -27,13 +27,14 @@ ...@@ -27,13 +27,14 @@
namespace mace { namespace mace {
class OpDelegatorRegistry;
class MemoryOptimizer; class MemoryOptimizer;
class Workspace { class Workspace {
public: public:
typedef std::map<std::string, std::unique_ptr<Tensor>> TensorMap; typedef std::map<std::string, std::unique_ptr<Tensor>> TensorMap;
Workspace(); explicit Workspace(const OpDelegatorRegistry *registry);
~Workspace() {} ~Workspace() {}
Tensor *CreateTensor(const std::string &name, Tensor *CreateTensor(const std::string &name,
...@@ -71,15 +72,16 @@ class Workspace { ...@@ -71,15 +72,16 @@ class Workspace {
void RemoveTensor(const std::string &name); void RemoveTensor(const std::string &name);
const OpDelegatorRegistry *GetDelegatorRegistry() const;
private: private:
TensorMap tensor_map_; TensorMap tensor_map_;
std::unique_ptr<BufferBase> tensor_buffer_; std::unique_ptr<BufferBase> tensor_buffer_;
PreallocatedPooledAllocator preallocated_allocator_; PreallocatedPooledAllocator preallocated_allocator_;
bool diffused_buffer_; bool diffused_buffer_;
const OpDelegatorRegistry *op_delegator_registry_;
MACE_DISABLE_COPY_AND_ASSIGN(Workspace); MACE_DISABLE_COPY_AND_ASSIGN(Workspace);
}; };
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
#include "mace/core/device_context.h" #include "mace/core/device_context.h"
#include "mace/core/memory_optimizer.h" #include "mace/core/memory_optimizer.h"
#include "mace/core/net.h" #include "mace/core/net.h"
#include "mace/ops/registry/ops_registry.h" #include "mace/core/registry/ops_registry.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/ops/common/transpose.h" #include "mace/ops/common/transpose.h"
#include "mace/ops/registry/registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/stl_util.h" #include "mace/utils/stl_util.h"
...@@ -451,7 +453,8 @@ class MaceEngine::Impl { ...@@ -451,7 +453,8 @@ class MaceEngine::Impl {
private: private:
std::unique_ptr<port::ReadOnlyMemoryRegion> model_data_; std::unique_ptr<port::ReadOnlyMemoryRegion> model_data_;
std::unique_ptr<OpRegistryBase> op_registry_; std::unique_ptr<OpRegistry> op_registry_;
std::unique_ptr<OpDelegatorRegistry> op_delegator_registry_;
DeviceType device_type_; DeviceType device_type_;
std::unique_ptr<Device> device_; std::unique_ptr<Device> device_;
std::unique_ptr<Workspace> ws_; std::unique_ptr<Workspace> ws_;
...@@ -478,9 +481,10 @@ class MaceEngine::Impl { ...@@ -478,9 +481,10 @@ class MaceEngine::Impl {
MaceEngine::Impl::Impl(const MaceEngineConfig &config) MaceEngine::Impl::Impl(const MaceEngineConfig &config)
: model_data_(nullptr), : model_data_(nullptr),
op_registry_(new OpRegistry), op_registry_(new OpRegistry),
op_delegator_registry_(new OpDelegatorRegistry),
device_type_(config.impl_->device_type()), device_type_(config.impl_->device_type()),
device_(nullptr), device_(nullptr),
ws_(new Workspace()), ws_(new Workspace(op_delegator_registry_.get())),
net_(nullptr), net_(nullptr),
is_quantized_model_(false), is_quantized_model_(false),
thread_pool_(new utils::ThreadPool(config.impl_->num_threads(), thread_pool_(new utils::ThreadPool(config.impl_->num_threads(),
...@@ -498,6 +502,8 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config) ...@@ -498,6 +502,8 @@ MaceEngine::Impl::Impl(const MaceEngineConfig &config)
#endif #endif
{ {
LOG(INFO) << "Creating MaceEngine, MACE version: " << MaceVersion(); LOG(INFO) << "Creating MaceEngine, MACE version: " << MaceVersion();
ops::RegisterAllOps(op_registry_.get());
ops::RegisterAllOpDelegators(op_delegator_registry_.get());
thread_pool_->Init(); thread_pool_->Init();
if (device_type_ == DeviceType::CPU) { if (device_type_ == DeviceType::CPU) {
device_.reset(new CPUDevice(config.impl_->num_threads(), device_.reset(new CPUDevice(config.impl_->num_threads(),
......
...@@ -22,11 +22,13 @@ cc_library( ...@@ -22,11 +22,13 @@ cc_library(
srcs = glob( srcs = glob(
[ [
"common/*.cc", "common/*.cc",
"delegator/*.cc",
], ],
), ),
hdrs = glob( hdrs = glob(
[ [
"common/*.h", "common/*.h",
"delegator/*.h",
], ],
), ),
copts = [ copts = [
...@@ -58,12 +60,16 @@ cc_library( ...@@ -58,12 +60,16 @@ cc_library(
[ [
"ref/*.cc", "ref/*.cc",
], ],
), ) + if_quantize_enabled(glob([
"ref/q8/*.cc",
])),
hdrs = glob( hdrs = glob(
[ [
"ref/*.h", "ref/*.h",
], ],
), ) + if_quantize_enabled(glob([
"ref/q8/*.h",
])),
copts = [ copts = [
"-Werror", "-Werror",
"-Wextra", "-Wextra",
...@@ -236,12 +242,12 @@ cc_library( ...@@ -236,12 +242,12 @@ cc_library(
cc_library( cc_library(
name = "ops", name = "ops",
srcs = [ srcs = glob([
"registry/ops_registry.cc", "registry/*.cc",
], ]),
hdrs = [ hdrs = glob([
"registry/ops_registry.h", "registry/*.h",
], ]),
copts = [ copts = [
"-Werror", "-Werror",
"-Wextra", "-Wextra",
......
file(GLOB OPS_COMMON_SRCS common/*.cc) file(GLOB OPS_COMMON_SRCS common/*.cc)
file(GLOB OPS_REF_KERNELS_SRCS ref/*.cc) file(GLOB OPS_REF_KERNELS_SRCS ref/*.cc)
file(GLOB OPS_REF_Q8_KERNELS_SRCS
ref/q8/*.cc
)
file(GLOB OPS_ARM_NEON_FP32_KERNELS_SRCS file(GLOB OPS_ARM_NEON_FP32_KERNELS_SRCS
arm/fp32/*.cc arm/fp32/*.cc
) )
...@@ -17,19 +21,22 @@ file(GLOB OPS_OPENCL_KERNELS_SRCS ...@@ -17,19 +21,22 @@ file(GLOB OPS_OPENCL_KERNELS_SRCS
file(GLOB OPS_INTERNAL_OPS_SRCS *.cc) file(GLOB OPS_INTERNAL_OPS_SRCS *.cc)
set(OPS_SRCS registry/ops_registry.cc) set(OPS_SRCS registry/ops_registry.cc registry/op_delegators_registry.cc)
set(OPS_SRCS ${OPS_SRCS} ${OPS_COMMON_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_COMMON_SRCS})
set(OPS_SRCS ${OPS_SRCS} ${OPS_INTERNAL_OPS_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_INTERNAL_OPS_SRCS})
# TODO we need to remove this in production build # TODO we need to remove this in production build
set(OPS_SRCS ${OPS_SRCS} ${OPS_REF_KERNELS_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_REF_KERNELS_SRCS})
if(MACE_ENABLE_QUANTIZE)
set(OPS_SRCS ${OPS_SRCS} ${OPS_REF_Q8_KERNELS_SRCS})
endif(MACE_ENABLE_QUANTIZE)
if(MACE_ENABLE_NEON) if(MACE_ENABLE_NEON)
set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_FP32_KERNELS_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_FP32_KERNELS_SRCS})
endif(MACE_ENABLE_NEON) if(MACE_ENABLE_QUANTIZE)
if(MACE_ENABLE_QUANTIZE)
set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_Q8_KERNELS_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_Q8_KERNELS_SRCS})
endif(MACE_ENABLE_QUANTIZE) endif(MACE_ENABLE_QUANTIZE)
endif(MACE_ENABLE_NEON)
if(MACE_ENABLE_OPENCL) if(MACE_ENABLE_OPENCL)
set(OPS_SRCS ${OPS_SRCS} ${OPS_OPENCL_KERNELS_SRCS}) set(OPS_SRCS ${OPS_SRCS} ${OPS_OPENCL_KERNELS_SRCS})
......
...@@ -17,13 +17,10 @@ ...@@ -17,13 +17,10 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#if defined(MACE_ENABLE_NEON) #include "mace/ops/delegator/activation.h"
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/activation.h"
#endif
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
...@@ -37,19 +34,20 @@ namespace ops { ...@@ -37,19 +34,20 @@ namespace ops {
template<DeviceType D, class T> template<DeviceType D, class T>
class ActivationOp; class ActivationOp;
template<> template<typename T>
class ActivationOp<DeviceType::CPU, float> : public Operation { class ActivationOp<DeviceType::CPU, T> : public Operation {
public: public:
explicit ActivationOp(OpConstructContext *context) explicit ActivationOp(OpConstructContext *context)
: Operation(context), : Operation(context),
activation_type_(ops::StringToActivationType( activation_type_(ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation", Operation::GetOptionalArg<std::string>("activation", "NOOP"))),
"NOOP"))), activation_delegator_(delegator::Activation::Create(
activation_delegator_(activation_type_, context->workspace(),
Operation::GetOptionalArg<float>("max_limit", MACE_DELEGATOR_KEY(Activation, CPU, T, MACE_CPU_IMPL_TYPE),
0.0f), delegator::ActivationParam(
Operation::GetOptionalArg<float>( activation_type_,
"leakyrelu_coefficient", 0.0f)) {} Operation::GetOptionalArg<T>("max_limit", 0),
Operation::GetOptionalArg<T>("leakyrelu_coefficient", 0)))) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -58,28 +56,24 @@ class ActivationOp<DeviceType::CPU, float> : public Operation { ...@@ -58,28 +56,24 @@ class ActivationOp<DeviceType::CPU, float> : public Operation {
if (activation_type_ == PRELU) { if (activation_type_ == PRELU) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
const float *input_ptr = input->data<float>(); const T *input_ptr = input->data<T>();
float *output_ptr = output->mutable_data<float>(); T *output_ptr = output->mutable_data<T>();
MACE_CHECK(this->InputSize() > 1); MACE_CHECK(this->InputSize() > 1);
const Tensor *alpha = this->Input(1); const Tensor *alpha = this->Input(1);
const float *alpha_ptr = alpha->data<float>(); const T *alpha_ptr = alpha->data<T>();
const index_t outer_size = output->dim(0); const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3); const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(context, input_ptr, outer_size, input->dim(1), inner_size, PReLUActivation(context, input_ptr, outer_size, input->dim(1), inner_size,
alpha_ptr, output_ptr); alpha_ptr, output_ptr);
} else { } else {
activation_delegator_.Compute(context, input, output); activation_delegator_->Compute(context, input, output);
} }
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
ActivationType activation_type_; ActivationType activation_type_;
#if defined(MACE_ENABLE_NEON) std::unique_ptr<delegator::Activation> activation_delegator_;
arm::fp32::Activation activation_delegator_;
#else
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -122,7 +116,7 @@ class ActivationOp<DeviceType::GPU, float> : public Operation { ...@@ -122,7 +116,7 @@ class ActivationOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterActivation(OpRegistryBase *op_registry) { void RegisterActivation(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Activation", ActivationOp, MACE_REGISTER_OP(op_registry, "Activation", ActivationOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "Activation", ActivationOp); MACE_REGISTER_GPU_OP(op_registry, "Activation", ActivationOp);
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <string> #include <string>
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/addn.h" #include "mace/ops/opencl/image/addn.h"
...@@ -92,7 +93,7 @@ class AddNOp<DeviceType::GPU, float> : public Operation { ...@@ -92,7 +93,7 @@ class AddNOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterAddN(OpRegistryBase *op_registry) { void RegisterAddN(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "AddN", AddNOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "AddN", AddNOp, DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "AddN", AddNOp); MACE_REGISTER_GPU_OP(op_registry, "AddN", AddNOp);
MACE_REGISTER_OP_CONDITION( MACE_REGISTER_OP_CONDITION(
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -109,7 +110,7 @@ class ArgMaxOp : public Operation { ...@@ -109,7 +110,7 @@ class ArgMaxOp : public Operation {
void RegisterArgMax(OpRegistryBase *op_registry) { void RegisterArgMax(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ArgMax", ArgMaxOp, MACE_REGISTER_OP(op_registry, "ArgMax", ArgMaxOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/activation.h" #include "mace/ops/delegator/activation.h"
#include <arm_neon.h> #include <arm_neon.h>
#include <algorithm> #include <algorithm>
...@@ -22,16 +22,22 @@ namespace ops { ...@@ -22,16 +22,22 @@ namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
Activation::Activation(ActivationType type, class Activation : public delegator::Activation {
const float limit, public:
const float leakyrelu_coefficient) explicit Activation(const delegator::ActivationParam &param)
: type_(type), : delegator::Activation(param) {}
limit_(limit), ~Activation() = default;
leakyrelu_coefficient_(leakyrelu_coefficient) {}
MaceStatus Compute(const OpContext *context,
const Tensor *input, Tensor *output) override;
private:
void DoActivation(const OpContext *context,
const Tensor *input, Tensor *output);
};
MaceStatus Activation::Compute(const OpContext *context, MaceStatus Activation::Compute(const OpContext *context,
const Tensor *input, const Tensor *input, Tensor *output) {
Tensor *output) {
Tensor::MappingGuard input_guard(input); Tensor::MappingGuard input_guard(input);
if (input != output) { if (input != output) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input)); MACE_RETURN_IF_ERROR(output->ResizeLike(input));
...@@ -169,14 +175,19 @@ void Activation::DoActivation(const OpContext *context, ...@@ -169,14 +175,19 @@ void Activation::DoActivation(const OpContext *context,
break; break;
} }
case NOOP: case NOOP: {
break; break;
}
default: default: {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
}
} }
MACE_REGISTER_DELEGATOR(registry, Activation, delegator::ActivationParam,
MACE_DELEGATOR_KEY(Activation, CPU, float, NEON))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -12,15 +12,27 @@ ...@@ -12,15 +12,27 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/bias_add.h"
#include <arm_neon.h> #include <arm_neon.h>
#include "mace/ops/delegator/bias_add.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class BiasAdd : public delegator::BiasAdd {
public:
explicit BiasAdd(const DelegatorParam &param) : delegator::BiasAdd(param) {}
~BiasAdd() = default;
MaceStatus Compute(const OpContext *context, const Tensor *input,
const Tensor *bias, Tensor *output) override;
private:
void AddBias(const OpContext *context, const Tensor *input,
const Tensor *bias, Tensor *output);
};
MaceStatus BiasAdd::Compute(const OpContext *context, MaceStatus BiasAdd::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *bias, const Tensor *bias,
...@@ -117,6 +129,9 @@ void BiasAdd::AddBias(const OpContext *context, ...@@ -117,6 +129,9 @@ void BiasAdd::AddBias(const OpContext *context,
} }
} }
MACE_REGISTER_DELEGATOR(registry, BiasAdd, DelegatorParam,
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, NEON))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,36 +18,25 @@ ...@@ -18,36 +18,25 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h" #include "mace/ops/delegator/conv_2d.h"
#include "mace/ops/arm/fp32/gemm.h" #include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Conv2dBase { class Conv2dBase : public delegator::Conv2d {
public: public:
Conv2dBase(const std::vector<int> &strides, explicit Conv2dBase(const delegator::Conv2dParam &param)
const std::vector<int> &dilations, : delegator::Conv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
virtual ~Conv2dBase() = default; virtual ~Conv2dBase() = default;
virtual MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) = 0;
protected: protected:
void CalOutputShapeAndInputPadSize(const std::vector<index_t> &input_shape, void CalOutputShapeAndInputPadSize(const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape, const std::vector<index_t> &filter_shape,
...@@ -83,11 +72,6 @@ class Conv2dBase { ...@@ -83,11 +72,6 @@ class Conv2dBase {
const int pad_left, const int pad_left,
Tensor *dst); Tensor *dst);
void UnPadOutput(const Tensor &src, Tensor *dst); void UnPadOutput(const Tensor &src, Tensor *dst);
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
}; };
} // namespace fp32 } // namespace fp32
......
...@@ -12,13 +12,32 @@ ...@@ -12,13 +12,32 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/conv_2d_1x1.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Conv2dK1x1 : public Conv2dBase {
public:
explicit Conv2dK1x1(const delegator::Conv2dParam &param)
: Conv2dBase(param),
gemm_(delegator::GemmParam()) {}
virtual ~Conv2dK1x1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) override;
private:
Gemm gemm_;
};
MaceStatus Conv2dK1x1::Compute(const OpContext *context, MaceStatus Conv2dK1x1::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
...@@ -94,6 +113,9 @@ MaceStatus Conv2dK1x1::Compute(const OpContext *context, ...@@ -94,6 +113,9 @@ MaceStatus Conv2dK1x1::Compute(const OpContext *context,
output); output);
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK1x1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K1x1))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <arm_neon.h> #include <arm_neon.h>
#include <memory> #include <memory>
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
...@@ -859,6 +861,19 @@ MaceStatus Conv2dK15x1S1::Compute(const OpContext *context, ...@@ -859,6 +861,19 @@ MaceStatus Conv2dK15x1S1::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK1x7S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K1x7S1))
MACE_REGISTER_DELEGATOR(registry, Conv2dK7x1S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K7x1S1))
MACE_REGISTER_DELEGATOR(registry, Conv2dK1x15S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
NEON, K1x15S1))
MACE_REGISTER_DELEGATOR(registry, Conv2dK15x1S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
NEON, K15x1S1))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
#define MACE_OPS_ARM_FP32_CONV_2D_1XN_H_ #define MACE_OPS_ARM_FP32_CONV_2D_1XN_H_
#include <vector> #include <vector>
#include "mace/public/mace.h"
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -28,8 +29,8 @@ namespace fp32 { ...@@ -28,8 +29,8 @@ namespace fp32 {
class Conv2dK1x7S1 : public Conv2dBase { class Conv2dK1x7S1 : public Conv2dBase {
public: public:
Conv2dK1x7S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK1x7S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK1x7S1() {} virtual ~Conv2dK1x7S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -41,8 +42,8 @@ class Conv2dK1x7S1 : public Conv2dBase { ...@@ -41,8 +42,8 @@ class Conv2dK1x7S1 : public Conv2dBase {
class Conv2dK7x1S1 : public Conv2dBase { class Conv2dK7x1S1 : public Conv2dBase {
public: public:
Conv2dK7x1S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK7x1S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK7x1S1() {} virtual ~Conv2dK7x1S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -54,8 +55,8 @@ class Conv2dK7x1S1 : public Conv2dBase { ...@@ -54,8 +55,8 @@ class Conv2dK7x1S1 : public Conv2dBase {
class Conv2dK1x15S1 : public Conv2dBase { class Conv2dK1x15S1 : public Conv2dBase {
public: public:
Conv2dK1x15S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK1x15S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK1x15S1() {} virtual ~Conv2dK1x15S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -67,8 +68,8 @@ class Conv2dK1x15S1 : public Conv2dBase { ...@@ -67,8 +68,8 @@ class Conv2dK1x15S1 : public Conv2dBase {
class Conv2dK15x1S1 : public Conv2dBase { class Conv2dK15x1S1 : public Conv2dBase {
public: public:
Conv2dK15x1S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK15x1S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK15x1S1() {} virtual ~Conv2dK15x1S1() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <arm_neon.h> #include <arm_neon.h>
#include <memory> #include <memory>
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
...@@ -735,6 +737,11 @@ MaceStatus Conv2dK3x3S2::Compute(const OpContext *context, ...@@ -735,6 +737,11 @@ MaceStatus Conv2dK3x3S2::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK3x3S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K3x3S1))
MACE_REGISTER_DELEGATOR(registry, Conv2dK3x3S2, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K3x3S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
#define MACE_OPS_ARM_FP32_CONV_2D_3X3_H_ #define MACE_OPS_ARM_FP32_CONV_2D_3X3_H_
#include <vector> #include <vector>
#include "mace/public/mace.h"
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -28,8 +29,8 @@ namespace fp32 { ...@@ -28,8 +29,8 @@ namespace fp32 {
class Conv2dK3x3S1 : public Conv2dBase { class Conv2dK3x3S1 : public Conv2dBase {
public: public:
Conv2dK3x3S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK3x3S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK3x3S1() {} virtual ~Conv2dK3x3S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -41,8 +42,8 @@ class Conv2dK3x3S1 : public Conv2dBase { ...@@ -41,8 +42,8 @@ class Conv2dK3x3S1 : public Conv2dBase {
class Conv2dK3x3S2 : public Conv2dBase { class Conv2dK3x3S2 : public Conv2dBase {
public: public:
Conv2dK3x3S2(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK3x3S2(const delegator::Conv2dParam &param)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK3x3S2() {} virtual ~Conv2dK3x3S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/conv_2d.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
...@@ -800,6 +801,10 @@ void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, ...@@ -800,6 +801,10 @@ void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context,
}, 0, batch, 1, 0, out_channels, 1); }, 0, batch, 1, 0, out_channels, 1);
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK3x3Winograd, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(
Conv2d, CPU, float, NEON, K3x3Winograd))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/conv_2d.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -31,10 +31,9 @@ namespace fp32 { ...@@ -31,10 +31,9 @@ namespace fp32 {
class Conv2dK3x3Winograd : public Conv2dBase { class Conv2dK3x3Winograd : public Conv2dBase {
public: public:
Conv2dK3x3Winograd(const std::vector<int> &paddings, explicit Conv2dK3x3Winograd(const delegator::Conv2dParam &param)
const Padding padding_type) : Conv2dBase(param),
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type), gemm_(delegator::GemmParam()),
gemm_(),
transformed_filter_(nullptr), transformed_filter_(nullptr),
out_tile_size_(0) {} out_tile_size_(0) {}
......
...@@ -12,16 +12,30 @@ ...@@ -12,16 +12,30 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/conv_2d_5x5.h"
#include <arm_neon.h> #include <arm_neon.h>
#include <memory> #include <memory>
#include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Conv2dK5x5S1 : public Conv2dBase {
public:
explicit Conv2dK5x5S1(const delegator::Conv2dParam &param)
: Conv2dBase(param) {}
virtual ~Conv2dK5x5S1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) override;
};
#define MACE_Conv2dNeonK5x5SnLoadCalc4 \ #define MACE_Conv2dNeonK5x5SnLoadCalc4 \
/* load filter (4 outch x 1 height x 4 width) */ \ /* load filter (4 outch x 1 height x 4 width) */ \
float32x4_t vf00, vf10, vf20, vf30; \ float32x4_t vf00, vf10, vf20, vf30; \
...@@ -244,6 +258,9 @@ MaceStatus Conv2dK5x5S1::Compute(const OpContext *context, ...@@ -244,6 +258,9 @@ MaceStatus Conv2dK5x5S1::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK5x5S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K5x5S1))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_CONV_2D_5X5_H_
#define MACE_OPS_ARM_FP32_CONV_2D_5X5_H_
#include <vector>
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Conv2dK5x5S1 : public Conv2dBase {
public:
Conv2dK5x5S1(const std::vector<int> &paddings, const Padding padding_type)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~Conv2dK5x5S1() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) override;
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_CONV_2D_5X5_H_
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <arm_neon.h> #include <arm_neon.h>
#include <memory> #include <memory>
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
...@@ -720,6 +722,13 @@ MaceStatus Conv2dK7x7S3::Compute(const OpContext *context, ...@@ -720,6 +722,13 @@ MaceStatus Conv2dK7x7S3::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Conv2dK7x7S1, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K7x7S1))
MACE_REGISTER_DELEGATOR(registry, Conv2dK7x7S2, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K7x7S2))
MACE_REGISTER_DELEGATOR(registry, Conv2dK7x7S3, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, K7x7S3))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -16,10 +16,11 @@ ...@@ -16,10 +16,11 @@
#define MACE_OPS_ARM_FP32_CONV_2D_7X7_H_ #define MACE_OPS_ARM_FP32_CONV_2D_7X7_H_
#include <vector> #include <vector>
#include "mace/public/mace.h"
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -28,8 +29,8 @@ namespace fp32 { ...@@ -28,8 +29,8 @@ namespace fp32 {
class Conv2dK7x7S1 : public Conv2dBase { class Conv2dK7x7S1 : public Conv2dBase {
public: public:
Conv2dK7x7S1(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK7x7S1(const delegator::Conv2dParam &param)
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK7x7S1() {} virtual ~Conv2dK7x7S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -41,8 +42,8 @@ class Conv2dK7x7S1 : public Conv2dBase { ...@@ -41,8 +42,8 @@ class Conv2dK7x7S1 : public Conv2dBase {
class Conv2dK7x7S2 : public Conv2dBase { class Conv2dK7x7S2 : public Conv2dBase {
public: public:
Conv2dK7x7S2(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK7x7S2(const delegator::Conv2dParam &param)
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK7x7S2() {} virtual ~Conv2dK7x7S2() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -54,8 +55,8 @@ class Conv2dK7x7S2 : public Conv2dBase { ...@@ -54,8 +55,8 @@ class Conv2dK7x7S2 : public Conv2dBase {
class Conv2dK7x7S3 : public Conv2dBase { class Conv2dK7x7S3 : public Conv2dBase {
public: public:
Conv2dK7x7S3(const std::vector<int> &paddings, const Padding padding_type) explicit Conv2dK7x7S3(const delegator::Conv2dParam &param)
: Conv2dBase({3, 3}, {1, 1}, paddings, padding_type) {} : Conv2dBase(param) {}
virtual ~Conv2dK7x7S3() {} virtual ~Conv2dK7x7S3() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -12,15 +12,30 @@ ...@@ -12,15 +12,30 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/conv_general.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include <memory> #include <memory>
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Conv2dGeneral : public Conv2dBase {
public:
explicit Conv2dGeneral(const delegator::Conv2dParam &param)
: Conv2dBase(param) {}
virtual ~Conv2dGeneral() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) override;
};
MaceStatus Conv2dGeneral::Compute(const OpContext *context, MaceStatus Conv2dGeneral::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
...@@ -237,6 +252,10 @@ MaceStatus Conv2dGeneral::Compute(const OpContext *context, ...@@ -237,6 +252,10 @@ MaceStatus Conv2dGeneral::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(
registry, Conv2dGeneral, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, NEON, General))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,54 +18,27 @@ ...@@ -18,54 +18,27 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h" #include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Deconv2dBase { class Deconv2dBase : public delegator::Deconv2d {
public: public:
Deconv2dBase(const std::vector<int> &strides, explicit Deconv2dBase(const delegator::Deconv2dParam &param)
const std::vector<int> &dilations, : delegator::Deconv2d(param),
const std::vector<int> &paddings, group_(param.group_) {}
const Padding padding_type,
const index_t group,
const FrameworkType framework_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type),
group_(group),
framework_type_(framework_type) {}
Deconv2dBase(const std::vector<int> &strides,
const std::vector<int> &dilations,
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type)
: Deconv2dBase(strides,
dilations,
paddings,
padding_type,
1,
framework_type) {}
virtual ~Deconv2dBase() = default; virtual ~Deconv2dBase() = default;
virtual MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *output_shape,
Tensor *output) = 0;
protected: protected:
MaceStatus ResizeOutAndPadOut(const OpContext *context, MaceStatus ResizeOutAndPadOut(const OpContext *context,
const Tensor *input, const Tensor *input,
...@@ -78,13 +51,7 @@ class Deconv2dBase { ...@@ -78,13 +51,7 @@ class Deconv2dBase {
void UnPadOutput(const Tensor &src, void UnPadOutput(const Tensor &src,
const std::vector<int> &out_pad_size, const std::vector<int> &out_pad_size,
Tensor *dst); Tensor *dst);
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
index_t group_; index_t group_;
const FrameworkType framework_type_;
}; };
} // namespace fp32 } // namespace fp32
......
...@@ -330,12 +330,18 @@ MaceStatus Deconv2dK2x2S2::Compute(const OpContext *context, ...@@ -330,12 +330,18 @@ MaceStatus Deconv2dK2x2S2::Compute(const OpContext *context,
} }
}, 0, batch, 1, 0, outch, 1); }, 0, batch, 1, 0, outch, 1);
UnPadOutput(*out_tensor, out_pad_size, output); UnPadOutput(*out_tensor, out_pad_size, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Deconv2dK2x2S1, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K2x2S1))
MACE_REGISTER_DELEGATOR(registry, Deconv2dK2x2S2, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K2x2S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,10 +32,8 @@ namespace fp32 { ...@@ -32,10 +32,8 @@ namespace fp32 {
class Deconv2dK2x2S1 : public Deconv2dBase { class Deconv2dK2x2S1 : public Deconv2dBase {
public: public:
Deconv2dK2x2S1(const std::vector<int> &paddings, explicit Deconv2dK2x2S1(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({1, 1}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK2x2S1() {} virtual ~Deconv2dK2x2S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -48,10 +46,8 @@ class Deconv2dK2x2S1 : public Deconv2dBase { ...@@ -48,10 +46,8 @@ class Deconv2dK2x2S1 : public Deconv2dBase {
class Deconv2dK2x2S2 : public Deconv2dBase { class Deconv2dK2x2S2 : public Deconv2dBase {
public: public:
Deconv2dK2x2S2(const std::vector<int> &paddings, explicit Deconv2dK2x2S2(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({2, 2}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK2x2S2() {} virtual ~Deconv2dK2x2S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -464,6 +464,13 @@ MaceStatus Deconv2dK3x3S2::Compute(const OpContext *context, ...@@ -464,6 +464,13 @@ MaceStatus Deconv2dK3x3S2::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Deconv2dK3x3S1, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K3x3S1))
MACE_REGISTER_DELEGATOR(registry, Deconv2dK3x3S2, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K3x3S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,10 +32,8 @@ namespace fp32 { ...@@ -32,10 +32,8 @@ namespace fp32 {
class Deconv2dK3x3S1 : public Deconv2dBase { class Deconv2dK3x3S1 : public Deconv2dBase {
public: public:
Deconv2dK3x3S1(const std::vector<int> &paddings, explicit Deconv2dK3x3S1(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({1, 1}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK3x3S1() {} virtual ~Deconv2dK3x3S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -48,10 +46,8 @@ class Deconv2dK3x3S1 : public Deconv2dBase { ...@@ -48,10 +46,8 @@ class Deconv2dK3x3S1 : public Deconv2dBase {
class Deconv2dK3x3S2 : public Deconv2dBase { class Deconv2dK3x3S2 : public Deconv2dBase {
public: public:
Deconv2dK3x3S2(const std::vector<int> &paddings, explicit Deconv2dK3x3S2(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({2, 2}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK3x3S2() {} virtual ~Deconv2dK3x3S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -449,7 +449,6 @@ MaceStatus Deconv2dK4x4S2::Compute(const OpContext *context, ...@@ -449,7 +449,6 @@ MaceStatus Deconv2dK4x4S2::Compute(const OpContext *context,
const index_t outw = out_shape[3]; const index_t outw = out_shape[3];
const index_t out_img_size = outh * outw; const index_t out_img_size = outh * outw;
utils::ThreadPool utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool(); &thread_pool = context->device()->cpu_runtime()->thread_pool();
...@@ -575,6 +574,13 @@ MaceStatus Deconv2dK4x4S2::Compute(const OpContext *context, ...@@ -575,6 +574,13 @@ MaceStatus Deconv2dK4x4S2::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Deconv2dK4x4S1, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K4x4S1))
MACE_REGISTER_DELEGATOR(registry, Deconv2dK4x4S2, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, K4x4S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,10 +32,8 @@ namespace fp32 { ...@@ -32,10 +32,8 @@ namespace fp32 {
class Deconv2dK4x4S1 : public Deconv2dBase { class Deconv2dK4x4S1 : public Deconv2dBase {
public: public:
Deconv2dK4x4S1(const std::vector<int> &paddings, explicit Deconv2dK4x4S1(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({1, 1}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK4x4S1() {} virtual ~Deconv2dK4x4S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -48,10 +46,8 @@ class Deconv2dK4x4S1 : public Deconv2dBase { ...@@ -48,10 +46,8 @@ class Deconv2dK4x4S1 : public Deconv2dBase {
class Deconv2dK4x4S2 : public Deconv2dBase { class Deconv2dK4x4S2 : public Deconv2dBase {
public: public:
Deconv2dK4x4S2(const std::vector<int> &paddings, explicit Deconv2dK4x4S2(const delegator::Deconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({2, 2}, {1, 1}, paddings, padding_type, framework_type) {}
virtual ~Deconv2dK4x4S2() {} virtual ~Deconv2dK4x4S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/fp32/deconv_2d_general.h" #include "mace/ops/arm/fp32/deconv_2d.h"
// TODO(liutuo): optimize it // TODO(liutuo): optimize it
...@@ -21,6 +21,20 @@ namespace ops { ...@@ -21,6 +21,20 @@ namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Deconv2dGeneral : public Deconv2dBase {
public:
explicit Deconv2dGeneral(const delegator::Deconv2dParam &param)
: Deconv2dBase(param) {}
virtual ~Deconv2dGeneral() {}
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *output_shape,
Tensor *output) override;
};
MaceStatus Deconv2dGeneral::Compute(const OpContext *context, MaceStatus Deconv2dGeneral::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
...@@ -110,6 +124,10 @@ MaceStatus Deconv2dGeneral::Compute(const OpContext *context, ...@@ -110,6 +124,10 @@ MaceStatus Deconv2dGeneral::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Deconv2dGeneral, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
NEON, General))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -512,6 +512,13 @@ MaceStatus DepthwiseConv2dK3x3S2::Compute(const mace::OpContext *context, ...@@ -512,6 +512,13 @@ MaceStatus DepthwiseConv2dK3x3S2::Compute(const mace::OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(
registry, DepthwiseConv2dK3x3S1, delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float, NEON, K3x3S1))
MACE_REGISTER_DELEGATOR(
registry, DepthwiseConv2dK3x3S2, delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float, NEON, K3x3S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -16,10 +16,12 @@ ...@@ -16,10 +16,12 @@
#define MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_ #define MACE_OPS_ARM_FP32_DEPTHWISE_CONV_2D_3X3_H_
#include <vector> #include <vector>
#include "mace/public/mace.h"
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/conv_2d.h" #include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/delegator/depthwise_conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -28,9 +30,8 @@ namespace fp32 { ...@@ -28,9 +30,8 @@ namespace fp32 {
class DepthwiseConv2dK3x3S1 : public Conv2dBase { class DepthwiseConv2dK3x3S1 : public Conv2dBase {
public: public:
DepthwiseConv2dK3x3S1(const std::vector<int> &paddings, explicit DepthwiseConv2dK3x3S1(const delegator::DepthwiseConv2dParam &param)
const Padding padding_type) : Conv2dBase(param) {}
: Conv2dBase({1, 1}, {1, 1}, paddings, padding_type) {}
virtual ~DepthwiseConv2dK3x3S1() {} virtual ~DepthwiseConv2dK3x3S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -42,9 +43,8 @@ class DepthwiseConv2dK3x3S1 : public Conv2dBase { ...@@ -42,9 +43,8 @@ class DepthwiseConv2dK3x3S1 : public Conv2dBase {
class DepthwiseConv2dK3x3S2 : public Conv2dBase { class DepthwiseConv2dK3x3S2 : public Conv2dBase {
public: public:
DepthwiseConv2dK3x3S2(const std::vector<int> &paddings, explicit DepthwiseConv2dK3x3S2(const delegator::DepthwiseConv2dParam &param)
const Padding padding_type) : Conv2dBase(param) {}
: Conv2dBase({2, 2}, {1, 1}, paddings, padding_type) {}
virtual ~DepthwiseConv2dK3x3S2() {} virtual ~DepthwiseConv2dK3x3S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -776,6 +776,20 @@ MaceStatus GroupDeconv2dK3x3S2::Compute(const OpContext *context, ...@@ -776,6 +776,20 @@ MaceStatus GroupDeconv2dK3x3S2::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dK3x3S1, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, NEON, K3x3S1))
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dK3x3S2, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, NEON, K3x3S2))
MACE_REGISTER_DELEGATOR(
registry, GroupDeconv2dK3x3S1, delegator::GroupDeconv2dParam,
MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float, NEON, K3x3S1))
MACE_REGISTER_DELEGATOR(
registry, GroupDeconv2dK3x3S2, delegator::GroupDeconv2dParam,
MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float, NEON, K3x3S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,13 @@ ...@@ -18,12 +18,13 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/depthwise_deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,14 +33,9 @@ namespace fp32 { ...@@ -32,14 +33,9 @@ namespace fp32 {
class DepthwiseDeconv2dK3x3S1 : public Deconv2dBase { class DepthwiseDeconv2dK3x3S1 : public Deconv2dBase {
public: public:
DepthwiseDeconv2dK3x3S1(const std::vector<int> &paddings, explicit DepthwiseDeconv2dK3x3S1(
const Padding padding_type, const delegator::DepthwiseDeconv2dParam &param)
const FrameworkType framework_type) : Deconv2dBase(param) {}
: Deconv2dBase({1, 1},
{1, 1},
paddings,
padding_type,
framework_type) {}
virtual ~DepthwiseDeconv2dK3x3S1() {} virtual ~DepthwiseDeconv2dK3x3S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -52,14 +48,9 @@ class DepthwiseDeconv2dK3x3S1 : public Deconv2dBase { ...@@ -52,14 +48,9 @@ class DepthwiseDeconv2dK3x3S1 : public Deconv2dBase {
class DepthwiseDeconv2dK3x3S2 : public Deconv2dBase { class DepthwiseDeconv2dK3x3S2 : public Deconv2dBase {
public: public:
DepthwiseDeconv2dK3x3S2(const std::vector<int> &paddings, explicit DepthwiseDeconv2dK3x3S2(
const Padding padding_type, const delegator::DepthwiseDeconv2dParam &param)
const FrameworkType framework_type) : Deconv2dBase(param) {}
: Deconv2dBase({2, 2},
{1, 1},
paddings,
padding_type,
framework_type) {}
virtual ~DepthwiseDeconv2dK3x3S2() {} virtual ~DepthwiseDeconv2dK3x3S2() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -72,16 +63,9 @@ class DepthwiseDeconv2dK3x3S2 : public Deconv2dBase { ...@@ -72,16 +63,9 @@ class DepthwiseDeconv2dK3x3S2 : public Deconv2dBase {
class GroupDeconv2dK3x3S1 : public Deconv2dBase { class GroupDeconv2dK3x3S1 : public Deconv2dBase {
public: public:
GroupDeconv2dK3x3S1(const std::vector<int> &paddings, explicit GroupDeconv2dK3x3S1(
const Padding padding_type, const delegator::GroupDeconv2dParam &param)
const int group, : Deconv2dBase(param) {}
const FrameworkType framework_type)
: Deconv2dBase({1, 1},
{1, 1},
paddings,
padding_type,
group,
framework_type) {}
virtual ~GroupDeconv2dK3x3S1() {} virtual ~GroupDeconv2dK3x3S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -94,16 +78,8 @@ class GroupDeconv2dK3x3S1 : public Deconv2dBase { ...@@ -94,16 +78,8 @@ class GroupDeconv2dK3x3S1 : public Deconv2dBase {
class GroupDeconv2dK3x3S2 : public Deconv2dBase { class GroupDeconv2dK3x3S2 : public Deconv2dBase {
public: public:
GroupDeconv2dK3x3S2(const std::vector<int> &paddings, explicit GroupDeconv2dK3x3S2(const delegator::GroupDeconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const int group,
const FrameworkType framework_type)
: Deconv2dBase({2, 2},
{1, 1},
paddings,
padding_type,
group,
framework_type) {}
virtual ~GroupDeconv2dK3x3S2() {} virtual ~GroupDeconv2dK3x3S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -959,6 +959,20 @@ MaceStatus GroupDeconv2dK4x4S2::Compute(const OpContext *context, ...@@ -959,6 +959,20 @@ MaceStatus GroupDeconv2dK4x4S2::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dK4x4S1, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, NEON, K4x4S1))
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dK4x4S2, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, NEON, K4x4S2))
MACE_REGISTER_DELEGATOR(
registry, GroupDeconv2dK4x4S1, delegator::GroupDeconv2dParam,
MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float, NEON, K4x4S1))
MACE_REGISTER_DELEGATOR(
registry, GroupDeconv2dK4x4S2, delegator::GroupDeconv2dParam,
MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float, NEON, K4x4S2))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,13 @@ ...@@ -18,12 +18,13 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/depthwise_deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,14 +33,9 @@ namespace fp32 { ...@@ -32,14 +33,9 @@ namespace fp32 {
class DepthwiseDeconv2dK4x4S1 : public Deconv2dBase { class DepthwiseDeconv2dK4x4S1 : public Deconv2dBase {
public: public:
DepthwiseDeconv2dK4x4S1(const std::vector<int> &paddings, explicit DepthwiseDeconv2dK4x4S1(
const Padding padding_type, const delegator::DepthwiseDeconv2dParam &param)
const FrameworkType framework_type) : Deconv2dBase(param) {}
: Deconv2dBase({1, 1},
{1, 1},
paddings,
padding_type,
framework_type) {}
virtual ~DepthwiseDeconv2dK4x4S1() {} virtual ~DepthwiseDeconv2dK4x4S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -52,14 +48,9 @@ class DepthwiseDeconv2dK4x4S1 : public Deconv2dBase { ...@@ -52,14 +48,9 @@ class DepthwiseDeconv2dK4x4S1 : public Deconv2dBase {
class DepthwiseDeconv2dK4x4S2 : public Deconv2dBase { class DepthwiseDeconv2dK4x4S2 : public Deconv2dBase {
public: public:
DepthwiseDeconv2dK4x4S2(const std::vector<int> &paddings, explicit DepthwiseDeconv2dK4x4S2(
const Padding padding_type, const delegator::DepthwiseDeconv2dParam &param)
const FrameworkType framework_type) : Deconv2dBase(param) {}
: Deconv2dBase({2, 2},
{1, 1},
paddings,
padding_type,
framework_type) {}
virtual ~DepthwiseDeconv2dK4x4S2() {} virtual ~DepthwiseDeconv2dK4x4S2() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -72,16 +63,8 @@ class DepthwiseDeconv2dK4x4S2 : public Deconv2dBase { ...@@ -72,16 +63,8 @@ class DepthwiseDeconv2dK4x4S2 : public Deconv2dBase {
class GroupDeconv2dK4x4S1 : public Deconv2dBase { class GroupDeconv2dK4x4S1 : public Deconv2dBase {
public: public:
GroupDeconv2dK4x4S1(const std::vector<int> &paddings, explicit GroupDeconv2dK4x4S1(const delegator::GroupDeconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const int group,
const FrameworkType framework_type)
: Deconv2dBase({1, 1},
{1, 1},
paddings,
padding_type,
group,
framework_type) {}
virtual ~GroupDeconv2dK4x4S1() {} virtual ~GroupDeconv2dK4x4S1() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -94,16 +77,8 @@ class GroupDeconv2dK4x4S1 : public Deconv2dBase { ...@@ -94,16 +77,8 @@ class GroupDeconv2dK4x4S1 : public Deconv2dBase {
class GroupDeconv2dK4x4S2 : public Deconv2dBase { class GroupDeconv2dK4x4S2 : public Deconv2dBase {
public: public:
GroupDeconv2dK4x4S2(const std::vector<int> &paddings, explicit GroupDeconv2dK4x4S2(const delegator::GroupDeconv2dParam &param)
const Padding padding_type, : Deconv2dBase(param) {}
const int group,
const FrameworkType framework_type)
: Deconv2dBase({2, 2},
{1, 1},
paddings,
padding_type,
group,
framework_type) {}
virtual ~GroupDeconv2dK4x4S2() {} virtual ~GroupDeconv2dK4x4S2() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -207,6 +207,14 @@ MaceStatus GroupDeconv2dGeneral::Compute(const OpContext *context, ...@@ -207,6 +207,14 @@ MaceStatus GroupDeconv2dGeneral::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dGeneral, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, NEON, General))
MACE_REGISTER_DELEGATOR(
registry, GroupDeconv2dGeneral, delegator::GroupDeconv2dParam,
MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float, NEON, General))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -18,12 +18,13 @@ ...@@ -18,12 +18,13 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/deconv_2d.h" #include "mace/ops/arm/fp32/deconv_2d.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/depthwise_deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,16 +33,9 @@ namespace fp32 { ...@@ -32,16 +33,9 @@ namespace fp32 {
class DepthwiseDeconv2dGeneral : public Deconv2dBase { class DepthwiseDeconv2dGeneral : public Deconv2dBase {
public: public:
DepthwiseDeconv2dGeneral(const std::vector<int> &strides, explicit DepthwiseDeconv2dGeneral(
const std::vector<int> &dilations, const delegator::DepthwiseDeconv2dParam &param)
const std::vector<int> &paddings, : Deconv2dBase(param) {}
const Padding padding_type,
const FrameworkType framework_type)
: Deconv2dBase(strides,
dilations,
paddings,
padding_type,
framework_type) {}
virtual ~DepthwiseDeconv2dGeneral() {} virtual ~DepthwiseDeconv2dGeneral() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -54,18 +48,8 @@ class DepthwiseDeconv2dGeneral : public Deconv2dBase { ...@@ -54,18 +48,8 @@ class DepthwiseDeconv2dGeneral : public Deconv2dBase {
class GroupDeconv2dGeneral : public Deconv2dBase { class GroupDeconv2dGeneral : public Deconv2dBase {
public: public:
GroupDeconv2dGeneral(const std::vector<int> &strides, explicit GroupDeconv2dGeneral(const delegator::GroupDeconv2dParam &param)
const std::vector<int> &dilations, : Deconv2dBase(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const int group,
const FrameworkType framework_type)
: Deconv2dBase(strides,
dilations,
paddings,
padding_type,
group,
framework_type) {}
virtual ~GroupDeconv2dGeneral() {} virtual ~GroupDeconv2dGeneral() {}
MaceStatus Compute( MaceStatus Compute(
......
...@@ -1224,6 +1224,9 @@ MaceStatus Gemm::Compute(const OpContext *context, ...@@ -1224,6 +1224,9 @@ MaceStatus Gemm::Compute(const OpContext *context,
output); output);
} }
MACE_REGISTER_DELEGATOR(registry, Gemm, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, CPU, float, NEON))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
#ifndef MACE_OPS_ARM_FP32_GEMM_H_ #ifndef MACE_OPS_ARM_FP32_GEMM_H_
#define MACE_OPS_ARM_FP32_GEMM_H_ #define MACE_OPS_ARM_FP32_GEMM_H_
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h" #include "mace/ops/common/matrix.h"
#include "mace/ops/delegator/gemm.h"
#include "mace/public/mace.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
// This implements matrix-matrix multiplication. // This implements matrix-matrix multiplication.
...@@ -29,13 +30,12 @@ namespace ops { ...@@ -29,13 +30,12 @@ namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Gemm { class Gemm : public delegator::Gemm {
public: public:
explicit Gemm(const bool should_cache_pack) explicit Gemm(const delegator::GemmParam &param)
: pack_cache_(GetCPUAllocator()), : delegator::Gemm(param), pack_cache_(GetCPUAllocator()),
should_cache_pack_(should_cache_pack), should_cache_pack_(param.should_cache_pack_),
cached_(0) {} cached_(0) {}
Gemm() : Gemm(false) {}
~Gemm() {} ~Gemm() {}
MaceStatus Compute( MaceStatus Compute(
...@@ -51,7 +51,7 @@ class Gemm { ...@@ -51,7 +51,7 @@ class Gemm {
const MatrixMajor output_major, const MatrixMajor output_major,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
// Original matrix before transpose has row-major // Original matrix before transpose has row-major
MaceStatus Compute( MaceStatus Compute(
...@@ -68,7 +68,7 @@ class Gemm { ...@@ -68,7 +68,7 @@ class Gemm {
const bool transpose_out, const bool transpose_out,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
private: private:
void ComputeBlock(const float *packed_lhs_data, void ComputeBlock(const float *packed_lhs_data,
......
...@@ -378,6 +378,10 @@ MaceStatus Gemv::Compute(const OpContext *context, ...@@ -378,6 +378,10 @@ MaceStatus Gemv::Compute(const OpContext *context,
#undef vaddvq_f32 #undef vaddvq_f32
#endif #endif
MACE_REGISTER_DELEGATOR(registry, Gemv, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, CPU, float, NEON))
} // namespace fp32 } // namespace fp32
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -15,18 +15,19 @@ ...@@ -15,18 +15,19 @@
#ifndef MACE_OPS_ARM_FP32_GEMV_H_ #ifndef MACE_OPS_ARM_FP32_GEMV_H_
#define MACE_OPS_ARM_FP32_GEMV_H_ #define MACE_OPS_ARM_FP32_GEMV_H_
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h" #include "mace/ops/delegator/gemv.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace arm {
namespace fp32 { namespace fp32 {
class Gemv { class Gemv : public delegator::Gemv {
public: public:
Gemv() {} explicit Gemv(const DelegatorParam &param) : delegator::Gemv(param) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -39,7 +40,7 @@ class Gemv { ...@@ -39,7 +40,7 @@ class Gemv {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
} // namespace fp32 } // namespace fp32
......
...@@ -12,12 +12,11 @@ ...@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/arm/q8/eltwise.h"
#include <arm_neon.h> #include <arm_neon.h>
#include <algorithm> #include <algorithm>
#include "mace/ops/common/gemmlowp_util.h" #include "mace/ops/common/gemmlowp_util.h"
#include "mace/ops/delegator/eltwise.h"
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
namespace mace { namespace mace {
...@@ -25,6 +24,16 @@ namespace ops { ...@@ -25,6 +24,16 @@ namespace ops {
namespace arm { namespace arm {
namespace q8 { namespace q8 {
class Eltwise : public delegator::Eltwise {
public:
explicit Eltwise(const delegator::EltwiseParam &param)
: delegator::Eltwise(param) {}
~Eltwise() = default;
MaceStatus Compute(const OpContext *context, const Tensor *input0,
const Tensor *input1, Tensor *output) override;
};
MaceStatus Eltwise::Compute(const OpContext *context, MaceStatus Eltwise::Compute(const OpContext *context,
const Tensor *input0, const Tensor *input0,
const Tensor *input1, const Tensor *input1,
...@@ -153,6 +162,9 @@ MaceStatus Eltwise::Compute(const OpContext *context, ...@@ -153,6 +162,9 @@ MaceStatus Eltwise::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MACE_REGISTER_DELEGATOR(registry, Eltwise, delegator::EltwiseParam,
MACE_DELEGATOR_KEY(Eltwise, CPU, uint8_t, NEON))
} // namespace q8 } // namespace q8
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
...@@ -181,6 +181,14 @@ class Gemv<uint8_t>; ...@@ -181,6 +181,14 @@ class Gemv<uint8_t>;
template template
class Gemv<int32_t>; class Gemv<int32_t>;
typedef Gemv<uint8_t> GemvUint8;
MACE_REGISTER_DELEGATOR(registry, GemvUint8, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, CPU, uint8_t, NEON))
typedef Gemv<int32_t> GemvInt32;
MACE_REGISTER_DELEGATOR(registry, GemvInt32, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, CPU, int32_t, NEON))
} // namespace q8 } // namespace q8
} // namespace arm } // namespace arm
} // namespace ops } // namespace ops
......
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,15 +12,10 @@ ...@@ -12,15 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// This implements matrix-vector multiplication described as
// https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt
#ifndef MACE_OPS_ARM_Q8_GEMV_H_ #ifndef MACE_OPS_ARM_Q8_GEMV_H_
#define MACE_OPS_ARM_Q8_GEMV_H_ #define MACE_OPS_ARM_Q8_GEMV_H_
#include "mace/public/mace.h" #include "mace/ops/delegator/gemv.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -28,11 +23,11 @@ namespace arm { ...@@ -28,11 +23,11 @@ namespace arm {
namespace q8 { namespace q8 {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class Gemv { class Gemv : public delegator::Gemv {
public: public:
Gemv() : is_output_type_uint8_( explicit Gemv(const DelegatorParam &param)
DataTypeToEnum<OUTPUT_TYPE>::value == DataType::DT_UINT8) { : delegator::Gemv(param), is_output_type_uint8_(
} DataTypeToEnum<OUTPUT_TYPE>::value == DataType::DT_UINT8) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -45,7 +40,7 @@ class Gemv { ...@@ -45,7 +40,7 @@ class Gemv {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
private: private:
bool is_output_type_uint8_; bool is_output_type_uint8_;
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/quantize.h" #include "mace/core/quantize.h"
...@@ -106,12 +107,12 @@ class DequantizeOp<DeviceType::CPU, T> : public Operation { ...@@ -106,12 +107,12 @@ class DequantizeOp<DeviceType::CPU, T> : public Operation {
QuantizeUtil<float, T> quantize_util_; QuantizeUtil<float, T> quantize_util_;
}; };
void RegisterQuantize(OpRegistryBase *op_registry) { void RegisterQuantize(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Quantize", QuantizeOp, MACE_REGISTER_OP(op_registry, "Quantize", QuantizeOp,
DeviceType::CPU, uint8_t); DeviceType::CPU, uint8_t);
} }
void RegisterDequantize(OpRegistryBase *op_registry) { void RegisterDequantize(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp, MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp,
DeviceType::CPU, uint8_t); DeviceType::CPU, uint8_t);
MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp, MACE_REGISTER_OP(op_registry, "Dequantize", DequantizeOp,
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/delegator/activation.h"
#if defined(MACE_ENABLE_NEON)
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/activation.h"
#endif
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
...@@ -45,11 +41,16 @@ class BatchNormOp<DeviceType::CPU, float> : public Operation { ...@@ -45,11 +41,16 @@ class BatchNormOp<DeviceType::CPU, float> : public Operation {
epsilon_(Operation::GetOptionalArg<float>("epsilon", epsilon_(Operation::GetOptionalArg<float>("epsilon",
static_cast<float>(1e-4))), static_cast<float>(1e-4))),
activation_delegator_( activation_delegator_(
delegator::Activation::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(
ops::StringToActivationType( ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation", "NOOP")), Operation::GetOptionalArg<std::string>("activation",
"NOOP")),
Operation::GetOptionalArg<float>("max_limit", 0.0f), Operation::GetOptionalArg<float>("max_limit", 0.0f),
Operation::GetOptionalArg<float>( Operation::GetOptionalArg<float>("leakyrelu_coefficient",
"leakyrelu_coefficient", 0.0f)) {} 0.0f)))) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -142,18 +143,14 @@ class BatchNormOp<DeviceType::CPU, float> : public Operation { ...@@ -142,18 +143,14 @@ class BatchNormOp<DeviceType::CPU, float> : public Operation {
}, 0, batch, 1, 0, channels, 1); }, 0, batch, 1, 0, channels, 1);
} }
activation_delegator_.Compute(context, output, output); activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
float epsilon_; float epsilon_;
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
arm::fp32::Activation activation_delegator_;
#else
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
protected: protected:
MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR); MACE_OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR);
...@@ -232,7 +229,7 @@ class BatchNormOp<DeviceType::GPU, float> : public Operation { ...@@ -232,7 +229,7 @@ class BatchNormOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterBatchNorm(OpRegistryBase *op_registry) { void RegisterBatchNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "BatchNorm", BatchNormOp, MACE_REGISTER_OP(op_registry, "BatchNorm", BatchNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "BatchNorm", BatchNormOp); MACE_REGISTER_GPU_OP(op_registry, "BatchNorm", BatchNormOp);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/batch_to_space.h" #include "mace/ops/opencl/image/batch_to_space.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -285,7 +286,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, float> : public BatchToSpaceOpBase { ...@@ -285,7 +286,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, float> : public BatchToSpaceOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterBatchToSpaceND(OpRegistryBase *op_registry) { void RegisterBatchToSpaceND(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "BatchToSpaceND", MACE_REGISTER_OP(op_registry, "BatchToSpaceND",
BatchToSpaceNDOp, DeviceType::CPU, float); BatchToSpaceNDOp, DeviceType::CPU, float);
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/delegator/bias_add.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/bias_add.h"
#else
#include "mace/ops/ref/bias_add.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
...@@ -42,8 +38,11 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -42,8 +38,11 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
public: public:
explicit BiasAddOp(OpConstructContext *context) explicit BiasAddOp(OpConstructContext *context)
: Operation(context), : Operation(context),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 0)),
0)) {} bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -56,7 +55,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -56,7 +55,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2, MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or n*c for caffee.", "bias must be 1-dimensional or n*c for caffee.",
MakeString(bias->shape())); MakeString(bias->shape()));
bias_add_delegator_.Compute(context, input, bias, output); bias_add_delegator_->Compute(context, input, bias, output);
} else { // NHWC } else { // NHWC
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2, MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1 or 2 dimensionals for caffee.", "bias must be 1 or 2 dimensionals for caffee.",
...@@ -115,11 +114,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation { ...@@ -115,11 +114,7 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
private: private:
int has_data_format_; int has_data_format_;
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::BiasAdd> bias_add_delegator_;
arm::fp32::BiasAdd bias_add_delegator_;
#else
ref::BiasAdd bias_add_delegator_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -164,7 +159,7 @@ class BiasAddOp<DeviceType::GPU, float> : public Operation { ...@@ -164,7 +159,7 @@ class BiasAddOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterBiasAdd(OpRegistryBase *op_registry) { void RegisterBiasAdd(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "BiasAdd", BiasAddOp); MACE_REGISTER_GPU_OP(op_registry, "BiasAdd", BiasAddOp);
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#if defined(MACE_ENABLE_NEON) && defined(__ANDROID__) #if defined(MACE_ENABLE_NEON) && defined(__ANDROID__)
#include <arm_neon.h> #include <arm_neon.h>
...@@ -54,7 +55,7 @@ class CastOp : public Operation { ...@@ -54,7 +55,7 @@ class CastOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterCast(OpRegistryBase *op_registry) { void RegisterCast(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Cast", CastOp, MACE_REGISTER_OP(op_registry, "Cast", CastOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Cast", CastOp, MACE_REGISTER_OP(op_registry, "Cast", CastOp,
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/channel_shuffle.h" #include "mace/ops/opencl/image/channel_shuffle.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -98,7 +99,7 @@ class ChannelShuffleOp<DeviceType::GPU, float> : public Operation { ...@@ -98,7 +99,7 @@ class ChannelShuffleOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterChannelShuffle(OpRegistryBase *op_registry) { void RegisterChannelShuffle(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ChannelShuffle", MACE_REGISTER_OP(op_registry, "ChannelShuffle",
ChannelShuffleOp, DeviceType::CPU, float); ChannelShuffleOp, DeviceType::CPU, float);
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#ifndef MACE_OPS_COMMON_LSTM_H_ #ifndef MACE_OPS_COMMON_LSTM_H_
#define MACE_OPS_COMMON_LSTM_H_ #define MACE_OPS_COMMON_LSTM_H_
#include "mace/core/ops/op_context.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/core/op_context.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
namespace mace { namespace mace {
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/quantize.h" #include "mace/core/quantize.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
...@@ -221,7 +222,7 @@ class ConcatOp<DeviceType::GPU, float> : public ConcatOpBase { ...@@ -221,7 +222,7 @@ class ConcatOp<DeviceType::GPU, float> : public ConcatOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterConcat(OpRegistryBase *op_registry) { void RegisterConcat(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Concat", ConcatOp, MACE_REGISTER_OP(op_registry, "Concat", ConcatOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -24,32 +24,18 @@ ...@@ -24,32 +24,18 @@
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_base.h" #include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/activation.h"
#include "mace/ops/delegator/bias_add.h"
#include "mace/ops/delegator/conv_2d.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/conv_2d.h"
#include "mace/ops/arm/fp32/conv_2d_1x1.h"
#include "mace/ops/arm/fp32/conv_2d_3x3.h"
#include "mace/ops/arm/fp32/conv_2d_3x3_winograd.h"
#include "mace/ops/arm/fp32/conv_2d_5x5.h"
#include "mace/ops/arm/fp32/conv_2d_7x7.h"
#include "mace/ops/arm/fp32/conv_2d_1xn.h"
#include "mace/ops/arm/fp32/conv_general.h"
#include "mace/ops/arm/fp32/bias_add.h"
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/activation.h"
#include "mace/ops/ref/bias_add.h"
#endif // MACE_ENABLE_NEON
#include "mace/ops/ref/conv_2d.h"
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/common/gemmlowp_util.h" #include "mace/ops/common/gemmlowp_util.h"
#include "mace/ops/arm/q8/quantization_util.h" #include "mace/ops/arm/q8/quantization_util.h"
...@@ -72,13 +58,21 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase { ...@@ -72,13 +58,21 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
public: public:
explicit Conv2dOp(OpConstructContext *context) explicit Conv2dOp(OpConstructContext *context)
: ConvPool2dOpBase(context), : ConvPool2dOpBase(context),
activation_delegator_(ops::StringToActivationType( activation_delegator_(
delegator::Activation::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(
ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation", Operation::GetOptionalArg<std::string>("activation",
"NOOP")), "NOOP")),
Operation::GetOptionalArg<float>("max_limit", Operation::GetOptionalArg<float>("max_limit", 0.0f),
0.0f), Operation::GetOptionalArg<float>("leakyrelu_coefficient",
Operation::GetOptionalArg<float>( 0.0f)))),
"leakyrelu_coefficient", 0.0f)) {} bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -86,7 +80,10 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase { ...@@ -86,7 +80,10 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
#ifdef MACE_ENABLE_NEON if (conv2d_delegator_ == nullptr) {
std::string tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
MACE_CPU_IMPL_TYPE, General);
if (MACE_CPU_IMPL_TYPE == NEON) {
// the following params are used to decide which conv delegator to use // the following params are used to decide which conv delegator to use
const index_t stride_h = strides_[0]; const index_t stride_h = strides_[0];
const index_t stride_w = strides_[1]; const index_t stride_w = strides_[1];
...@@ -96,106 +93,87 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase { ...@@ -96,106 +93,87 @@ class Conv2dOp<DeviceType::CPU, float> : public ConvPool2dOpBase {
const index_t filter_w = filter->dim(3); const index_t filter_w = filter->dim(3);
const index_t input_channels = input->dim(1); const index_t input_channels = input->dim(1);
const index_t channels = filter->dim(0); const index_t channels = filter->dim(0);
// NOTE: delegator is fixed after first round of running, // NOTE: delegator is fixed after first round of running,
// although winograd depends on input params. // although winograd depends on input params.
// We do not support changeable filter for now. // We do not support changeable filter for now.
if (conv2d_delegator_ == nullptr) {
if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1 if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) { && dilation_h == 1 && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK1x1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K1x1);
} else if (filter_h == 3 && filter_w == 3 } else if (filter_h == 3 && filter_w == 3
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
if (input_channels >= 8 && channels >= 8) { if (input_channels >= 8 && channels >= 8) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK3x3Winograd>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K3x3Winograd);
} else { } else {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK3x3S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K3x3S1);
} }
} else if (filter_h == 3 && filter_w == 3 } else if (filter_h == 3 && filter_w == 3
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && stride_h == 2 && stride_w == 2 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK3x3S2>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K3x3S2);
} else if (filter_h == 5 && filter_w == 5 } else if (filter_h == 5 && filter_w == 5
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK5x5S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K5x5S1);
} else if (filter_h == 7 && filter_w == 7 } else if (filter_h == 7 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK7x7S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K7x7S1);
} else if (filter_h == 7 && filter_w == 7 } else if (filter_h == 7 && filter_w == 7
&& stride_h == 2 && stride_w == 2 && dilation_h == 1 && stride_h == 2 && stride_w == 2 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK7x7S2>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K7x7S2);
} else if (filter_h == 7 && filter_w == 7 } else if (filter_h == 7 && filter_w == 7
&& stride_h == 3 && stride_w == 3 && dilation_h == 1 && stride_h == 3 && stride_w == 3 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK7x7S3>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K7x7S3);
} else if (filter_h == 1 && filter_w == 7 } else if (filter_h == 1 && filter_w == 7
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK1x7S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K1x7S1);
} else if (filter_h == 7 && filter_w == 1 } else if (filter_h == 7 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK7x1S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K7x1S1);
} else if (filter_h == 1 && filter_w == 15 } else if (filter_h == 1 && filter_w == 15
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK1x15S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K1x15S1);
} else if (filter_h == 15 && filter_w == 1 } else if (filter_h == 15 && filter_w == 1
&& stride_h == 1 && stride_w == 1 && dilation_h == 1 && stride_h == 1 && stride_w == 1 && dilation_h == 1
&& dilation_w == 1) { && dilation_w == 1) {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dK15x1S1>( tag = MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float,
paddings_, padding_type_); MACE_CPU_IMPL_TYPE, K15x1S1);
} else {
conv2d_delegator_ = make_unique<arm::fp32::Conv2dGeneral>(
strides_,
dilations_,
paddings_,
padding_type_);
} }
} }
delegator::Conv2dParam param(strides_, dilations_,
conv2d_delegator_->Compute(context, input, filter, output); paddings_, padding_type_);
#else conv2d_delegator_ = delegator::Conv2d::Create(context->workspace(),
if (ref_conv2d_delegator_ == nullptr) { tag, param);
ref_conv2d_delegator_ = make_unique<ref::Conv2d<float>>(strides_,
dilations_,
paddings_,
padding_type_);
} }
ref_conv2d_delegator_->Compute(context, input, filter, output);
#endif
bias_add_delegator_.Compute(context, output, bias, output); conv2d_delegator_->Compute(context, input, filter, output);
activation_delegator_.Compute(context, output, output); bias_add_delegator_->Compute(context, output, bias, output);
activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_; std::unique_ptr<delegator::BiasAdd> bias_add_delegator_;
arm::fp32::BiasAdd bias_add_delegator_; std::unique_ptr<delegator::Conv2d> conv2d_delegator_;
arm::fp32::Activation activation_delegator_;
#else
std::unique_ptr<ref::Conv2d<float>> ref_conv2d_delegator_;
ref::BiasAdd bias_add_delegator_;
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
private: private:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS); MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
...@@ -518,7 +496,7 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase { ...@@ -518,7 +496,7 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterConv2D(OpRegistryBase *op_registry) { void RegisterConv2D(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Conv2D", Conv2dOp, MACE_REGISTER_OP(op_registry, "Conv2D", Conv2dOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
namespace mace { namespace mace {
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -132,7 +133,7 @@ class CropOp<DeviceType::GPU, float> : public Operation { ...@@ -132,7 +133,7 @@ class CropOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterCrop(OpRegistryBase *op_registry) { void RegisterCrop(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Crop", CropOp, MACE_REGISTER_OP(op_registry, "Crop", CropOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "Crop", CropOp); MACE_REGISTER_GPU_OP(op_registry, "Crop", CropOp);
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <functional> #include <functional>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -141,7 +142,7 @@ class CumsumOp<DeviceType::CPU, T> : public Operation { ...@@ -141,7 +142,7 @@ class CumsumOp<DeviceType::CPU, T> : public Operation {
bool checked_; bool checked_;
}; };
void RegisterCumsum(OpRegistryBase *op_registry) { void RegisterCumsum(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Cumsum", CumsumOp, MACE_REGISTER_OP(op_registry, "Cumsum", CumsumOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -14,20 +14,6 @@ ...@@ -14,20 +14,6 @@
#include "mace/ops/deconv_2d.h" #include "mace/ops/deconv_2d.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#include "mace/ops/arm/fp32/deconv_2d_2x2.h"
#include "mace/ops/arm/fp32/deconv_2d_3x3.h"
#include "mace/ops/arm/fp32/deconv_2d_4x4.h"
#include "mace/ops/arm/fp32/deconv_2d_general.h"
#include "mace/ops/arm/fp32/bias_add.h"
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/bias_add.h"
#include "mace/ops/ref/activation.h"
#include "mace/ops/ref/deconv_2d.h"
#endif
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -35,9 +21,13 @@ ...@@ -35,9 +21,13 @@
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/activation.h"
#include "mace/ops/delegator/bias_add.h"
#include "mace/ops/delegator/deconv_2d.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
...@@ -49,6 +39,10 @@ ...@@ -49,6 +39,10 @@
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace {
const std::vector<int> kDeconv2dStrides = {1, 1};
}
template<DeviceType D, class T> template<DeviceType D, class T>
class Deconv2dOp; class Deconv2dOp;
...@@ -57,9 +51,16 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase { ...@@ -57,9 +51,16 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
public: public:
explicit Deconv2dOp(OpConstructContext *context) explicit Deconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context), : Deconv2dOpBase(context),
activation_delegator_(activation_, activation_delegator_(
relux_max_limit_, delegator::Activation::Create(
leakyrelu_coefficient_) {} context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(0);
...@@ -79,7 +80,11 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase { ...@@ -79,7 +80,11 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
#ifdef MACE_ENABLE_NEON
if (deconv2d_delegator_ == nullptr) {
std::string tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
MACE_CPU_IMPL_TYPE, General);
if (MACE_CPU_IMPL_TYPE == NEON) {
const index_t kernel_h = filter->dim(2); const index_t kernel_h = filter->dim(2);
const index_t kernel_w = filter->dim(3); const index_t kernel_w = filter->dim(3);
...@@ -98,72 +103,44 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase { ...@@ -98,72 +103,44 @@ class Deconv2dOp<DeviceType::CPU, float> : public Deconv2dOpBase {
bool use_neon_4x4_s2 = kernel_h == kernel_w && kernel_h == 4 && bool use_neon_4x4_s2 = kernel_h == kernel_w && kernel_h == 4 &&
strides_[0] == strides_[1] && strides_[0] == 2; strides_[0] == strides_[1] && strides_[0] == 2;
if (deconv2d_delegator_ == nullptr) {
if (use_neon_2x2_s1) { if (use_neon_2x2_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK2x2S1>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K2x2S1);
} else if (use_neon_2x2_s2) { } else if (use_neon_2x2_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK2x2S2>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K2x2S2);
} else if (use_neon_3x3_s1) { } else if (use_neon_3x3_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK3x3S1>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K3x3S1);
} else if (use_neon_3x3_s2) { } else if (use_neon_3x3_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK3x3S2>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K3x3S2);
} else if (use_neon_4x4_s1) { } else if (use_neon_4x4_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK4x4S1>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K4x4S1);
} else if (use_neon_4x4_s2) { } else if (use_neon_4x4_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::Deconv2dK4x4S2>( tag = MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float,
paddings_, padding_type_, model_type_); MACE_CPU_IMPL_TYPE, K4x4S2);
} else {
deconv2d_delegator_ =
make_unique<arm::fp32::Deconv2dGeneral>(strides_,
std::vector<int>{1, 1},
paddings_,
padding_type_,
model_type_);
} }
} }
deconv2d_delegator_->Compute(context, delegator::Deconv2dParam param(strides_, kDeconv2dStrides, paddings_,
input, padding_type_, model_type_);
filter, deconv2d_delegator_ = delegator::Deconv2d::Create(context->workspace(),
output_shape_tensor, tag, param);
output);
#else
if (deconv2d_delegator_ == nullptr) {
deconv2d_delegator_ = make_unique<ref::Deconv2d<float>>(strides_,
std::vector<int>{
1, 1},
paddings_,
padding_type_,
model_type_);
} }
deconv2d_delegator_->Compute(context,
input,
filter,
output_shape_tensor,
output);
#endif // MACE_ENABLE_NEON
bias_add_delegator_.Compute(context, output, bias, output); deconv2d_delegator_->Compute(context, input, filter,
activation_delegator_.Compute(context, output, output); output_shape_tensor, output);
bias_add_delegator_->Compute(context, output, bias, output);
activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
std::unique_ptr<arm::fp32::Deconv2dBase> deconv2d_delegator_; std::unique_ptr<delegator::BiasAdd> bias_add_delegator_;
arm::fp32::BiasAdd bias_add_delegator_; std::unique_ptr<delegator::Deconv2d> deconv2d_delegator_;
arm::fp32::Activation activation_delegator_;
#else
ref::BiasAdd bias_add_delegator_;
ref::Activation activation_delegator_;
std::unique_ptr<ref::Deconv2d<float>> deconv2d_delegator_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -258,7 +235,7 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase { ...@@ -258,7 +235,7 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterDeconv2D(OpRegistryBase *op_registry) { void RegisterDeconv2D(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp, MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "Deconv2D", Deconv2dOp); MACE_REGISTER_GPU_OP(op_registry, "Deconv2D", Deconv2dOp);
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
......
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,42 +12,50 @@ ...@@ -12,42 +12,50 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_FP32_ACTIVATION_H_ #ifndef MACE_OPS_DELEGATOR_ACTIVATION_H_
#define MACE_OPS_ARM_FP32_ACTIVATION_H_ #define MACE_OPS_DELEGATOR_ACTIVATION_H_
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace delegator {
namespace fp32 {
class Activation { struct ActivationParam : public DelegatorParam {
explicit ActivationParam(ActivationType type, const float limit,
const float leakyrelu_coefficient)
: type_(type), limit_(limit),
leakyrelu_coefficient_(leakyrelu_coefficient) {}
ActivationType type_;
const float limit_;
const float leakyrelu_coefficient_;
};
class Activation : public OpDelegator {
public: public:
explicit Activation(ActivationType type, explicit Activation(const ActivationParam &param)
const float limit, : OpDelegator(param), type_(param.type_), limit_(param.limit_),
const float leakyrelu_coefficient); leakyrelu_coefficient_(param.leakyrelu_coefficient_) {}
~Activation() = default; virtual ~Activation() = default;
MaceStatus Compute( MACE_DEFINE_DELEGATOR_CREATOR(Activation)
const OpContext *context,
const Tensor *input,
Tensor *output);
private: virtual MaceStatus Compute(const OpContext *context,
void DoActivation(const OpContext *context,
const Tensor *input, const Tensor *input,
Tensor *output); Tensor *output) = 0;
protected:
ActivationType type_; ActivationType type_;
const float limit_; const float limit_;
const float leakyrelu_coefficient_; const float leakyrelu_coefficient_;
}; };
} // namespace fp32 } // namespace delegator
} // namespace arm
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_FP32_ACTIVATION_H_ #endif // MACE_OPS_DELEGATOR_ACTIVATION_H_
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,37 +12,32 @@ ...@@ -12,37 +12,32 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_ARM_FP32_BIAS_ADD_H_ #ifndef MACE_OPS_DELEGATOR_BIAS_ADD_H_
#define MACE_OPS_ARM_FP32_BIAS_ADD_H_ #define MACE_OPS_DELEGATOR_BIAS_ADD_H_
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace delegator {
namespace fp32 {
class BiasAdd { class BiasAdd : public OpDelegator {
public: public:
BiasAdd() = default; explicit BiasAdd(const DelegatorParam &param) : OpDelegator(param) {}
~BiasAdd() = default; virtual ~BiasAdd() = default;
MaceStatus Compute( MACE_DEFINE_DELEGATOR_CREATOR(BiasAdd)
const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output);
private: virtual MaceStatus Compute(const OpContext *context,
void AddBias(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *bias, const Tensor *bias,
Tensor *output); Tensor *output) = 0;
}; };
} // namespace fp32 } // namespace delegator
} // namespace arm
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_FP32_BIAS_ADD_H_ #endif // MACE_OPS_DELEGATOR_BIAS_ADD_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DELEGATOR_CONV_2D_H_
#define MACE_OPS_DELEGATOR_CONV_2D_H_
#include <vector>
#include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
enum ConvType {
General,
K1x1,
K1x7S1,
K7x1S1,
K1x15S1,
K15x1S1,
K3x3S1,
K3x3S2,
K3x3Winograd,
K5x5S1,
K7x7S1,
K7x7S2,
K7x7S3,
};
namespace delegator {
struct Conv2dParam : public DelegatorParam {
explicit Conv2dParam(const std::vector<int> &strides,
const std::vector<int> &dilations,
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides), dilations_(dilations),
paddings_(paddings), padding_type_(padding_type) {}
const std::vector<int> &strides_;
const std::vector<int> &dilations_;
const std::vector<int> &paddings_;
const Padding padding_type_;
};
class Conv2d : public OpDelegator {
public:
explicit Conv2d(const delegator::Conv2dParam &param)
: OpDelegator(param),
strides_(param.strides_),
dilations_(param.dilations_),
paddings_(param.paddings_),
padding_type_(param.padding_type_) {}
virtual ~Conv2d() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Conv2d)
virtual MaceStatus Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
Tensor *output) = 0;
protected:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
};
} // namespace delegator
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DELEGATOR_CONV_2D_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DELEGATOR_DECONV_2D_H_
#define MACE_OPS_DELEGATOR_DECONV_2D_H_
#include <vector>
#include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
namespace mace {
namespace ops {
enum DeconvType {
General,
K2x2S1,
K2x2S2,
K3x3S1,
K3x3S2,
K4x4S1,
K4x4S2,
};
namespace delegator {
struct Deconv2dParam : public DelegatorParam {
explicit Deconv2dParam(const std::vector<int> &strides,
const std::vector<int> &dilations,
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type,
const int group = 1)
: strides_(strides), dilations_(dilations),
paddings_(paddings), padding_type_(padding_type),
framework_type_(framework_type),
group_(group) {}
const std::vector<int> &strides_;
const std::vector<int> &dilations_;
const std::vector<int> &paddings_;
const Padding padding_type_;
const FrameworkType framework_type_;
const int group_;
};
class Deconv2d : public OpDelegator {
public:
explicit Deconv2d(const Deconv2dParam &param)
: OpDelegator(param),
strides_(param.strides_),
dilations_(param.dilations_),
paddings_(param.paddings_),
padding_type_(param.padding_type_),
framework_type_(param.framework_type_),
group_(param.group_) {}
virtual ~Deconv2d() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Deconv2d)
virtual MaceStatus Compute(const OpContext *context,
const Tensor *input,
const Tensor *filter,
const Tensor *output_shape,
Tensor *output) = 0;
protected:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
const FrameworkType framework_type_;
const int group_;
};
} // namespace delegator
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DELEGATOR_DECONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,35 +12,22 @@ ...@@ -12,35 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_REF_BIAS_ADD_H_
#define MACE_OPS_REF_BIAS_ADD_H_
#include "mace/core/op_context.h" #ifndef MACE_OPS_DELEGATOR_DEPTHWISE_CONV_2D_H_
#define MACE_OPS_DELEGATOR_DEPTHWISE_CONV_2D_H_
#include "mace/ops/delegator/conv_2d.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace delegator {
class BiasAdd { typedef Conv2dParam DepthwiseConv2dParam;
public: typedef Conv2d DepthwiseConv2d;
BiasAdd() = default;
~BiasAdd() = default; } // namespace delegator
MaceStatus Compute(
const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output);
private:
void AddBias(const OpContext *context,
const Tensor *input,
const Tensor *bias,
Tensor *output);
};
} // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_REF_BIAS_ADD_H_ #endif // MACE_OPS_DELEGATOR_DEPTHWISE_CONV_2D_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DELEGATOR_DEPTHWISE_DECONV_2D_H_
#define MACE_OPS_DELEGATOR_DEPTHWISE_DECONV_2D_H_
#include "mace/ops/delegator/deconv_2d.h"
namespace mace {
namespace ops {
namespace delegator {
typedef Deconv2dParam DepthwiseDeconv2dParam;
typedef Deconv2dParam GroupDeconv2dParam;
typedef Deconv2d DepthwiseDeconv2d;
typedef Deconv2d GroupDeconv2d;
} // namespace delegator
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DELEGATOR_DEPTHWISE_DECONV_2D_H_
// Copyright 2019 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -15,34 +15,43 @@ ...@@ -15,34 +15,43 @@
// This implements matrix-vector multiplication described as // This implements matrix-vector multiplication described as
// https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt // https://github.com/google/gemmlowp/blob/master/todo/fast-gemv.txt
#ifndef MACE_OPS_ARM_Q8_ELTWISE_H_ #ifndef MACE_OPS_DELEGATOR_ELTWISE_H_
#define MACE_OPS_ARM_Q8_ELTWISE_H_ #define MACE_OPS_DELEGATOR_ELTWISE_H_
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/core/types.h" #include "mace/core/types.h"
#include "mace/ops/common/eltwise_type.h" #include "mace/ops/common/eltwise_type.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace arm { namespace delegator {
namespace q8 {
class Eltwise { struct EltwiseParam : public DelegatorParam {
explicit EltwiseParam(EltwiseType type)
: type_(type) {}
EltwiseType type_;
};
class Eltwise : public OpDelegator {
public: public:
explicit Eltwise(const EltwiseType type) : type_(type) {} explicit Eltwise(const EltwiseParam &param) : OpDelegator(param),
type_(param.type_) {}
virtual ~Eltwise() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Eltwise)
MaceStatus Compute(const OpContext *context, virtual MaceStatus Compute(const OpContext *context, const Tensor *input0,
const Tensor *input0, const Tensor *input1, Tensor *output) = 0;
const Tensor *input1,
Tensor *output);
private: protected:
EltwiseType type_; EltwiseType type_;
}; };
} // namespace q8 } // namespace delegator
} // namespace arm
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_ARM_Q8_ELTWISE_H_ #endif // MACE_OPS_DELEGATOR_ELTWISE_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DELEGATOR_GEMM_H_
#define MACE_OPS_DELEGATOR_GEMM_H_
#include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/ops/common/matrix.h"
namespace mace {
namespace ops {
namespace delegator {
struct GemmParam : public DelegatorParam {
explicit GemmParam(const bool should_cache_pack = false)
: should_cache_pack_(should_cache_pack) {}
const bool should_cache_pack_;
};
class Gemm : public OpDelegator {
public:
explicit Gemm(const GemmParam &param) : OpDelegator(param) {}
virtual ~Gemm() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Gemm)
virtual MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t rows,
const index_t cols,
const index_t depth,
const MatrixMajor lhs_major,
const MatrixMajor rhs_major,
const MatrixMajor output_major,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) = 0;
// Original matrix before transpose has row-major
virtual MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const index_t batch,
const index_t lhs_rows,
const index_t lhs_cols,
const index_t rhs_rows,
const index_t rhs_cols,
const bool transpose_lhs,
const bool transpose_rhs,
const bool transpose_out,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) = 0;
};
} // namespace delegator
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DELEGATOR_GEMM_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_DELEGATOR_GEMV_H_
#define MACE_OPS_DELEGATOR_GEMV_H_
#include "mace/core/ops/op_context.h"
#include "mace/core/ops/op_delegator.h"
#include "mace/core/registry/op_delegator_registry.h"
namespace mace {
namespace ops {
namespace delegator {
class Gemv : public OpDelegator {
public:
explicit Gemv(const DelegatorParam &param) : OpDelegator(param) {}
virtual ~Gemv() = default;
MACE_DEFINE_DELEGATOR_CREATOR(Gemv)
// Always row-major after transpose
virtual MaceStatus Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
const bool rhs_batched,
Tensor *output) = 0;
};
} // namespace delegator
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DELEGATOR_GEMV_H_
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/depth_to_space.h" #include "mace/ops/opencl/image/depth_to_space.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -184,7 +185,7 @@ class DepthToSpaceOp<DeviceType::GPU, float> : public Operation { ...@@ -184,7 +185,7 @@ class DepthToSpaceOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterDepthToSpace(OpRegistryBase *op_registry) { void RegisterDepthToSpace(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthToSpace", MACE_REGISTER_OP(op_registry, "DepthToSpace",
DepthToSpaceOp, DeviceType::CPU, float); DepthToSpaceOp, DeviceType::CPU, float);
......
...@@ -17,17 +17,6 @@ ...@@ -17,17 +17,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/ops/ref/depthwise_conv_2d.h"
#if defined(MACE_ENABLE_NEON)
#include "mace/ops/arm/fp32/depthwise_conv_2d_3x3.h"
#include "mace/ops/arm/fp32/bias_add.h"
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/activation.h"
#include "mace/ops/ref/bias_add.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/arm/q8/quantization_util.h" #include "mace/ops/arm/q8/quantization_util.h"
// We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it // We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it
...@@ -36,9 +25,13 @@ ...@@ -36,9 +25,13 @@
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_base.h" #include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/delegator/activation.h"
#include "mace/ops/delegator/bias_add.h"
#include "mace/ops/delegator/depthwise_conv_2d.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/core/quantize.h" #include "mace/core/quantize.h"
...@@ -75,9 +68,16 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase { ...@@ -75,9 +68,16 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
public: public:
explicit DepthwiseConv2dOp(OpConstructContext *context) explicit DepthwiseConv2dOp(OpConstructContext *context)
: DepthwiseConv2dOpBase(context), : DepthwiseConv2dOpBase(context),
activation_delegator_(activation_, activation_delegator_(
relux_max_limit_, delegator::Activation::Create(
leakyrelu_coefficient_) {} context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -92,67 +92,44 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase { ...@@ -92,67 +92,44 @@ class DepthwiseConv2dOp<DeviceType::CPU, float> : public DepthwiseConv2dOpBase {
MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output); MACE_CHECK_NOTNULL(output);
#ifdef MACE_ENABLE_NEON if (depthwise_conv2d_delegator_ == nullptr) {
std::string tag = MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float,
REF, General);
if (MACE_CPU_IMPL_TYPE == NEON) {
const index_t filter_h = filter->dim(2); const index_t filter_h = filter->dim(2);
const index_t filter_w = filter->dim(3); const index_t filter_w = filter->dim(3);
const index_t stride_h = strides_[0]; const index_t stride_h = strides_[0];
const index_t stride_w = strides_[1]; const index_t stride_w = strides_[1];
const index_t dilation_h = dilations_[0]; const index_t dilation_h = dilations_[0];
const index_t dilation_w = dilations_[1]; const index_t dilation_w = dilations_[1];
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1 if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) { && dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) { tag = MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float,
conv2d_delegator_ = MACE_CPU_IMPL_TYPE, K3x3S1);
make_unique<arm::fp32::DepthwiseConv2dK3x3S1>(paddings_, } else if (filter_h == 3 && filter_w == 3 && stride_h == 2
padding_type_); && stride_w == 2
}
conv2d_delegator_->Compute(context, input, filter, output);
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) { && dilation_h == 1 && dilation_w == 1) {
if (conv2d_delegator_.get() == nullptr) { tag = MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float,
conv2d_delegator_ = MACE_CPU_IMPL_TYPE, K3x3S2);
make_unique<arm::fp32::DepthwiseConv2dK3x3S2>(paddings_,
padding_type_);
}
conv2d_delegator_->Compute(context, input, filter, output);
} else {
if (ref_conv2d_delegator_.get() == nullptr) {
ref_conv2d_delegator_ =
make_unique<ref::DepthwiseConv2d<float>>(strides_,
dilations_,
paddings_,
padding_type_);
} }
ref_conv2d_delegator_->Compute(context, input, filter, output);
} }
#else delegator::Conv2dParam param(strides_, dilations_,
if (ref_conv2d_delegator_.get() == nullptr) { paddings_, padding_type_);
ref_conv2d_delegator_ = depthwise_conv2d_delegator_ = delegator::DepthwiseConv2d::Create(
make_unique<ref::DepthwiseConv2d<float>>(strides_, context->workspace(), tag, param);
dilations_,
paddings_,
padding_type_);
} }
ref_conv2d_delegator_->Compute(context, input, filter, output);
#endif // MACE_ENABLE_NEON
bias_add_delegator_.Compute(context, output, bias, output); depthwise_conv2d_delegator_->Compute(context, input, filter, output);
activation_delegator_.Compute(context, output, output); bias_add_delegator_->Compute(context, output, bias, output);
activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
std::unique_ptr<arm::fp32::Conv2dBase> conv2d_delegator_; std::unique_ptr<delegator::BiasAdd> bias_add_delegator_;
arm::fp32::BiasAdd bias_add_delegator_; std::unique_ptr<delegator::DepthwiseConv2d> depthwise_conv2d_delegator_;
arm::fp32::Activation activation_delegator_;
#else
ref::BiasAdd bias_add_delegator_;
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
std::unique_ptr<ref::DepthwiseConv2d<float>> ref_conv2d_delegator_;
protected: protected:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS); MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
...@@ -422,7 +399,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, float> : public DepthwiseConv2dOpBase { ...@@ -422,7 +399,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, float> : public DepthwiseConv2dOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) { void RegisterDepthwiseConv2d(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthwiseConv2d", MACE_REGISTER_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::CPU, float); DepthwiseConv2dOp, DeviceType::CPU, float);
......
...@@ -12,33 +12,22 @@ ...@@ -12,33 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/deconv_2d.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#include "mace/ops/arm/fp32/depthwise_deconv_2d_general.h"
#include "mace/ops/arm/fp32/depthwise_deconv_2d_3x3.h"
#include "mace/ops/arm/fp32/depthwise_deconv_2d_4x4.h"
#include "mace/ops/arm/fp32/bias_add.h"
#include "mace/ops/arm/fp32/activation.h"
#else
#include "mace/ops/ref/depthwise_deconv_2d.h"
#include "mace/ops/ref/bias_add.h"
#include "mace/ops/ref/activation.h"
#endif
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/utils/math.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/deconv_2d.h"
#include "mace/ops/delegator/activation.h"
#include "mace/ops/delegator/bias_add.h"
#include "mace/ops/delegator/depthwise_deconv_2d.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/utils/math.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
...@@ -48,6 +37,10 @@ ...@@ -48,6 +37,10 @@
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace {
const std::vector<int> kDepthwiseStrides = {1, 1};
}
template<DeviceType D, class T> template<DeviceType D, class T>
class DepthwiseDeconv2dOp; class DepthwiseDeconv2dOp;
...@@ -57,9 +50,16 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float> ...@@ -57,9 +50,16 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
public: public:
explicit DepthwiseDeconv2dOp(OpConstructContext *context) explicit DepthwiseDeconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context), : Deconv2dOpBase(context),
activation_delegator_(activation_, activation_delegator_(
relux_max_limit_, delegator::Activation::Create(
leakyrelu_coefficient_) {} context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(activation_, relux_max_limit_,
leakyrelu_coefficient_))),
bias_add_delegator_(delegator::BiasAdd::Create(
context->workspace(),
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(0);
...@@ -74,7 +74,8 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float> ...@@ -74,7 +74,8 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
const index_t in_channels = input->dim(1); const index_t in_channels = input->dim(1);
bool is_depthwise = group_ == in_channels; bool is_depthwise = group_ == in_channels;
#ifdef MACE_ENABLE_NEON if (depthwise_deconv2d_delegator_ == nullptr) {
if (MACE_CPU_IMPL_TYPE == NEON) {
const index_t kernel_h = filter->dim(2); const index_t kernel_h = filter->dim(2);
const index_t kernel_w = filter->dim(3); const index_t kernel_w = filter->dim(3);
bool use_neon_3x3_s1 = kernel_h == kernel_w && kernel_h == 3 && bool use_neon_3x3_s1 = kernel_h == kernel_w && kernel_h == 3 &&
...@@ -86,101 +87,64 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float> ...@@ -86,101 +87,64 @@ class DepthwiseDeconv2dOp<DeviceType::CPU, float>
bool use_neon_4x4_s2 = kernel_h == kernel_w && kernel_h == 4 && bool use_neon_4x4_s2 = kernel_h == kernel_w && kernel_h == 4 &&
strides_[0] == strides_[1] && strides_[0] == 2; strides_[0] == strides_[1] && strides_[0] == 2;
if (deconv2d_delegator_ == nullptr) {
if (is_depthwise) { if (is_depthwise) {
std::string tag = MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float,
MACE_CPU_IMPL_TYPE, General);
if (use_neon_3x3_s1) { if (use_neon_3x3_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::DepthwiseDeconv2dK3x3S1>( tag = MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float,
paddings_, padding_type_, CAFFE); MACE_CPU_IMPL_TYPE, K3x3S1);
} else if (use_neon_3x3_s2) { } else if (use_neon_3x3_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::DepthwiseDeconv2dK3x3S2>( tag = MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float,
paddings_, padding_type_, CAFFE); MACE_CPU_IMPL_TYPE, K3x3S2);
} else if (use_neon_4x4_s1) { } else if (use_neon_4x4_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::DepthwiseDeconv2dK4x4S1>( tag = MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float,
paddings_, padding_type_, CAFFE); MACE_CPU_IMPL_TYPE, K4x4S1);
} else if (use_neon_4x4_s2) { } else if (use_neon_4x4_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::DepthwiseDeconv2dK4x4S2>( tag = MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float,
paddings_, padding_type_, CAFFE); MACE_CPU_IMPL_TYPE, K4x4S2);
} else {
deconv2d_delegator_ =
make_unique<arm::fp32::DepthwiseDeconv2dGeneral>(
strides_,
std::vector<int>{1, 1},
paddings_,
padding_type_,
CAFFE);
} }
delegator::DepthwiseDeconv2dParam param(strides_, kDepthwiseStrides,
paddings_, padding_type_,
CAFFE, group_);
depthwise_deconv2d_delegator_ = delegator::DepthwiseDeconv2d::Create(
context->workspace(), tag, param);
} else { } else {
std::string tag = MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float,
MACE_CPU_IMPL_TYPE, General);
if (use_neon_3x3_s1) { if (use_neon_3x3_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::GroupDeconv2dK3x3S1>( tag = MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float,
paddings_, padding_type_, group_, CAFFE); MACE_CPU_IMPL_TYPE, K3x3S1);
} else if (use_neon_3x3_s2) { } else if (use_neon_3x3_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::GroupDeconv2dK3x3S2>( tag = MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float,
paddings_, padding_type_, group_, CAFFE); MACE_CPU_IMPL_TYPE, K3x3S2);
} else if (use_neon_4x4_s1) { } else if (use_neon_4x4_s1) {
deconv2d_delegator_ = make_unique<arm::fp32::GroupDeconv2dK4x4S1>( tag = MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float,
paddings_, padding_type_, group_, CAFFE); MACE_CPU_IMPL_TYPE, K4x4S1);
} else if (use_neon_4x4_s2) { } else if (use_neon_4x4_s2) {
deconv2d_delegator_ = make_unique<arm::fp32::GroupDeconv2dK4x4S2>( tag = MACE_DELEGATOR_KEY_EX(GroupDeconv2d, CPU, float,
paddings_, padding_type_, group_, CAFFE); MACE_CPU_IMPL_TYPE, K4x4S2);
} else {
deconv2d_delegator_ = make_unique<arm::fp32::GroupDeconv2dGeneral>(
strides_,
std::vector<int>{1, 1},
paddings_,
padding_type_,
group_,
CAFFE);
}
} }
delegator::GroupDeconv2dParam param(strides_, kDepthwiseStrides,
paddings_, padding_type_,
CAFFE, group_);
depthwise_deconv2d_delegator_ = delegator::GroupDeconv2d::Create(
context->workspace(), tag, param);
} }
deconv2d_delegator_->Compute(context,
input,
filter,
nullptr,
output);
#else
if (deconv2d_delegator_ == nullptr) {
if (is_depthwise) {
deconv2d_delegator_ = make_unique<ref::DepthwiseDeconv2d<float>>(
strides_,
std::vector<int>{1, 1},
paddings_,
padding_type_,
CAFFE);
} else {
deconv2d_delegator_ = make_unique<ref::GroupDeconv2d<float>>(
strides_,
std::vector<int>{1, 1},
paddings_,
padding_type_,
group_,
CAFFE);
} }
} }
deconv2d_delegator_->Compute(context,
input,
filter,
nullptr,
output);
#endif
bias_add_delegator_.Compute(context, output, bias, output); depthwise_deconv2d_delegator_->Compute(context, input, filter,
activation_delegator_.Compute(context, output, output); nullptr, output);
bias_add_delegator_->Compute(context, output, bias, output);
activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
std::unique_ptr<arm::fp32::Deconv2dBase> deconv2d_delegator_; std::unique_ptr<delegator::BiasAdd> bias_add_delegator_;
arm::fp32::BiasAdd bias_add_delegator_; std::unique_ptr<delegator::DepthwiseDeconv2d> depthwise_deconv2d_delegator_;
arm::fp32::Activation activation_delegator_;
#else
std::unique_ptr<ref::GroupDeconv2d<float>> deconv2d_delegator_;
ref::BiasAdd bias_add_delegator_;
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -251,7 +215,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase { ...@@ -251,7 +215,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry) { void RegisterDepthwiseDeconv2d(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthwiseDeconv2d", MACE_REGISTER_OP(op_registry, "DepthwiseDeconv2d",
DepthwiseDeconv2dOp, DeviceType::CPU, float); DepthwiseDeconv2dOp, DeviceType::CPU, float);
......
...@@ -35,14 +35,13 @@ ...@@ -35,14 +35,13 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/lstm.h" #include "mace/ops/common/lstm.h"
#include "mace/ops/delegator/gemv.h"
#ifdef MACE_ENABLE_NEON #ifdef MACE_ENABLE_NEON
#include <arm_neon.h> #include <arm_neon.h>
#include "mace/ops/arm/fp32/gemv.h"
#else
#include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
namespace mace { namespace mace {
...@@ -73,7 +72,11 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -73,7 +72,11 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
cell_cache_indexes_( cell_cache_indexes_(
Operation::GetRepeatedArgs<index_t>("cell_cache_indexes")), Operation::GetRepeatedArgs<index_t>("cell_cache_indexes")),
out_cache_indexes_( out_cache_indexes_(
Operation::GetRepeatedArgs<index_t>("out_cache_indexes")) {} Operation::GetRepeatedArgs<index_t>("out_cache_indexes")),
gemv_(delegator::Gemv::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemv, CPU, T, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
inline void Validate() { inline void Validate() {
const Tensor *input = this->Input(0); const Tensor *input = this->Input(0);
...@@ -316,7 +319,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -316,7 +319,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
prev_out_buf_data + i % out_buf_chunk * prev_out_dim_, prev_out_buf_data + i % out_buf_chunk * prev_out_dim_,
prev_out_dim_ * sizeof(float)); prev_out_dim_ * sizeof(float));
// Affine // Affine
gemv_.Compute(context, gemv_->Compute(context,
weights_a, weights_a,
&affine_a_in, &affine_a_in,
bias_a, bias_a,
...@@ -343,7 +346,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -343,7 +346,7 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
affine_b_in_data); affine_b_in_data);
UpdateCell(curr_cell_ptr, prev_cell_dim_, scale_); UpdateCell(curr_cell_ptr, prev_cell_dim_, scale_);
// Affine // Affine
gemv_.Compute(context, gemv_->Compute(context,
weights_b, weights_b,
&affine_b_in, &affine_b_in,
bias_b, bias_b,
...@@ -404,18 +407,13 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation { ...@@ -404,18 +407,13 @@ class DynamicLSTMOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> forward_indexes_; std::vector<index_t> forward_indexes_;
std::vector<index_t> cell_cache_indexes_; std::vector<index_t> cell_cache_indexes_;
std::vector<index_t> out_cache_indexes_; std::vector<index_t> out_cache_indexes_;
std::unique_ptr<delegator::Gemv> gemv_;
#ifdef MACE_ENABLE_NEON
arm::fp32::Gemv gemv_;
#else
ref::Gemv<float> gemv_;
#endif // MACE_ENABLE_NEON
MACE_OP_INPUT_TAGS(INPUT, PREV_OUT, PREV_CELL, WEIGHTS_A, PARAMS, WEIGHTS_B); MACE_OP_INPUT_TAGS(INPUT, PREV_OUT, PREV_CELL, WEIGHTS_A, PARAMS, WEIGHTS_B);
MACE_OP_OUTPUT_TAGS(OUTPUT, OUT_CACHE, CELL_CACHE); MACE_OP_OUTPUT_TAGS(OUTPUT, OUT_CACHE, CELL_CACHE);
}; };
void RegisterDynamicLSTM(OpRegistryBase *op_registry) { void RegisterDynamicLSTM(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "DynamicLSTM", DynamicLSTMOp, MACE_REGISTER_OP(op_registry, "DynamicLSTM", DynamicLSTMOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -12,11 +12,9 @@ ...@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifdef MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/arm/q8/eltwise.h" #include "mace/ops/delegator/eltwise.h"
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#endif // MACE_ENABLE_NEON
#include "mace/ops/eltwise.h" #include "mace/ops/eltwise.h"
...@@ -28,7 +26,8 @@ ...@@ -28,7 +26,8 @@
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/core/quantize.h" #include "mace/core/quantize.h"
...@@ -1061,7 +1060,7 @@ class EltwiseOp : public Operation { ...@@ -1061,7 +1060,7 @@ class EltwiseOp : public Operation {
}; };
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
template <> template<>
class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
public: public:
explicit EltwiseOp(OpConstructContext *context) explicit EltwiseOp(OpConstructContext *context)
...@@ -1071,12 +1070,15 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -1071,12 +1070,15 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
coeff_(Operation::GetRepeatedArgs<float>("coeff")), coeff_(Operation::GetRepeatedArgs<float>("coeff")),
scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)), scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)),
scalar_input_index_(Operation::GetOptionalArg<int32_t>( scalar_input_index_(Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1)) "scalar_input_index", 1)),
#ifdef MACE_ENABLE_NEON eltwise_delegator_(delegator::Eltwise::Create(
, eltwise_(static_cast<ops::EltwiseType>(Operation::GetOptionalArg<int>( context->workspace(),
"type", static_cast<int>(ops::EltwiseType::NONE)))) MACE_DELEGATOR_KEY(Eltwise, CPU, uint8_t, MACE_CPU_IMPL_TYPE),
#endif delegator::EltwiseParam(
{} static_cast<ops::EltwiseType>(
Operation::GetOptionalArg<int>(
"type",
static_cast<int>(ops::EltwiseType::NONE)))))) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -1092,77 +1094,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -1092,77 +1094,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
MACE_CHECK(output->scale() != 0); MACE_CHECK(output->scale() != 0);
MACE_RETURN_IF_ERROR(output->Resize(input0->shape())); MACE_RETURN_IF_ERROR(output->Resize(input0->shape()));
#ifdef MACE_ENABLE_NEON return eltwise_delegator_->Compute(context, input0, input1, output);
eltwise_.Compute(context, input0, input1, output);
#else
constexpr int left_shift = 20;
const double doubled_scale = 2 * std::max(input0->scale(), input1->scale());
const double adjusted_input0_scale = input0->scale() / doubled_scale;
const double adjusted_input1_scale = input1->scale() / doubled_scale;
const double adjusted_output_scale =
doubled_scale / ((1 << left_shift) * output->scale());
int32_t input0_multiplier;
int32_t input1_multiplier;
int32_t output_multiplier;
int32_t input0_shift;
int32_t input1_shift;
int32_t output_shift;
QuantizeMultiplier(adjusted_input0_scale,
&input0_multiplier,
&input0_shift);
QuantizeMultiplier(adjusted_input1_scale,
&input1_multiplier,
&input1_shift);
QuantizeMultiplier(adjusted_output_scale,
&output_multiplier,
&output_shift);
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
auto input0_ptr = input0->data<uint8_t>();
auto input1_ptr = input1->data<uint8_t>();
auto output_ptr = output->mutable_data<uint8_t>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const int32_t offset_input0 = input0_ptr[i] - input0->zero_point();
const int32_t offset_input1 = input1_ptr[i] - input1->zero_point();
const int32_t shifted_input0 = offset_input0 * (1 << left_shift);
const int32_t shifted_input1 = offset_input1 * (1 << left_shift);
const int32_t multiplied_input0 =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input0,
input0_multiplier),
-input0_shift);
const int32_t multiplied_input1 =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input1,
input1_multiplier),
-input1_shift);
int32_t res;
if (type_ == SUM) {
res = multiplied_input0 + multiplied_input1;
} else {
res = multiplied_input0 - multiplied_input1;
}
const int32_t output_val =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(res,
output_multiplier),
-output_shift) + output->zero_point();
output_ptr[i] = Saturate<uint8_t>(output_val);
}
}, 0, output->size(), 1);
#endif // NEON
return MaceStatus::MACE_SUCCESS;
} }
private: private:
...@@ -1171,9 +1103,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation { ...@@ -1171,9 +1103,7 @@ class EltwiseOp<DeviceType::CPU, uint8_t> : public Operation {
float scalar_input_; float scalar_input_;
int32_t scalar_input_index_; int32_t scalar_input_index_;
Tensor scalar_tensor_; Tensor scalar_tensor_;
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Eltwise> eltwise_delegator_;
arm::q8::Eltwise eltwise_;
#endif
}; };
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
...@@ -1244,7 +1174,7 @@ class EltwiseOp<DeviceType::GPU, float> : public Operation { ...@@ -1244,7 +1174,7 @@ class EltwiseOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterEltwise(OpRegistryBase *op_registry) { void RegisterEltwise(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Eltwise", EltwiseOp, MACE_REGISTER_OP(op_registry, "Eltwise", EltwiseOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -53,7 +54,7 @@ class ExpandDimsOp<DeviceType::CPU, T> : public Operation { ...@@ -53,7 +54,7 @@ class ExpandDimsOp<DeviceType::CPU, T> : public Operation {
int axis_; int axis_;
}; };
void RegisterExpandDims(OpRegistryBase *op_registry) { void RegisterExpandDims(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ExpandDims", ExpandDimsOp, MACE_REGISTER_OP(op_registry, "ExpandDims", ExpandDimsOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
...@@ -176,7 +177,7 @@ class ExtractPoolingOp<DeviceType::CPU, T> : public Operation { ...@@ -176,7 +177,7 @@ class ExtractPoolingOp<DeviceType::CPU, T> : public Operation {
std::vector<float> counts_; std::vector<float> counts_;
}; };
void RegisterExtractPooling(OpRegistryBase *op_registry) { void RegisterExtractPooling(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ExtractPooling", ExtractPoolingOp, MACE_REGISTER_OP(op_registry, "ExtractPooling", ExtractPoolingOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -61,7 +62,7 @@ class FillOp<DeviceType::CPU, float> : public Operation { ...@@ -61,7 +62,7 @@ class FillOp<DeviceType::CPU, float> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterFill(OpRegistryBase *op_registry) { void RegisterFill(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Fill", FillOp, MACE_REGISTER_OP(op_registry, "Fill", FillOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -17,22 +17,12 @@ ...@@ -17,22 +17,12 @@
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include "mace/ops/delegator/activation.h"
#ifdef MACE_ENABLE_NEON #include "mace/ops/delegator/gemv.h"
#include "mace/ops/arm/fp32/gemv.h"
#include "mace/ops/arm/fp32/activation.h"
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/arm/q8/gemv.h"
#endif // MACE_ENABLE_QUANTIZE
#else
#include "mace/ops/ref/gemv.h"
#include "mace/ops/ref/activation.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
...@@ -71,9 +61,16 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase { ...@@ -71,9 +61,16 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
public: public:
explicit FullyConnectedOp(OpConstructContext *context) explicit FullyConnectedOp(OpConstructContext *context)
: FullyConnectedOpBase(context), : FullyConnectedOpBase(context),
activation_delegator_(activation_, activation_delegator_(delegator::Activation::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Activation, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::ActivationParam(activation_,
relux_max_limit_, relux_max_limit_,
leakyrelu_coefficient_) {} leakyrelu_coefficient_))),
gemv_(delegator::Gemv::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemv, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
...@@ -100,7 +97,7 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase { ...@@ -100,7 +97,7 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3); const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3);
const index_t output_size = weight->dim(0); const index_t output_size = weight->dim(0);
gemv_.Compute(context, gemv_->Compute(context,
weight, weight,
input, input,
bias, bias,
...@@ -111,19 +108,14 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase { ...@@ -111,19 +108,14 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
true, true,
output); output);
activation_delegator_.Compute(context, output, output); activation_delegator_->Compute(context, output, output);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Activation> activation_delegator_;
arm::fp32::Gemv gemv_; std::unique_ptr<delegator::Gemv> gemv_;
arm::fp32::Activation activation_delegator_;
#else
ref::Gemv<float> gemv_;
ref::Activation activation_delegator_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
...@@ -132,7 +124,11 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t> ...@@ -132,7 +124,11 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
: public FullyConnectedOpBase { : public FullyConnectedOpBase {
public: public:
explicit FullyConnectedOp(OpConstructContext *context) explicit FullyConnectedOp(OpConstructContext *context)
: FullyConnectedOpBase(context) {} : FullyConnectedOpBase(context),
gemv_(delegator::Gemv::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemv, CPU, uint8_t, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
...@@ -161,7 +157,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t> ...@@ -161,7 +157,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
const int input_size = const int input_size =
static_cast<int>(weight->dim(1) * weight->dim(2) * weight->dim(3)); static_cast<int>(weight->dim(1) * weight->dim(2) * weight->dim(3));
const int output_size = static_cast<int>(weight->dim(0)); const int output_size = static_cast<int>(weight->dim(0));
gemv_.Compute(context, gemv_->Compute(context,
weight, weight,
input, input,
bias, bias,
...@@ -175,11 +171,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t> ...@@ -175,11 +171,7 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Gemv> gemv_;
::mace::ops::arm::q8::Gemv<uint8_t> gemv_;
#else
ref::Gemv<uint8_t> gemv_;
#endif // MACE_ENABLE_NEON
}; };
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
...@@ -231,7 +223,7 @@ class FullyConnectedOp<DeviceType::GPU, float> : public FullyConnectedOpBase { ...@@ -231,7 +223,7 @@ class FullyConnectedOp<DeviceType::GPU, float> : public FullyConnectedOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterFullyConnected(OpRegistryBase *op_registry) { void RegisterFullyConnected(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "FullyConnected", MACE_REGISTER_OP(op_registry, "FullyConnected",
FullyConnectedOp, DeviceType::CPU, float); FullyConnectedOp, DeviceType::CPU, float);
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <algorithm> #include <algorithm>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -85,7 +86,7 @@ class GatherOp : public Operation { ...@@ -85,7 +86,7 @@ class GatherOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterGather(OpRegistryBase *op_registry) { void RegisterGather(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Gather", GatherOp, MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -32,7 +33,7 @@ class IdentityOp : public Operation { ...@@ -32,7 +33,7 @@ class IdentityOp : public Operation {
} }
}; };
void RegisterIdentity(OpRegistryBase *op_registry) { void RegisterIdentity(OpRegistry *op_registry) {
MACE_REGISTER_OP_BY_CLASS(op_registry, "Identity", IdentityOp, MACE_REGISTER_OP_BY_CLASS(op_registry, "Identity", IdentityOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP_BY_CLASS(op_registry, "Identity", IdentityOp, MACE_REGISTER_OP_BY_CLASS(op_registry, "Identity", IdentityOp,
......
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -162,7 +163,7 @@ class IfDefinedOp<DeviceType::CPU, T> : public Operation { ...@@ -162,7 +163,7 @@ class IfDefinedOp<DeviceType::CPU, T> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterIfDefined(OpRegistryBase *op_registry) { void RegisterIfDefined(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "IfDefined", IfDefinedOp, MACE_REGISTER_OP(op_registry, "IfDefined", IfDefinedOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
namespace mace { namespace mace {
...@@ -101,7 +102,7 @@ class InferConv2dShapeOp : public Operation { ...@@ -101,7 +102,7 @@ class InferConv2dShapeOp : public Operation {
} }
}; };
void RegisterInferConv2dShape(OpRegistryBase *op_registry) { void RegisterInferConv2dShape(OpRegistry *op_registry) {
MACE_REGISTER_OP_BY_CLASS(op_registry, "InferConv2dShape", MACE_REGISTER_OP_BY_CLASS(op_registry, "InferConv2dShape",
InferConv2dShapeOp, DeviceType::CPU, float); InferConv2dShapeOp, DeviceType::CPU, float);
MACE_REGISTER_OP_BY_CLASS(op_registry, "InferConv2dShape", MACE_REGISTER_OP_BY_CLASS(op_registry, "InferConv2dShape",
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -167,7 +168,7 @@ class KaldiBatchNormOp<DeviceType::CPU, float> : public Operation { ...@@ -167,7 +168,7 @@ class KaldiBatchNormOp<DeviceType::CPU, float> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterKaldiBatchNorm(OpRegistryBase *op_registry) { void RegisterKaldiBatchNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "KaldiBatchNorm", KaldiBatchNormOp, MACE_REGISTER_OP(op_registry, "KaldiBatchNorm", KaldiBatchNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -91,7 +92,7 @@ class LocalResponseNormOp<DeviceType::CPU, float> : public Operation { ...@@ -91,7 +92,7 @@ class LocalResponseNormOp<DeviceType::CPU, float> : public Operation {
float beta_; float beta_;
}; };
void RegisterLocalResponseNorm(OpRegistryBase *op_registry) { void RegisterLocalResponseNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "LocalResponseNorm", MACE_REGISTER_OP(op_registry, "LocalResponseNorm",
LocalResponseNormOp, DeviceType::CPU, float); LocalResponseNormOp, DeviceType::CPU, float);
} }
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/lpnorm.h" #include "mace/ops/opencl/image/lpnorm.h"
...@@ -147,7 +148,7 @@ class LpNormOp<DeviceType::GPU, float> : public Operation { ...@@ -147,7 +148,7 @@ class LpNormOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterLpNorm(OpRegistryBase *op_registry) { void RegisterLpNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "LpNorm", LpNormOp, MACE_REGISTER_OP(op_registry, "LpNorm", LpNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "LpNorm", LpNormOp); MACE_REGISTER_GPU_OP(op_registry, "LpNorm", LpNormOp);
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/lstm.h" #include "mace/ops/common/lstm.h"
namespace mace { namespace mace {
...@@ -100,7 +101,7 @@ class LSTMNonlinearOp<DeviceType::CPU, T> : public Operation { ...@@ -100,7 +101,7 @@ class LSTMNonlinearOp<DeviceType::CPU, T> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterLSTMNonlinear(OpRegistryBase *op_registry) { void RegisterLSTMNonlinear(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "LSTMNonlinear", LSTMNonlinearOp, MACE_REGISTER_OP(op_registry, "LSTMNonlinear", LSTMNonlinearOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -19,25 +19,18 @@ ...@@ -19,25 +19,18 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/delegator/gemm.h"
#include "mace/ops/delegator/gemv.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/arm/fp32/gemv.h"
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/common/gemmlowp_util.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/q8/gemv.h" #include "mace/ops/arm/q8/gemv.h"
#endif // MACE_ENABLE_QUANTIZE
#else
#include "mace/ops/ref/gemm.h"
#include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON #endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/common/gemmlowp_util.h"
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -103,7 +96,15 @@ template<> ...@@ -103,7 +96,15 @@ template<>
class MatMulOp<CPU, float> : public MatMulOpBase { class MatMulOp<CPU, float> : public MatMulOpBase {
public: public:
explicit MatMulOp(OpConstructContext *context) explicit MatMulOp(OpConstructContext *context)
: MatMulOpBase(context) {} : MatMulOpBase(context),
gemm_(delegator::Gemm::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemm, CPU, float, MACE_CPU_IMPL_TYPE),
delegator::GemmParam())),
gemv_(delegator::Gemv::Create(
context->workspace(),
MACE_DELEGATOR_KEY(Gemv, CPU, float, MACE_CPU_IMPL_TYPE),
DelegatorParam())) {}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
Validate(); Validate();
...@@ -154,7 +155,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -154,7 +155,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
MACE_RETURN_IF_ERROR(C->Resize(output_shape)); MACE_RETURN_IF_ERROR(C->Resize(output_shape));
if (rows == 1 && transpose_b_) { if (rows == 1 && transpose_b_) {
return gemv_.Compute(context, return gemv_->Compute(context,
rhs, rhs,
lhs, lhs,
bias, bias,
...@@ -165,7 +166,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -165,7 +166,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
lhs_batched, lhs_batched,
C); C);
} else if (cols == 1 && !transpose_a_) { } else if (cols == 1 && !transpose_a_) {
return gemv_.Compute(context, return gemv_->Compute(context,
lhs, lhs,
rhs, rhs,
bias, bias,
...@@ -177,7 +178,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -177,7 +178,7 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
C); C);
} else { } else {
context->device()->scratch_buffer()->Rewind(); context->device()->scratch_buffer()->Rewind();
MaceStatus ret = gemm_.Compute(context, MaceStatus ret = gemm_->Compute(context,
lhs, lhs,
rhs, rhs,
batch, batch,
...@@ -217,13 +218,8 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -217,13 +218,8 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
} }
private: private:
#ifdef MACE_ENABLE_NEON std::unique_ptr<delegator::Gemm> gemm_;
arm::fp32::Gemm gemm_; std::unique_ptr<delegator::Gemv> gemv_;
arm::fp32::Gemv gemv_;
#else
ref::Gemv<float> gemv_;
ref::Gemm<float> gemm_;
#endif // MACE_ENABLE_NEON
}; };
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
...@@ -234,6 +230,10 @@ class MatMulFixpointImpl; ...@@ -234,6 +230,10 @@ class MatMulFixpointImpl;
template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder> template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder>
class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
public: public:
#ifdef MACE_ENABLE_NEON
MatMulFixpointImpl<AOrder, BOrder, uint8_t>()
: gemv_kernel_(DelegatorParam()) {}
#endif // MACE_ENABLE_NEON
void operator()(OpContext *context, void operator()(OpContext *context,
const Tensor *A, const Tensor *A,
const Tensor *B, const Tensor *B,
...@@ -318,6 +318,10 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> { ...@@ -318,6 +318,10 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder> template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder>
class MatMulFixpointImpl<AOrder, BOrder, int32_t> { class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
public: public:
#ifdef MACE_ENABLE_NEON
MatMulFixpointImpl<AOrder, BOrder, int32_t>()
: gemv_kernel_(DelegatorParam()) {}
#endif // MACE_ENABLE_NEON
void operator()(OpContext *context, void operator()(OpContext *context,
const Tensor *A, const Tensor *A,
const Tensor *B, const Tensor *B,
...@@ -592,7 +596,7 @@ class MatMulOp<CPU, float16_t> : public MatMulOpBase { ...@@ -592,7 +596,7 @@ class MatMulOp<CPU, float16_t> : public MatMulOpBase {
}; };
#endif // MACE_ENABLE_FP16_NEON #endif // MACE_ENABLE_FP16_NEON
void RegisterMatMul(OpRegistryBase *op_registry) { void RegisterMatMul(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/mvnorm.h" #include "mace/ops/opencl/image/mvnorm.h"
...@@ -165,7 +166,7 @@ class MVNormOp<DeviceType::GPU, float> : public Operation { ...@@ -165,7 +166,7 @@ class MVNormOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterMVNorm(OpRegistryBase *op_registry) { void RegisterMVNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "MVNorm", MVNormOp, MACE_REGISTER_OP(op_registry, "MVNorm", MVNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "MVNorm", MVNormOp); MACE_REGISTER_GPU_OP(op_registry, "MVNorm", MVNormOp);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -144,7 +145,7 @@ class OneHotOp<DeviceType::CPU, T> : public OneHotOpBase { ...@@ -144,7 +145,7 @@ class OneHotOp<DeviceType::CPU, T> : public OneHotOpBase {
}; };
void RegisterOneHot(OpRegistryBase *op_registry) { void RegisterOneHot(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "OneHot", OneHotOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "OneHot", OneHotOp, DeviceType::CPU, float);
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#define MACE_OPS_OPENCL_BUFFER_UTILS_H_ #define MACE_OPS_OPENCL_BUFFER_UTILS_H_
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
namespace mace { namespace mace {
...@@ -51,7 +52,7 @@ class BufferTransformOp<DeviceType::GPU, float> : public Operation { ...@@ -51,7 +52,7 @@ class BufferTransformOp<DeviceType::GPU, float> : public Operation {
MemoryType out_mem_type_; MemoryType out_mem_type_;
}; };
void RegisterBufferTransform(OpRegistryBase *op_registry) { void RegisterBufferTransform(OpRegistry *op_registry) {
MACE_REGISTER_GPU_OP(op_registry, "BufferTransform", BufferTransformOp); MACE_REGISTER_GPU_OP(op_registry, "BufferTransform", BufferTransformOp);
} }
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/opencl/image/buffer_to_image.h" #include "mace/ops/opencl/image/buffer_to_image.h"
#include "mace/ops/opencl/image/image_to_buffer.h" #include "mace/ops/opencl/image/image_to_buffer.h"
#include "mace/ops/opencl/buffer/buffer_transform.h" #include "mace/ops/opencl/buffer/buffer_transform.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/common/eltwise_type.h" #include "mace/ops/common/eltwise_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/ops/opencl/buffer_transform_kernel.h" #include "mace/ops/opencl/buffer_transform_kernel.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#ifndef MACE_OPS_OPENCL_IMAGE_LPNORM_H_ #ifndef MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#define MACE_OPS_OPENCL_IMAGE_LPNORM_H_ #define MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/common/pad_type.h" #include "mace/ops/common/pad_type.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
#include "mace/ops/common/reduce_type.h" #include "mace/ops/common/reduce_type.h"
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
#include "mace/ops/opencl/buffer_transform_kernel.h" #include "mace/ops/opencl/buffer_transform_kernel.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/ops/common/activation_type.h" #include "mace/ops/common/activation_type.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/core/runtime/opencl/opencl_helper.h" #include "mace/core/runtime/opencl/opencl_helper.h"
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/buffer_transformer.h"
#include "mace/ops/opencl/image/lstm_cell.h" #include "mace/ops/opencl/image/lstm_cell.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
...@@ -89,7 +90,7 @@ class LSTMCellOp<DeviceType::GPU, float> : public Operation { ...@@ -89,7 +90,7 @@ class LSTMCellOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterLSTMCell(OpRegistryBase *op_registry) { void RegisterLSTMCell(OpRegistry *op_registry) {
MACE_REGISTER_GPU_OP(op_registry, "LSTMCell", LSTMCellOp); MACE_REGISTER_GPU_OP(op_registry, "LSTMCell", LSTMCellOp);
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/pad_type.h" #include "mace/ops/common/pad_type.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/pad.h" #include "mace/ops/opencl/image/pad.h"
...@@ -198,7 +199,7 @@ class PadOp<DeviceType::GPU, float> : public Operation { ...@@ -198,7 +199,7 @@ class PadOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterPad(OpRegistryBase *op_registry) { void RegisterPad(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Pad", PadOp, MACE_REGISTER_OP(op_registry, "Pad", PadOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -83,7 +84,7 @@ class PadContextOp<DeviceType::CPU, T> : public Operation { ...@@ -83,7 +84,7 @@ class PadContextOp<DeviceType::CPU, T> : public Operation {
int right_context_; int right_context_;
}; };
void RegisterPadContext(OpRegistryBase *op_registry) { void RegisterPadContext(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "PadContext", PadContextOp, MACE_REGISTER_OP(op_registry, "PadContext", PadContextOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -26,7 +26,8 @@ ...@@ -26,7 +26,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -132,7 +133,7 @@ class PNormOp<DeviceType::CPU, T> : public Operation { ...@@ -132,7 +133,7 @@ class PNormOp<DeviceType::CPU, T> : public Operation {
int output_dim_; int output_dim_;
}; };
void RegisterPNorm(OpRegistryBase *op_registry) { void RegisterPNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "PNorm", PNormOp, MACE_REGISTER_OP(op_registry, "PNorm", PNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/ops/conv_pool_2d_base.h" #include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
...@@ -510,7 +511,7 @@ class PoolingOp<DeviceType::GPU, float> : public PoolingOpBase { ...@@ -510,7 +511,7 @@ class PoolingOp<DeviceType::GPU, float> : public PoolingOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterPooling(OpRegistryBase *op_registry) { void RegisterPooling(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Pooling", PoolingOp, MACE_REGISTER_OP(op_registry, "Pooling", PoolingOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -144,7 +145,7 @@ class PriorBoxOp : public Operation { ...@@ -144,7 +145,7 @@ class PriorBoxOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterPriorBox(OpRegistryBase *op_registry) { void RegisterPriorBox(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "PriorBox", PriorBoxOp, MACE_REGISTER_OP(op_registry, "PriorBox", PriorBoxOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
#include "mace/ops/common/reduce_type.h" #include "mace/ops/common/reduce_type.h"
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/runtime/cpu/cpu_runtime.h" #include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -1032,7 +1033,7 @@ class ReduceOp<DeviceType::GPU, float> : public ReduceOpBase { ...@@ -1032,7 +1033,7 @@ class ReduceOp<DeviceType::GPU, float> : public ReduceOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterReduce(OpRegistryBase *op_registry) { void RegisterReduce(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
......
...@@ -13,18 +13,26 @@ ...@@ -13,18 +13,26 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include "mace/ops/ref/activation.h"
#include "mace/ops/delegator/activation.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
Activation::Activation(ActivationType type, class Activation : public delegator::Activation {
const float limit, public:
const float leakyrelu_coefficient) explicit Activation(const delegator::ActivationParam &param)
: type_(type), : delegator::Activation(param) {}
limit_(limit), ~Activation() = default;
leakyrelu_coefficient_(leakyrelu_coefficient) {}
MaceStatus Compute(const OpContext *context, const Tensor *input,
Tensor *output) override;
private:
void DoActivation(const OpContext *context, const Tensor *input,
Tensor *output);
};
MaceStatus Activation::Compute(const OpContext *context, MaceStatus Activation::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
...@@ -99,6 +107,9 @@ void Activation::DoActivation(const OpContext *context, ...@@ -99,6 +107,9 @@ void Activation::DoActivation(const OpContext *context,
} }
} }
MACE_REGISTER_DELEGATOR(registry, Activation, delegator::ActivationParam,
MACE_DELEGATOR_KEY(Activation, CPU, float, REF))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -12,12 +12,25 @@ ...@@ -12,12 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/ref/bias_add.h" #include "mace/ops/delegator/bias_add.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
class BiasAdd : public delegator::BiasAdd {
public:
explicit BiasAdd(const DelegatorParam &param) : delegator::BiasAdd(param) {}
~BiasAdd() = default;
MaceStatus Compute(const OpContext *context, const Tensor *input,
const Tensor *bias, Tensor *output) override;
private:
void AddBias(const OpContext *context, const Tensor *input,
const Tensor *bias, Tensor *output);
};
MaceStatus BiasAdd::Compute(const OpContext *context, MaceStatus BiasAdd::Compute(const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *bias, const Tensor *bias,
...@@ -71,6 +84,9 @@ void BiasAdd::AddBias(const OpContext *context, ...@@ -71,6 +84,9 @@ void BiasAdd::AddBias(const OpContext *context,
} }
} }
MACE_REGISTER_DELEGATOR(registry, BiasAdd, DelegatorParam,
MACE_DELEGATOR_KEY(BiasAdd, CPU, float, REF))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
......
...@@ -109,6 +109,10 @@ MaceStatus Conv2d<float>::Compute(const OpContext *context, ...@@ -109,6 +109,10 @@ MaceStatus Conv2d<float>::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
typedef Conv2d<float> Conv2dRef;
MACE_REGISTER_DELEGATOR(registry, Conv2dRef, delegator::Conv2dParam,
MACE_DELEGATOR_KEY_EX(Conv2d, CPU, float, REF, General))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
......
...@@ -18,64 +18,41 @@ ...@@ -18,64 +18,41 @@
#include <vector> #include <vector>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class Conv2d { class Conv2d : public delegator::Conv2d {
public: public:
Conv2d(const std::vector<int> &strides, explicit Conv2d(const delegator::Conv2dParam &param)
const std::vector<int> &dilations, : delegator::Conv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~Conv2d() {} ~Conv2d() {}
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
}; };
template<> template<>
class Conv2d<float> { class Conv2d<float> : public delegator::Conv2d {
public: public:
Conv2d(const std::vector<int> &strides, explicit Conv2d(const delegator::Conv2dParam &param)
const std::vector<int> &dilations, : delegator::Conv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~Conv2d() {} ~Conv2d() {}
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
}; };
} // namespace ref } // namespace ref
......
...@@ -162,6 +162,11 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context, ...@@ -162,6 +162,11 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
typedef Deconv2d<float> Deconv2dRef;
MACE_REGISTER_DELEGATOR(
registry, Deconv2dRef, delegator::Deconv2dParam,
MACE_DELEGATOR_KEY_EX(Deconv2d, CPU, float, REF, General))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -18,28 +18,21 @@ ...@@ -18,28 +18,21 @@
#include <vector> #include <vector>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class Deconv2d { class Deconv2d : public delegator::Deconv2d {
public: public:
Deconv2d(const std::vector<int> &strides, explicit Deconv2d(const delegator::Deconv2dParam &param)
const std::vector<int> &dilations, : delegator::Deconv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type),
framework_type_(framework_type) {}
~Deconv2d() = default; ~Deconv2d() = default;
...@@ -48,29 +41,14 @@ class Deconv2d { ...@@ -48,29 +41,14 @@ class Deconv2d {
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
const FrameworkType framework_type_;
}; };
template<> template<>
class Deconv2d<float> { class Deconv2d<float> : public delegator::Deconv2d {
public: public:
Deconv2d(const std::vector<int> &strides, explicit Deconv2d(const delegator::Deconv2dParam &param)
const std::vector<int> &dilations, : delegator::Deconv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type),
framework_type_(framework_type) {}
~Deconv2d() = default; ~Deconv2d() = default;
...@@ -79,14 +57,7 @@ class Deconv2d<float> { ...@@ -79,14 +57,7 @@ class Deconv2d<float> {
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
const FrameworkType framework_type_;
}; };
} // namespace ref } // namespace ref
......
...@@ -115,6 +115,11 @@ MaceStatus DepthwiseConv2d<float>::Compute(const OpContext *context, ...@@ -115,6 +115,11 @@ MaceStatus DepthwiseConv2d<float>::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
typedef DepthwiseConv2d<float> DepthwiseConv2dRef;
MACE_REGISTER_DELEGATOR(
registry, DepthwiseConv2dRef, delegator::DepthwiseConv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, CPU, float, REF, General))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -18,64 +18,41 @@ ...@@ -18,64 +18,41 @@
#include <vector> #include <vector>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/depthwise_conv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class DepthwiseConv2d { class DepthwiseConv2d : public delegator::DepthwiseConv2d {
public: public:
DepthwiseConv2d(const std::vector<int> &strides, explicit DepthwiseConv2d(const delegator::DepthwiseConv2dParam &param)
const std::vector<int> &dilations, : delegator::DepthwiseConv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {} ~DepthwiseConv2d() {}
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
}; };
template<> template<>
class DepthwiseConv2d<float> { class DepthwiseConv2d<float> : public delegator::DepthwiseConv2d {
public: public:
DepthwiseConv2d(const std::vector<int> &strides, explicit DepthwiseConv2d(const delegator::DepthwiseConv2dParam &param)
const std::vector<int> &dilations, : delegator::DepthwiseConv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type) {}
~DepthwiseConv2d() {} ~DepthwiseConv2d() {}
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
}; };
} // namespace ref } // namespace ref
......
...@@ -302,6 +302,11 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context, ...@@ -302,6 +302,11 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context,
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
typedef DepthwiseDeconv2d<float> DepthwiseDeconv2dRef;
MACE_REGISTER_DELEGATOR(
registry, DepthwiseDeconv2dRef, delegator::DepthwiseDeconv2dParam,
MACE_DELEGATOR_KEY_EX(DepthwiseDeconv2d, CPU, float, REF, General))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -18,63 +18,37 @@ ...@@ -18,63 +18,37 @@
#include <vector> #include <vector>
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/delegator/depthwise_deconv_2d.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class GroupDeconv2d { class GroupDeconv2d : public delegator::GroupDeconv2d {
public: public:
GroupDeconv2d(const std::vector<int> &strides, explicit GroupDeconv2d(const delegator::GroupDeconv2dParam &param)
const std::vector<int> &dilations, : delegator::GroupDeconv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const index_t group,
const FrameworkType framework_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type),
group_(group),
framework_type_(framework_type) {}
virtual ~GroupDeconv2d() = default; virtual ~GroupDeconv2d() = default;
virtual MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
private:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
const index_t group_;
const FrameworkType framework_type_;
}; };
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class DepthwiseDeconv2d : public GroupDeconv2d<OUTPUT_TYPE> { class DepthwiseDeconv2d : public GroupDeconv2d<OUTPUT_TYPE> {
public: public:
DepthwiseDeconv2d(const std::vector<int> &strides, explicit DepthwiseDeconv2d(const delegator::DepthwiseDeconv2d &param)
const std::vector<int> &dilations, : GroupDeconv2d<OUTPUT_TYPE>(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type)
: GroupDeconv2d<OUTPUT_TYPE>(strides,
dilations,
paddings,
padding_type,
0,
framework_type) {}
~DepthwiseDeconv2d() = default; ~DepthwiseDeconv2d() = default;
...@@ -83,57 +57,30 @@ class DepthwiseDeconv2d : public GroupDeconv2d<OUTPUT_TYPE> { ...@@ -83,57 +57,30 @@ class DepthwiseDeconv2d : public GroupDeconv2d<OUTPUT_TYPE> {
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
}; };
template<> template<>
class GroupDeconv2d<float> { class GroupDeconv2d<float> : public delegator::GroupDeconv2d {
public: public:
GroupDeconv2d(const std::vector<int> &strides, explicit GroupDeconv2d(const delegator::GroupDeconv2dParam &param)
const std::vector<int> &dilations, : delegator::GroupDeconv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const index_t group,
const FrameworkType framework_type)
: strides_(strides),
dilations_(dilations),
paddings_(paddings),
padding_type_(padding_type),
group_(group),
framework_type_(framework_type) {}
virtual ~GroupDeconv2d() = default; virtual ~GroupDeconv2d() = default;
virtual MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
protected:
const std::vector<int> strides_;
const std::vector<int> dilations_;
const std::vector<int> paddings_;
const Padding padding_type_;
const index_t group_;
const FrameworkType framework_type_;
}; };
template<> template<>
class DepthwiseDeconv2d<float> : public GroupDeconv2d<float> { class DepthwiseDeconv2d<float> : public GroupDeconv2d<float> {
public: public:
DepthwiseDeconv2d(const std::vector<int> &strides, explicit DepthwiseDeconv2d(const delegator::DepthwiseDeconv2dParam &param)
const std::vector<int> &dilations, : GroupDeconv2d(param) {}
const std::vector<int> &paddings,
const Padding padding_type,
const FrameworkType framework_type)
: GroupDeconv2d<float>(strides,
dilations,
paddings,
padding_type,
0,
framework_type) {}
~DepthwiseDeconv2d() = default; ~DepthwiseDeconv2d() = default;
...@@ -142,7 +89,7 @@ class DepthwiseDeconv2d<float> : public GroupDeconv2d<float> { ...@@ -142,7 +89,7 @@ class DepthwiseDeconv2d<float> : public GroupDeconv2d<float> {
const Tensor *input, const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *output_shape, const Tensor *output_shape,
Tensor *output); Tensor *output) override;
}; };
} // namespace ref } // namespace ref
......
...@@ -111,6 +111,10 @@ MaceStatus Gemm<float>::Compute(const OpContext *context, ...@@ -111,6 +111,10 @@ MaceStatus Gemm<float>::Compute(const OpContext *context,
output); output);
} }
typedef Gemm<float> GemmRef;
MACE_REGISTER_DELEGATOR(registry, GemmRef, delegator::GemmParam,
MACE_DELEGATOR_KEY(Gemm, CPU, float, REF))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -16,19 +16,20 @@ ...@@ -16,19 +16,20 @@
#ifndef MACE_OPS_REF_GEMM_H_ #ifndef MACE_OPS_REF_GEMM_H_
#define MACE_OPS_REF_GEMM_H_ #define MACE_OPS_REF_GEMM_H_
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/common/matrix.h" #include "mace/ops/common/matrix.h"
#include "mace/ops/delegator/gemm.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class Gemm { class Gemm : public delegator::Gemm {
public: public:
Gemm() {} explicit Gemm(const delegator::GemmParam &param) : delegator::Gemm(param) {}
~Gemm() {} ~Gemm() {}
MaceStatus Compute(const OpContext *context, MaceStatus Compute(const OpContext *context,
const Tensor *lhs, const Tensor *lhs,
...@@ -42,13 +43,13 @@ class Gemm { ...@@ -42,13 +43,13 @@ class Gemm {
const MatrixMajor output_major, const MatrixMajor output_major,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
template<> template<>
class Gemm<float> { class Gemm<float> : public delegator::Gemm {
public: public:
Gemm() {} explicit Gemm(const delegator::GemmParam &param) : delegator::Gemm(param) {}
~Gemm() {} ~Gemm() {}
MaceStatus Compute(const OpContext *context, MaceStatus Compute(const OpContext *context,
const Tensor *lhs, const Tensor *lhs,
...@@ -62,7 +63,7 @@ class Gemm<float> { ...@@ -62,7 +63,7 @@ class Gemm<float> {
const MatrixMajor output_major, const MatrixMajor output_major,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
// Original matrix before transpose has row-major // Original matrix before transpose has row-major
MaceStatus Compute( MaceStatus Compute(
const OpContext *context, const OpContext *context,
...@@ -78,7 +79,7 @@ class Gemm<float> { ...@@ -78,7 +79,7 @@ class Gemm<float> {
const bool transpose_out, const bool transpose_out,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
} // namespace ref } // namespace ref
......
...@@ -159,8 +159,16 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context, ...@@ -159,8 +159,16 @@ MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
} // b } // b
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
typedef Gemv<uint8_t> GemvUint8Ref;
MACE_REGISTER_DELEGATOR(registry, GemvUint8Ref, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, CPU, uint8_t, Ref))
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
typedef Gemv<float> GemvRef;
MACE_REGISTER_DELEGATOR(registry, GemvRef, DelegatorParam,
MACE_DELEGATOR_KEY(Gemv, CPU, float, REF))
} // namespace ref } // namespace ref
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -16,18 +16,19 @@ ...@@ -16,18 +16,19 @@
#ifndef MACE_OPS_REF_GEMV_H_ #ifndef MACE_OPS_REF_GEMV_H_
#define MACE_OPS_REF_GEMV_H_ #define MACE_OPS_REF_GEMV_H_
#include "mace/public/mace.h" #include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h" #include "mace/ops/delegator/gemv.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace ref { namespace ref {
template<typename OUTPUT_TYPE> template<typename OUTPUT_TYPE>
class Gemv { class Gemv : public delegator::Gemv {
public: public:
Gemv() {} explicit Gemv(const DelegatorParam &param) : delegator::Gemv(param) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -40,13 +41,13 @@ class Gemv { ...@@ -40,13 +41,13 @@ class Gemv {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
template<> template<>
class Gemv<float> { class Gemv<float> : public delegator::Gemv {
public: public:
Gemv() {} explicit Gemv(const DelegatorParam &param) : delegator::Gemv(param) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -59,14 +60,14 @@ class Gemv<float> { ...@@ -59,14 +60,14 @@ class Gemv<float> {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
#if defined(MACE_ENABLE_QUANTIZE) #if defined(MACE_ENABLE_QUANTIZE)
template<> template<>
class Gemv<uint8_t> { class Gemv<uint8_t> : public delegator::Gemv {
public: public:
Gemv() {} explicit Gemv(const DelegatorParam &param) : delegator::Gemv(param) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -79,13 +80,13 @@ class Gemv<uint8_t> { ...@@ -79,13 +80,13 @@ class Gemv<uint8_t> {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
template<> template<>
class Gemv<int32_t> { class Gemv<int32_t> : public delegator::Gemv {
public: public:
Gemv() {} explicit Gemv(const DelegatorParam &param) : delegator::Gemv(param) {}
~Gemv() {} ~Gemv() {}
// Always row-major after transpose // Always row-major after transpose
MaceStatus Compute( MaceStatus Compute(
...@@ -98,7 +99,7 @@ class Gemv<int32_t> { ...@@ -98,7 +99,7 @@ class Gemv<int32_t> {
const index_t lhs_width, const index_t lhs_width,
const bool lhs_batched, const bool lhs_batched,
const bool rhs_batched, const bool rhs_batched,
Tensor *output); Tensor *output) override;
}; };
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <arm_neon.h>
#include <algorithm>
#include "mace/ops/common/gemmlowp_util.h"
#include "mace/ops/delegator/eltwise.h"
#include "mace/utils/logging.h"
namespace mace {
namespace ops {
namespace ref {
namespace q8 {
class Eltwise : public delegator::Eltwise {
public:
explicit Eltwise(const delegator::EltwiseParam &param)
: delegator::Eltwise(param) {}
~Eltwise() = default;
MaceStatus Compute(const OpContext *context, const Tensor *input0,
const Tensor *input1, Tensor *output) override;
};
MaceStatus Eltwise::Compute(const OpContext *context,
const Tensor *input0,
const Tensor *input1,
Tensor *output) {
constexpr int left_shift = 20;
const double doubled_scale = 2 * std::max(input0->scale(), input1->scale());
const double adjusted_input0_scale = input0->scale() / doubled_scale;
const double adjusted_input1_scale = input1->scale() / doubled_scale;
const double adjusted_output_scale =
doubled_scale / ((1 << left_shift) * output->scale());
int32_t input0_multiplier;
int32_t input1_multiplier;
int32_t output_multiplier;
int32_t input0_shift;
int32_t input1_shift;
int32_t output_shift;
QuantizeMultiplier(adjusted_input0_scale,
&input0_multiplier,
&input0_shift);
QuantizeMultiplier(adjusted_input1_scale,
&input1_multiplier,
&input1_shift);
QuantizeMultiplier(adjusted_output_scale,
&output_multiplier,
&output_shift);
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1);
Tensor::MappingGuard output_guard(output);
auto input0_ptr = input0->data<uint8_t>();
auto input1_ptr = input1->data<uint8_t>();
auto output_ptr = output->mutable_data<uint8_t>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const int32_t offset_input0 = input0_ptr[i] - input0->zero_point();
const int32_t offset_input1 = input1_ptr[i] - input1->zero_point();
const int32_t shifted_input0 = offset_input0 * (1 << left_shift);
const int32_t shifted_input1 = offset_input1 * (1 << left_shift);
const int32_t multiplied_input0 =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input0,
input0_multiplier),
-input0_shift);
const int32_t multiplied_input1 =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(shifted_input1,
input1_multiplier),
-input1_shift);
int32_t res;
if (type_ == SUM) {
res = multiplied_input0 + multiplied_input1;
} else {
res = multiplied_input0 - multiplied_input1;
}
const int32_t output_val =
gemmlowp::RoundingDivideByPOT(
gemmlowp::SaturatingRoundingDoublingHighMul(res,
output_multiplier),
-output_shift) + output->zero_point();
output_ptr[i] = Saturate<uint8_t>(output_val);
}
}, 0, output->size(), 1);
return MaceStatus::MACE_SUCCESS;
}
MACE_REGISTER_DELEGATOR(registry, Eltwise, delegator::EltwiseParam,
MACE_DELEGATOR_KEY(Eltwise, CPU, uint8_t, REF))
} // namespace q8
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/registry/registry.h"
namespace mace {
namespace ops {
namespace ref {
extern void RegisterActivationDelegator(OpDelegatorRegistry *registry);
extern void RegisterBiasAddDelegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dRefDelegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dRefDelegator(OpDelegatorRegistry *registry);
extern void RegisterDepthwiseConv2dRefDelegator(OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dRefDelegator(
OpDelegatorRegistry *registry);
extern void RegisterGemmRefDelegator(OpDelegatorRegistry *registry);
extern void RegisterGemvRefDelegator(OpDelegatorRegistry *registry);
#ifdef MACE_ENABLE_QUANTIZE
namespace q8 {
extern void RegisterEltwiseDelegator(OpDelegatorRegistry *registry);
} // namespace q8
extern void RegisterGemvUint8RefDelegator(OpDelegatorRegistry *registry);
#endif // MACE_ENABLE_QUANTIZE
} // namespace ref
#ifdef MACE_ENABLE_NEON
namespace arm {
namespace fp32 {
extern void RegisterActivationDelegator(OpDelegatorRegistry *registry);
extern void RegisterBiasAddDelegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK1x1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK1x7S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK7x1S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK1x15S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK15x1S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK3x3S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK3x3S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK3x3WinogradDelegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK5x5S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK7x7S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK7x7S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dK7x7S3Delegator(OpDelegatorRegistry *registry);
extern void RegisterConv2dGeneralDelegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK2x2S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK2x2S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK3x3S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK3x3S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK4x4S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dK4x4S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterDeconv2dGeneralDelegator(OpDelegatorRegistry *registry);
extern void RegisterDepthwiseConv2dK3x3S1Delegator(
OpDelegatorRegistry *registry);
extern void RegisterDepthwiseConv2dK3x3S2Delegator(
OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dK3x3S1Delegator(
OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dK3x3S2Delegator(
OpDelegatorRegistry *registry);
extern void RegisterGroupDeconv2dK3x3S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterGroupDeconv2dK3x3S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dK4x4S1Delegator(
OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dK4x4S2Delegator(
OpDelegatorRegistry *registry);
extern void RegisterGroupDeconv2dK4x4S1Delegator(OpDelegatorRegistry *registry);
extern void RegisterGroupDeconv2dK4x4S2Delegator(OpDelegatorRegistry *registry);
extern void RegisterDepthwiseDeconv2dGeneralDelegator(
OpDelegatorRegistry *registry);
extern void RegisterGroupDeconv2dGeneralDelegator(
OpDelegatorRegistry *registry);
extern void RegisterGemmDelegator(OpDelegatorRegistry *registry);
extern void RegisterGemvDelegator(OpDelegatorRegistry *registry);
} // namespace fp32
#ifdef MACE_ENABLE_QUANTIZE
namespace q8 {
extern void RegisterEltwiseDelegator(OpDelegatorRegistry *registry);
extern void RegisterGemvUint8Delegator(OpDelegatorRegistry *registry);
extern void RegisterGemvInt32Delegator(OpDelegatorRegistry *registry);
} // namespace q8
#endif // MACE_ENABLE_QUANTIZE
} // namespace arm
#endif // MACE_ENABLE_NEON
void RegisterAllOpDelegators(OpDelegatorRegistry *registry) {
ref::RegisterActivationDelegator(registry);
ref::RegisterBiasAddDelegator(registry);
ref::RegisterConv2dRefDelegator(registry);
ref::RegisterDeconv2dRefDelegator(registry);
ref::RegisterDepthwiseConv2dRefDelegator(registry);
ref::RegisterDepthwiseDeconv2dRefDelegator(registry);
ref::RegisterGemmRefDelegator(registry);
ref::RegisterGemvRefDelegator(registry);
#ifdef MACE_ENABLE_QUANTIZE
ref::q8::RegisterEltwiseDelegator(registry);
ref::RegisterGemvUint8RefDelegator(registry);
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_NEON
arm::fp32::RegisterActivationDelegator(registry);
arm::fp32::RegisterBiasAddDelegator(registry);
arm::fp32::RegisterConv2dK1x1Delegator(registry);
arm::fp32::RegisterConv2dK1x7S1Delegator(registry);
arm::fp32::RegisterConv2dK7x1S1Delegator(registry);
arm::fp32::RegisterConv2dK1x15S1Delegator(registry);
arm::fp32::RegisterConv2dK15x1S1Delegator(registry);
arm::fp32::RegisterConv2dK3x3S1Delegator(registry);
arm::fp32::RegisterConv2dK3x3S2Delegator(registry);
arm::fp32::RegisterConv2dK3x3WinogradDelegator(registry);
arm::fp32::RegisterConv2dK5x5S1Delegator(registry);
arm::fp32::RegisterConv2dK7x7S1Delegator(registry);
arm::fp32::RegisterConv2dK7x7S2Delegator(registry);
arm::fp32::RegisterConv2dK7x7S3Delegator(registry);
arm::fp32::RegisterConv2dGeneralDelegator(registry);
arm::fp32::RegisterDeconv2dK2x2S1Delegator(registry);
arm::fp32::RegisterDeconv2dK2x2S2Delegator(registry);
arm::fp32::RegisterDeconv2dK3x3S1Delegator(registry);
arm::fp32::RegisterDeconv2dK3x3S2Delegator(registry);
arm::fp32::RegisterDeconv2dK4x4S1Delegator(registry);
arm::fp32::RegisterDeconv2dK4x4S2Delegator(registry);
arm::fp32::RegisterDeconv2dGeneralDelegator(registry);
arm::fp32::RegisterDepthwiseConv2dK3x3S1Delegator(registry);
arm::fp32::RegisterDepthwiseConv2dK3x3S2Delegator(registry);
arm::fp32::RegisterDepthwiseDeconv2dK3x3S1Delegator(registry);
arm::fp32::RegisterDepthwiseDeconv2dK3x3S2Delegator(registry);
arm::fp32::RegisterGroupDeconv2dK3x3S1Delegator(registry);
arm::fp32::RegisterGroupDeconv2dK3x3S2Delegator(registry);
arm::fp32::RegisterDepthwiseDeconv2dK4x4S1Delegator(registry);
arm::fp32::RegisterDepthwiseDeconv2dK4x4S2Delegator(registry);
arm::fp32::RegisterGroupDeconv2dK4x4S1Delegator(registry);
arm::fp32::RegisterGroupDeconv2dK4x4S2Delegator(registry);
arm::fp32::RegisterDepthwiseDeconv2dGeneralDelegator(registry);
arm::fp32::RegisterGroupDeconv2dGeneralDelegator(registry);
arm::fp32::RegisterGemmDelegator(registry);
arm::fp32::RegisterGemvDelegator(registry);
#ifdef MACE_ENABLE_QUANTIZE
arm::q8::RegisterEltwiseDelegator(registry);
arm::q8::RegisterGemvUint8Delegator(registry);
arm::q8::RegisterGemvInt32Delegator(registry);
#endif // MACE_ENABLE_QUANTIZE
#endif // MACE_ENABLE_NEON
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved. // Copyright 2020 The MACE Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,167 +12,167 @@ ...@@ -12,167 +12,167 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/registry/ops_registry.h" #include "mace/ops/registry/registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
// Keep in lexicographical order // Keep in lexicographical order
extern void RegisterActivation(OpRegistryBase *op_registry); extern void RegisterActivation(OpRegistry *op_registry);
extern void RegisterAddN(OpRegistryBase *op_registry); extern void RegisterAddN(OpRegistry *op_registry);
extern void RegisterArgMax(OpRegistryBase *op_registry); extern void RegisterArgMax(OpRegistry *op_registry);
extern void RegisterBatchNorm(OpRegistryBase *op_registry); extern void RegisterBatchNorm(OpRegistry *op_registry);
extern void RegisterBatchToSpaceND(OpRegistryBase *op_registry); extern void RegisterBatchToSpaceND(OpRegistry *op_registry);
extern void RegisterBiasAdd(OpRegistryBase *op_registry); extern void RegisterBiasAdd(OpRegistry *op_registry);
extern void RegisterCast(OpRegistryBase *op_registry); extern void RegisterCast(OpRegistry *op_registry);
extern void RegisterChannelShuffle(OpRegistryBase *op_registry); extern void RegisterChannelShuffle(OpRegistry *op_registry);
extern void RegisterConcat(OpRegistryBase *op_registry); extern void RegisterConcat(OpRegistry *op_registry);
extern void RegisterConv2D(OpRegistryBase *op_registry); extern void RegisterConv2D(OpRegistry *op_registry);
extern void RegisterCrop(OpRegistryBase *op_registry); extern void RegisterCrop(OpRegistry *op_registry);
extern void RegisterCumsum(OpRegistryBase *op_registry); extern void RegisterCumsum(OpRegistry *op_registry);
extern void RegisterDeconv2D(OpRegistryBase *op_registry); extern void RegisterDeconv2D(OpRegistry *op_registry);
extern void RegisterDepthToSpace(OpRegistryBase *op_registry); extern void RegisterDepthToSpace(OpRegistry *op_registry);
extern void RegisterDepthwiseConv2d(OpRegistryBase *op_registry); extern void RegisterDepthwiseConv2d(OpRegistry *op_registry);
extern void RegisterDepthwiseDeconv2d(OpRegistryBase *op_registry); extern void RegisterDepthwiseDeconv2d(OpRegistry *op_registry);
extern void RegisterDynamicLSTM(OpRegistryBase *op_registry); extern void RegisterDynamicLSTM(OpRegistry *op_registry);
extern void RegisterEltwise(OpRegistryBase *op_registry); extern void RegisterEltwise(OpRegistry *op_registry);
extern void RegisterExpandDims(OpRegistryBase *op_registry); extern void RegisterExpandDims(OpRegistry *op_registry);
extern void RegisterExtractPooling(OpRegistryBase *op_registry); extern void RegisterExtractPooling(OpRegistry *op_registry);
extern void RegisterFill(OpRegistryBase *op_registry); extern void RegisterFill(OpRegistry *op_registry);
extern void RegisterFullyConnected(OpRegistryBase *op_registry); extern void RegisterFullyConnected(OpRegistry *op_registry);
extern void RegisterGather(OpRegistryBase *op_registry); extern void RegisterGather(OpRegistry *op_registry);
extern void RegisterIdentity(OpRegistryBase *op_registry); extern void RegisterIdentity(OpRegistry *op_registry);
extern void RegisterIfDefined(OpRegistryBase *op_registry); extern void RegisterIfDefined(OpRegistry *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry); extern void RegisterInferConv2dShape(OpRegistry *op_registry);
extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry); extern void RegisterKaldiBatchNorm(OpRegistry *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterLocalResponseNorm(OpRegistry *op_registry);
extern void RegisterLpNorm(OpRegistryBase *op_registry); extern void RegisterLpNorm(OpRegistry *op_registry);
extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry); extern void RegisterLSTMNonlinear(OpRegistry *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistry *op_registry);
extern void RegisterMVNorm(OpRegistryBase *op_registry); extern void RegisterMVNorm(OpRegistry *op_registry);
extern void RegisterOneHot(OpRegistryBase *op_registry); extern void RegisterOneHot(OpRegistry *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistry *op_registry);
extern void RegisterPadContext(OpRegistryBase *op_registry); extern void RegisterPadContext(OpRegistry *op_registry);
extern void RegisterPNorm(OpRegistryBase *op_registry); extern void RegisterPNorm(OpRegistry *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistry *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry); extern void RegisterReduce(OpRegistry *op_registry);
extern void RegisterReplaceIndex(OpRegistryBase *op_registry); extern void RegisterReplaceIndex(OpRegistry *op_registry);
extern void RegisterPriorBox(OpRegistryBase *op_registry); extern void RegisterPriorBox(OpRegistry *op_registry);
extern void RegisterReshape(OpRegistryBase *op_registry); extern void RegisterReshape(OpRegistry *op_registry);
extern void RegisterResizeBicubic(OpRegistryBase *op_registry); extern void RegisterResizeBicubic(OpRegistry *op_registry);
extern void RegisterResizeBilinear(OpRegistryBase *op_registry); extern void RegisterResizeBilinear(OpRegistry *op_registry);
extern void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry); extern void RegisterResizeNearestNeighbor(OpRegistry *op_registry);
extern void RegisterReverse(OpRegistryBase *op_registry); extern void RegisterReverse(OpRegistry *op_registry);
extern void RegisterScalarMath(OpRegistryBase *op_registry); extern void RegisterScalarMath(OpRegistry *op_registry);
extern void RegisterSelect(OpRegistryBase *op_registry); extern void RegisterSelect(OpRegistry *op_registry);
extern void RegisterShape(OpRegistryBase *op_registry); extern void RegisterShape(OpRegistry *op_registry);
extern void RegisterSlice(OpRegistryBase *op_registry); extern void RegisterSlice(OpRegistry *op_registry);
extern void RegisterSoftmax(OpRegistryBase *op_registry); extern void RegisterSoftmax(OpRegistry *op_registry);
extern void RegisterSpaceToBatchND(OpRegistryBase *op_registry); extern void RegisterSpaceToBatchND(OpRegistry *op_registry);
extern void RegisterSpaceToDepth(OpRegistryBase *op_registry); extern void RegisterSpaceToDepth(OpRegistry *op_registry);
extern void RegisterSplice(OpRegistryBase *op_registry); extern void RegisterSplice(OpRegistry *op_registry);
extern void RegisterSplit(OpRegistryBase *op_registry); extern void RegisterSplit(OpRegistry *op_registry);
extern void RegisterSqrDiffMean(OpRegistryBase *op_registry); extern void RegisterSqrDiffMean(OpRegistry *op_registry);
extern void RegisterSqueeze(OpRegistryBase *op_registry); extern void RegisterSqueeze(OpRegistry *op_registry);
extern void RegisterStack(OpRegistryBase *op_registry); extern void RegisterStack(OpRegistry *op_registry);
extern void RegisterStridedSlice(OpRegistryBase *op_registry); extern void RegisterStridedSlice(OpRegistry *op_registry);
extern void RegisterSubsample(OpRegistryBase *op_registry); extern void RegisterSubsample(OpRegistry *op_registry);
extern void RegisterSumGroup(OpRegistryBase *op_registry); extern void RegisterSumGroup(OpRegistry *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry); extern void RegisterTargetRMSNorm(OpRegistry *op_registry);
extern void RegisterTile(OpRegistryBase *op_registry); extern void RegisterTile(OpRegistry *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry); extern void RegisterTranspose(OpRegistry *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry); extern void RegisterUnstack(OpRegistry *op_registry);
extern void RegisterUnsqueeze(OpRegistryBase *op_registry); extern void RegisterUnsqueeze(OpRegistry *op_registry);
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
extern void RegisterDequantize(OpRegistryBase *op_registry); extern void RegisterDequantize(OpRegistry *op_registry);
extern void RegisterQuantize(OpRegistryBase *op_registry); extern void RegisterQuantize(OpRegistry *op_registry);
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
extern void RegisterBufferTransform(OpRegistryBase *op_registry); extern void RegisterBufferTransform(OpRegistry *op_registry);
extern void RegisterLSTMCell(OpRegistryBase *op_registry); extern void RegisterLSTMCell(OpRegistry *op_registry);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} // namespace ops
OpRegistry::OpRegistry() : OpRegistryBase() { void RegisterAllOps(OpRegistry *registry) {
// Keep in lexicographical order // Keep in lexicographical order
ops::RegisterActivation(this); ops::RegisterActivation(registry);
ops::RegisterAddN(this); ops::RegisterAddN(registry);
ops::RegisterArgMax(this); ops::RegisterArgMax(registry);
ops::RegisterBatchNorm(this); ops::RegisterBatchNorm(registry);
ops::RegisterBatchToSpaceND(this); ops::RegisterBatchToSpaceND(registry);
ops::RegisterBiasAdd(this); ops::RegisterBiasAdd(registry);
ops::RegisterCast(this); ops::RegisterCast(registry);
ops::RegisterChannelShuffle(this); ops::RegisterChannelShuffle(registry);
ops::RegisterConcat(this); ops::RegisterConcat(registry);
ops::RegisterConv2D(this); ops::RegisterConv2D(registry);
ops::RegisterCrop(this); ops::RegisterCrop(registry);
ops::RegisterCumsum(this); ops::RegisterCumsum(registry);
ops::RegisterDeconv2D(this); ops::RegisterDeconv2D(registry);
ops::RegisterDepthToSpace(this); ops::RegisterDepthToSpace(registry);
ops::RegisterDepthwiseConv2d(this); ops::RegisterDepthwiseConv2d(registry);
ops::RegisterDepthwiseDeconv2d(this); ops::RegisterDepthwiseDeconv2d(registry);
ops::RegisterDynamicLSTM(this); ops::RegisterDynamicLSTM(registry);
ops::RegisterEltwise(this); ops::RegisterEltwise(registry);
ops::RegisterExpandDims(this); ops::RegisterExpandDims(registry);
ops::RegisterExtractPooling(this); ops::RegisterExtractPooling(registry);
ops::RegisterFill(this); ops::RegisterFill(registry);
ops::RegisterFullyConnected(this); ops::RegisterFullyConnected(registry);
ops::RegisterGather(this); ops::RegisterGather(registry);
ops::RegisterIdentity(this); ops::RegisterIdentity(registry);
ops::RegisterIfDefined(this); ops::RegisterIfDefined(registry);
ops::RegisterInferConv2dShape(this); ops::RegisterInferConv2dShape(registry);
ops::RegisterKaldiBatchNorm(this); ops::RegisterKaldiBatchNorm(registry);
ops::RegisterLocalResponseNorm(this); ops::RegisterLocalResponseNorm(registry);
ops::RegisterLpNorm(this); ops::RegisterLpNorm(registry);
ops::RegisterLSTMNonlinear(this); ops::RegisterLSTMNonlinear(registry);
ops::RegisterMatMul(this); ops::RegisterMatMul(registry);
ops::RegisterMVNorm(this); ops::RegisterMVNorm(registry);
ops::RegisterOneHot(this); ops::RegisterOneHot(registry);
ops::RegisterPad(this); ops::RegisterPad(registry);
ops::RegisterPadContext(this); ops::RegisterPadContext(registry);
ops::RegisterPNorm(this); ops::RegisterPNorm(registry);
ops::RegisterPooling(this); ops::RegisterPooling(registry);
ops::RegisterReduce(this); ops::RegisterReduce(registry);
ops::RegisterReplaceIndex(this); ops::RegisterReplaceIndex(registry);
ops::RegisterPriorBox(this); ops::RegisterPriorBox(registry);
ops::RegisterReshape(this); ops::RegisterReshape(registry);
ops::RegisterResizeBicubic(this); ops::RegisterResizeBicubic(registry);
ops::RegisterResizeBilinear(this); ops::RegisterResizeBilinear(registry);
ops::RegisterResizeNearestNeighbor(this); ops::RegisterResizeNearestNeighbor(registry);
ops::RegisterReverse(this); ops::RegisterReverse(registry);
ops::RegisterScalarMath(this); ops::RegisterScalarMath(registry);
ops::RegisterSelect(this); ops::RegisterSelect(registry);
ops::RegisterShape(this); ops::RegisterShape(registry);
ops::RegisterSlice(this); ops::RegisterSlice(registry);
ops::RegisterSoftmax(this); ops::RegisterSoftmax(registry);
ops::RegisterSpaceToBatchND(this); ops::RegisterSpaceToBatchND(registry);
ops::RegisterSpaceToDepth(this); ops::RegisterSpaceToDepth(registry);
ops::RegisterSplice(this); ops::RegisterSplice(registry);
ops::RegisterSplit(this); ops::RegisterSplit(registry);
ops::RegisterStack(this); ops::RegisterStack(registry);
ops::RegisterStridedSlice(this); ops::RegisterStridedSlice(registry);
ops::RegisterSqrDiffMean(this); ops::RegisterSqrDiffMean(registry);
ops::RegisterSqueeze(this); ops::RegisterSqueeze(registry);
ops::RegisterSubsample(this); ops::RegisterSubsample(registry);
ops::RegisterSumGroup(this); ops::RegisterSumGroup(registry);
ops::RegisterTargetRMSNorm(this); ops::RegisterTargetRMSNorm(registry);
ops::RegisterTile(this); ops::RegisterTile(registry);
ops::RegisterTranspose(this); ops::RegisterTranspose(registry);
ops::RegisterUnstack(this); ops::RegisterUnstack(registry);
ops::RegisterUnsqueeze(this); ops::RegisterUnsqueeze(registry);
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
ops::RegisterDequantize(this); ops::RegisterDequantize(registry);
ops::RegisterQuantize(this); ops::RegisterQuantize(registry);
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
ops::RegisterBufferTransform(this); ops::RegisterBufferTransform(registry);
ops::RegisterLSTMCell(this); ops::RegisterLSTMCell(registry);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} }
} // namespace ops
} // namespace mace } // namespace mace
...@@ -12,19 +12,19 @@ ...@@ -12,19 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_REGISTRY_OPS_REGISTRY_H_ #ifndef MACE_OPS_REGISTRY_REGISTRY_H_
#define MACE_OPS_REGISTRY_OPS_REGISTRY_H_ #define MACE_OPS_REGISTRY_REGISTRY_H_
#include "mace/core/operator.h"
namespace mace { namespace mace {
class OpRegistry;
class OpDelegatorRegistry;
namespace ops {
class OpRegistry : public OpRegistryBase { void RegisterAllOps(OpRegistry *registry);
public: void RegisterAllOpDelegators(OpDelegatorRegistry *registry);
OpRegistry();
~OpRegistry() = default;
};
} // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_REGISTRY_OPS_REGISTRY_H_ #endif // MACE_OPS_REGISTRY_REGISTRY_H_
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -94,7 +95,7 @@ class ReplaceIndexOp<DeviceType::CPU, T> : public Operation { ...@@ -94,7 +95,7 @@ class ReplaceIndexOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> forward_indexes_; std::vector<index_t> forward_indexes_;
}; };
void RegisterReplaceIndex(OpRegistryBase *op_registry) { void RegisterReplaceIndex(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ReplaceIndex", ReplaceIndexOp, MACE_REGISTER_OP(op_registry, "ReplaceIndex", ReplaceIndexOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
...@@ -149,7 +150,7 @@ class ReshapeOp<GPU, float> : public Operation { ...@@ -149,7 +150,7 @@ class ReshapeOp<GPU, float> : public Operation {
}; };
#endif #endif
void RegisterReshape(OpRegistryBase *op_registry) { void RegisterReshape(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t); MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t);
MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp); MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp);
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/utils.h" #include "mace/ops/common/utils.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/resize_bicubic.h" #include "mace/ops/opencl/image/resize_bicubic.h"
...@@ -232,7 +233,7 @@ class ResizeBicubicOp<DeviceType::GPU, float> : public Operation { ...@@ -232,7 +233,7 @@ class ResizeBicubicOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterResizeBicubic(OpRegistryBase *op_registry) { void RegisterResizeBicubic(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ResizeBicubic", ResizeBicubicOp, MACE_REGISTER_OP(op_registry, "ResizeBicubic", ResizeBicubicOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/core/quantize.h" #include "mace/core/quantize.h"
#include "mace/ops/common/utils.h" #include "mace/ops/common/utils.h"
...@@ -366,7 +367,7 @@ class ResizeBilinearOp<DeviceType::GPU, float> : public Operation { ...@@ -366,7 +367,7 @@ class ResizeBilinearOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterResizeBilinear(OpRegistryBase *op_registry) { void RegisterResizeBilinear(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ResizeBilinear", ResizeBilinearOp, MACE_REGISTER_OP(op_registry, "ResizeBilinear", ResizeBilinearOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/utils.h" #include "mace/ops/common/utils.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/resize_nearest_neighbor.h" #include "mace/ops/opencl/image/resize_nearest_neighbor.h"
...@@ -172,7 +173,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, float> : public Operation { ...@@ -172,7 +173,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterResizeNearestNeighbor(OpRegistryBase *op_registry) { void RegisterResizeNearestNeighbor(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ResizeNearestNeighbor", MACE_REGISTER_OP(op_registry, "ResizeNearestNeighbor",
ResizeNearestNeighborOp, DeviceType::CPU, float); ResizeNearestNeighborOp, DeviceType::CPU, float);
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -72,7 +73,7 @@ class ReverseOp<DeviceType::CPU, T> : public Operation { ...@@ -72,7 +73,7 @@ class ReverseOp<DeviceType::CPU, T> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterReverse(OpRegistryBase *op_registry) { void RegisterReverse(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Reverse", ReverseOp, MACE_REGISTER_OP(op_registry, "Reverse", ReverseOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/eltwise.h" #include "mace/ops/eltwise.h"
namespace mace { namespace mace {
...@@ -154,7 +155,7 @@ class ScalarMathOp : public Operation { ...@@ -154,7 +155,7 @@ class ScalarMathOp : public Operation {
int32_t scalar_input_index_; int32_t scalar_input_index_;
}; };
void RegisterScalarMath(OpRegistryBase *op_registry) { void RegisterScalarMath(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ScalarMath", ScalarMathOp, MACE_REGISTER_OP(op_registry, "ScalarMath", ScalarMathOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "ScalarMath", ScalarMathOp, MACE_REGISTER_OP(op_registry, "ScalarMath", ScalarMathOp,
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
namespace mace { namespace mace {
...@@ -204,7 +205,7 @@ class SelectOp<DeviceType::CPU, float> : public Operation { ...@@ -204,7 +205,7 @@ class SelectOp<DeviceType::CPU, float> : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterSelect(OpRegistryBase *op_registry) { void RegisterSelect(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Select", SelectOp, MACE_REGISTER_OP(op_registry, "Select", SelectOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -58,7 +59,7 @@ class ShapeOp : public Operation { ...@@ -58,7 +59,7 @@ class ShapeOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterShape(OpRegistryBase *op_registry) { void RegisterShape(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Shape", ShapeOp, MACE_REGISTER_OP(op_registry, "Shape", ShapeOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -83,7 +84,7 @@ class SliceOp<DeviceType::CPU, T> : public Operation { ...@@ -83,7 +84,7 @@ class SliceOp<DeviceType::CPU, T> : public Operation {
std::vector<int> ends_; std::vector<int> ends_;
}; };
void RegisterSlice(OpRegistryBase *op_registry) { void RegisterSlice(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Slice", SliceOp, MACE_REGISTER_OP(op_registry, "Slice", SliceOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/fixpoint.h" #include "mace/ops/fixpoint.h"
...@@ -520,7 +521,7 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation { ...@@ -520,7 +521,7 @@ class SoftmaxOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterSoftmax(OpRegistryBase *op_registry) { void RegisterSoftmax(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Softmax", SoftmaxOp, MACE_REGISTER_OP(op_registry, "Softmax", SoftmaxOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/space_to_batch.h" #include "mace/ops/opencl/image/space_to_batch.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -328,7 +329,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, float> : public SpaceToBatchOpBase { ...@@ -328,7 +329,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, float> : public SpaceToBatchOpBase {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterSpaceToBatchND(OpRegistryBase *op_registry) { void RegisterSpaceToBatchND(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "SpaceToBatchND", MACE_REGISTER_OP(op_registry, "SpaceToBatchND",
SpaceToBatchNDOp, DeviceType::CPU, float); SpaceToBatchNDOp, DeviceType::CPU, float);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/space_to_depth.h" #include "mace/ops/opencl/image/space_to_depth.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -180,7 +181,7 @@ class SpaceToDepthOp<DeviceType::GPU, float> : public Operation { ...@@ -180,7 +181,7 @@ class SpaceToDepthOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterSpaceToDepth(OpRegistryBase *op_registry) { void RegisterSpaceToDepth(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "SpaceToDepth", MACE_REGISTER_OP(op_registry, "SpaceToDepth",
SpaceToDepthOp, DeviceType::CPU, float); SpaceToDepthOp, DeviceType::CPU, float);
......
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -153,7 +154,7 @@ class SpliceOp<DeviceType::CPU, T> : public Operation { ...@@ -153,7 +154,7 @@ class SpliceOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> forward_const_indexes_; std::vector<index_t> forward_const_indexes_;
}; };
void RegisterSplice(OpRegistryBase *op_registry) { void RegisterSplice(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Splice", SpliceOp, MACE_REGISTER_OP(op_registry, "Splice", SpliceOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/split.h" #include "mace/ops/opencl/image/split.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -128,7 +129,7 @@ class SplitOp<DeviceType::GPU, float> : public Operation { ...@@ -128,7 +129,7 @@ class SplitOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterSplit(OpRegistryBase *op_registry) { void RegisterSplit(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Split", SplitOp, MACE_REGISTER_OP(op_registry, "Split", SplitOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/sqrdiff_mean.h" #include "mace/ops/opencl/image/sqrdiff_mean.h"
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
...@@ -100,7 +101,7 @@ class SqrDiffMeanOp<DeviceType::GPU, float> : public Operation { ...@@ -100,7 +101,7 @@ class SqrDiffMeanOp<DeviceType::GPU, float> : public Operation {
}; };
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
void RegisterSqrDiffMean(OpRegistryBase *op_registry) { void RegisterSqrDiffMean(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "SqrDiffMean", SqrDiffMeanOp, MACE_REGISTER_OP(op_registry, "SqrDiffMean", SqrDiffMeanOp,
DeviceType::CPU, float); DeviceType::CPU, float);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -77,7 +78,7 @@ class SqueezeOp : public SqueezeOpRaw { ...@@ -77,7 +78,7 @@ class SqueezeOp : public SqueezeOpRaw {
} }
}; };
void RegisterSqueeze(OpRegistryBase *op_registry) { void RegisterSqueeze(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, float);
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, uint8_t); MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::CPU, uint8_t);
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -77,7 +78,7 @@ class StackOp : public Operation { ...@@ -77,7 +78,7 @@ class StackOp : public Operation {
int axis_; int axis_;
}; };
void RegisterStack(OpRegistryBase *op_registry) { void RegisterStack(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Stack", StackOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Stack", StackOp, DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Stack", StackOp, DeviceType::CPU, int32_t); MACE_REGISTER_OP(op_registry, "Stack", StackOp, DeviceType::CPU, int32_t);
} }
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -350,7 +351,7 @@ class StridedSliceOp : public Operation { ...@@ -350,7 +351,7 @@ class StridedSliceOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterStridedSlice(OpRegistryBase *op_registry) { void RegisterStridedSlice(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "StridedSlice", StridedSliceOp, MACE_REGISTER_OP(op_registry, "StridedSlice", StridedSliceOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "StridedSlice", StridedSliceOp, MACE_REGISTER_OP(op_registry, "StridedSlice", StridedSliceOp,
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -100,7 +101,7 @@ class SubsampleOp<DeviceType::CPU, T> : public Operation { ...@@ -100,7 +101,7 @@ class SubsampleOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> forward_indexes_; std::vector<index_t> forward_indexes_;
}; };
void RegisterSubsample(OpRegistryBase *op_registry) { void RegisterSubsample(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Subsample", SubsampleOp, MACE_REGISTER_OP(op_registry, "Subsample", SubsampleOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -101,7 +102,7 @@ class SumGroupOp<DeviceType::CPU, T> : public Operation { ...@@ -101,7 +102,7 @@ class SumGroupOp<DeviceType::CPU, T> : public Operation {
} }
}; };
void RegisterSumGroup(OpRegistryBase *op_registry) { void RegisterSumGroup(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "SumGroup", SumGroupOp, MACE_REGISTER_OP(op_registry, "SumGroup", SumGroupOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -148,7 +149,7 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation { ...@@ -148,7 +149,7 @@ class TargetRMSNormOp<DeviceType::CPU, T> : public Operation {
int block_dim_; int block_dim_;
}; };
void RegisterTargetRMSNorm(OpRegistryBase *op_registry) { void RegisterTargetRMSNorm(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "TargetRMSNorm", TargetRMSNormOp, MACE_REGISTER_OP(op_registry, "TargetRMSNorm", TargetRMSNormOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
namespace mace { namespace mace {
...@@ -110,7 +111,7 @@ class TileOp : public Operation { ...@@ -110,7 +111,7 @@ class TileOp : public Operation {
int has_data_format_; int has_data_format_;
}; };
void RegisterTile(OpRegistryBase *op_registry) { void RegisterTile(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Tile", TileOp, DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Tile", TileOp, DeviceType::CPU, float);
MACE_REGISTER_OP_CONDITION( MACE_REGISTER_OP_CONDITION(
op_registry, OpConditionBuilder("Tile").SetDevicePlacerFunc( op_registry, OpConditionBuilder("Tile").SetDevicePlacerFunc(
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/ops/common/transpose.h" #include "mace/ops/common/transpose.h"
namespace mace { namespace mace {
...@@ -64,7 +65,7 @@ class TransposeOp<D, float> : public Operation { ...@@ -64,7 +65,7 @@ class TransposeOp<D, float> : public Operation {
std::vector<int> dims_; std::vector<int> dims_;
}; };
void RegisterTranspose(OpRegistryBase *op_registry) { void RegisterTranspose(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp, MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp,
DeviceType::CPU, float); DeviceType::CPU, float);
} }
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -59,7 +60,7 @@ class UnsqueezeOp : public Operation { ...@@ -59,7 +60,7 @@ class UnsqueezeOp : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT); MACE_OP_OUTPUT_TAGS(OUTPUT);
}; };
void RegisterUnsqueeze(OpRegistryBase *op_registry) { void RegisterUnsqueeze(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp, MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp, MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -73,7 +74,7 @@ class UnstackOp : public Operation { ...@@ -73,7 +74,7 @@ class UnstackOp : public Operation {
int axis_; int axis_;
}; };
void RegisterUnstack(OpRegistryBase *op_registry) { void RegisterUnstack(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Unstack", UnstackOp, MACE_REGISTER_OP(op_registry, "Unstack", UnstackOp,
DeviceType::CPU, float); DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Unstack", UnstackOp, MACE_REGISTER_OP(op_registry, "Unstack", UnstackOp,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <algorithm> #include <algorithm>
#include "mace/utils/statistics.h" #include "mace/utils/statistics.h"
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/benchmark_utils/test_benchmark.h" #include "mace/benchmark_utils/test_benchmark.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/benchmark_utils/test_benchmark.h" #include "mace/benchmark_utils/test_benchmark.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/benchmark_utils/test_benchmark.h" #include "mace/benchmark_utils/test_benchmark.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemm.h" #include "mace/ops/arm/fp32/gemm.h"
#include "mace/ops/ref/gemm.h" #include "mace/ops/ref/gemm.h"
#include "mace/ops/testing/test_utils.h" #include "mace/ops/testing/test_utils.h"
...@@ -50,7 +50,7 @@ void TestGemmFloat32(const index_t batch, ...@@ -50,7 +50,7 @@ void TestGemmFloat32(const index_t batch,
GenerateRandomRealTypeData<float>(rhs.shape(), rhs_data); GenerateRandomRealTypeData<float>(rhs.shape(), rhs_data);
GenerateRandomRealTypeData<float>(output.shape(), output_data); GenerateRandomRealTypeData<float>(output.shape(), output_data);
} }
::mace::ops::arm::fp32::Gemm gemm; ::mace::ops::arm::fp32::Gemm gemm((delegator::GemmParam()));
utils::ThreadPool thread_pool(1, AFFINITY_NONE); utils::ThreadPool thread_pool(1, AFFINITY_NONE);
thread_pool.Init(); thread_pool.Init();
CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool); CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool);
...@@ -71,7 +71,7 @@ void TestGemmFloat32(const index_t batch, ...@@ -71,7 +71,7 @@ void TestGemmFloat32(const index_t batch,
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
expected_output.Resize({batch, rows, cols}); expected_output.Resize({batch, rows, cols});
::mace::ops::ref::Gemm<float> gemm_ref; ::mace::ops::ref::Gemm<float> gemm_ref((delegator::GemmParam()));
gemm_ref.Compute(nullptr, gemm_ref.Compute(nullptr,
&lhs, &lhs,
&rhs, &rhs,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemv.h" #include "mace/ops/arm/fp32/gemv.h"
#include "mace/ops/ref/gemv.h" #include "mace/ops/ref/gemv.h"
#include "mace/ops/testing/test_utils.h" #include "mace/ops/testing/test_utils.h"
...@@ -53,7 +53,8 @@ void TestGemvFloat32(const index_t batch, ...@@ -53,7 +53,8 @@ void TestGemvFloat32(const index_t batch,
thread_pool.Init(); thread_pool.Init();
CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool); CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool);
OpContext context(nullptr, &cpu_device); OpContext context(nullptr, &cpu_device);
::mace::ops::arm::fp32::Gemv gemv; ::mace::ops::arm::fp32::Gemv gemv =
::mace::ops::arm::fp32::Gemv(DelegatorParam());
gemv.Compute(&context, gemv.Compute(&context,
&lhs, &lhs,
&rhs, &rhs,
...@@ -67,7 +68,8 @@ void TestGemvFloat32(const index_t batch, ...@@ -67,7 +68,8 @@ void TestGemvFloat32(const index_t batch,
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
expected_output.Resize({batch, height}); expected_output.Resize({batch, height});
::mace::ops::ref::Gemv<float> gemv_ref; ::mace::ops::ref::Gemv<float> gemv_ref =
::mace::ops::ref::Gemv<float>(DelegatorParam());
gemv_ref.Compute(nullptr, gemv_ref.Compute(nullptr,
&lhs, &lhs,
&rhs, &rhs,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "mace/core/ops/op_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/q8/gemv.h" #include "mace/ops/arm/q8/gemv.h"
#include "mace/ops/ref/gemv.h" #include "mace/ops/ref/gemv.h"
#include "mace/ops/testing/test_utils.h" #include "mace/ops/testing/test_utils.h"
...@@ -58,7 +58,8 @@ void TestGemvInt32(const index_t batch, ...@@ -58,7 +58,8 @@ void TestGemvInt32(const index_t batch,
thread_pool.Init(); thread_pool.Init();
CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool); CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool);
OpContext context(nullptr, &cpu_device); OpContext context(nullptr, &cpu_device);
mace::ops::arm::q8::Gemv<int32_t> gemv; mace::ops::arm::q8::Gemv<int32_t> gemv =
mace::ops::arm::q8::Gemv<int32_t>(DelegatorParam());
gemv.Compute(&context, gemv.Compute(&context,
&lhs, &lhs,
&rhs, &rhs,
...@@ -72,7 +73,8 @@ void TestGemvInt32(const index_t batch, ...@@ -72,7 +73,8 @@ void TestGemvInt32(const index_t batch,
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
expected_output.Resize({batch, height}); expected_output.Resize({batch, height});
mace::ops::ref::Gemv<int32_t> gemv_ref; mace::ops::ref::Gemv<int32_t> gemv_ref =
mace::ops::ref::Gemv<int32_t>(DelegatorParam());
gemv_ref.Compute(nullptr, gemv_ref.Compute(nullptr,
&lhs, &lhs,
&rhs, &rhs,
...@@ -130,7 +132,8 @@ void TestGemvUint8(const index_t batch, ...@@ -130,7 +132,8 @@ void TestGemvUint8(const index_t batch,
thread_pool.Init(); thread_pool.Init();
CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool); CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool);
OpContext context(nullptr, &cpu_device); OpContext context(nullptr, &cpu_device);
mace::ops::arm::q8::Gemv<uint8_t> gemv; mace::ops::arm::q8::Gemv<uint8_t> gemv =
mace::ops::arm::q8::Gemv<uint8_t>(DelegatorParam());
gemv.Compute(&context, gemv.Compute(&context,
&lhs, &lhs,
&rhs, &rhs,
...@@ -146,7 +149,8 @@ void TestGemvUint8(const index_t batch, ...@@ -146,7 +149,8 @@ void TestGemvUint8(const index_t batch,
expected_output.SetScale(0.6); expected_output.SetScale(0.6);
expected_output.SetZeroPoint(57); expected_output.SetZeroPoint(57);
expected_output.Resize({batch, height}); expected_output.Resize({batch, height});
mace::ops::ref::Gemv<uint8_t> gemv_ref; mace::ops::ref::Gemv<uint8_t> gemv_ref =
mace::ops::ref::Gemv<uint8_t>(DelegatorParam());
gemv_ref.Compute(nullptr, gemv_ref.Compute(nullptr,
&lhs, &lhs,
&rhs, &rhs,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <fstream> #include <fstream>
#include "mace/ops/delegator/gemm.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/ops/ref/gemm.h" #include "mace/ops/ref/gemm.h"
...@@ -111,7 +112,7 @@ void Complex(const std::vector<index_t> &batch, ...@@ -111,7 +112,7 @@ void Complex(const std::vector<index_t> &batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
ref::Gemm<float> gemm; ref::Gemm<float> gemm = ref::Gemm<float>(delegator::GemmParam());
Tensor expected_output_tensor; Tensor expected_output_tensor;
std::vector<index_t> expected_output_shape({rows, cols}); std::vector<index_t> expected_output_shape({rows, cols});
expected_output_shape.insert(expected_output_shape.begin(), expected_output_shape.insert(expected_output_shape.begin(),
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "mace/core/op_context.h" #include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/gpu_device.h" #include "mace/core/runtime/opencl/gpu_device.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
...@@ -134,7 +134,7 @@ TEST(OutOfRangeCheckTest, RandomTest) { ...@@ -134,7 +134,7 @@ TEST(OutOfRangeCheckTest, RandomTest) {
std::unique_ptr<Device> device = make_unique<GPUDevice>( std::unique_ptr<Device> device = make_unique<GPUDevice>(
gpu_context.opencl_tuner()); gpu_context.opencl_tuner());
Workspace ws; Workspace ws(nullptr);
OpContext context(&ws, device.get()); OpContext context(&ws, device.get());
std::vector<index_t> buffer_shape = {batch, height, width, channels}; std::vector<index_t> buffer_shape = {batch, height, width, channels};
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/core/operator.h" #include "mace/core/ops/operator.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
namespace mace { namespace mace {
......
...@@ -31,7 +31,9 @@ ...@@ -31,7 +31,9 @@
#include "mace/core/device_context.h" #include "mace/core/device_context.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/core/workspace.h" #include "mace/core/workspace.h"
#include "mace/ops/registry/ops_registry.h" #include "mace/core/registry/ops_registry.h"
#include "mace/core/registry/op_delegator_registry.h"
#include "mace/ops/registry/registry.h"
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
...@@ -109,7 +111,12 @@ class OpTestContext { ...@@ -109,7 +111,12 @@ class OpTestContext {
class OpsTestNet { class OpsTestNet {
public: public:
OpsTestNet() : OpsTestNet() :
op_registry_(make_unique<OpRegistry>()) {} op_registry_(make_unique<OpRegistry>()),
op_delegator_registry_(make_unique<OpDelegatorRegistry>()),
ws_(op_delegator_registry_.get()) {
ops::RegisterAllOps(op_registry_.get());
ops::RegisterAllOpDelegators(op_delegator_registry_.get());
}
template <DeviceType D, typename T> template <DeviceType D, typename T>
void AddInputFromArray(const std::string &name, void AddInputFromArray(const std::string &name,
...@@ -426,7 +433,8 @@ class OpsTestNet { ...@@ -426,7 +433,8 @@ class OpsTestNet {
void Sync(); void Sync();
public: public:
std::shared_ptr<OpRegistryBase> op_registry_; std::unique_ptr<OpRegistry> op_registry_;
std::unique_ptr<OpDelegatorRegistry> op_delegator_registry_;
Workspace ws_; Workspace ws_;
std::vector<OperatorDef> op_defs_; std::vector<OperatorDef> op_defs_;
std::unique_ptr<NetBase> net_; std::unique_ptr<NetBase> net_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册