diff --git a/mace/core/allocator.h b/mace/core/allocator.h index 110b012bec4ea91663c17e14651519abc1a0f9f4..79efd610be61c4bde548009c9e13ec6ec8302650 100644 --- a/mace/core/allocator.h +++ b/mace/core/allocator.h @@ -12,8 +12,13 @@ namespace mace { -// 16 bytes = 32 * 4 (Neon) +#ifdef __ANDROID__ +// 16 bytes = 128 bits = 32 * 4 (Neon) constexpr size_t kMaceAlignment = 16; +#else +// 32 bytes = 256 bits (AVX512) +constexpr size_t kMaceAlignment = 32; +#endif class Allocator { public: @@ -41,25 +46,18 @@ class CPUAllocator: public Allocator { void* data = nullptr; #ifdef __ANDROID__ data = memalign(kMaceAlignment, nbytes); -#elif defined(_MSC_VER) - data = _aligned_malloc(nbytes, kMaceAlignment); #else CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); #endif CHECK_NOTNULL(data); + // TODO(heliangliang) This should be avoided sometimes memset(data, 0, nbytes); return data; } -#ifdef _MSC_VER - void Delete(void* data) { - _aligned_free(data); - } -#else void Delete(void* data) { free(data); } -#endif void CopyBytes(void* dst, const void* src, size_t size) { memcpy(dst, src, size); @@ -80,6 +78,11 @@ struct DeviceContext { static Allocator* allocator() { return cpu_allocator(); } }; +template <> +struct DeviceContext { + static Allocator* allocator() { return cpu_allocator(); } +}; + Allocator* GetDeviceAllocator(DeviceType type); } // namespace mace diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 0072b58add999b1b60c7a1ea0a3bef0172931909..078574a7d48433f3ead86e3b50962eb5370a41cb 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -18,6 +18,13 @@ MACE_DEFINE_REGISTRY( Workspace*); MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry); +MACE_DEFINE_REGISTRY( + NEONOperatorRegistry, + OperatorBase, + const OperatorDef&, + Workspace*); +MACE_REGISTER_DEVICE_TYPE(DeviceType::NEON, CPUOperatorRegistry); + unique_ptr CreateOperator( const OperatorDef& operator_def, Workspace* ws, @@ -33,4 +40,4 @@ OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) } -} // namespace mace \ No newline at end of file +} // namespace mace diff --git a/mace/core/operator.h b/mace/core/operator.h index 27e1fa16a772481406b0ce665bb61c1f620818b8..6ac672f641b27b4d2eaa39dcbba9dc1dc590a051 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -145,6 +145,17 @@ MACE_DECLARE_REGISTRY( #define REGISTER_CPU_OPERATOR(name, ...) \ MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) +MACE_DECLARE_REGISTRY( + NEONOperatorRegistry, + OperatorBase, + const OperatorDef&, + Workspace*); + +#define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \ + MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_NEON_OPERATOR(name, ...) \ + MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__) + unique_ptr CreateOperator( const OperatorDef &operator_def, Workspace *ws, diff --git a/mace/ops/relu.cc b/mace/ops/relu.cc index 94646e0f3eca5c21bd8eb4510f2c5a41c54f7a05..66662f38bbb2c9960a2c3c6b7415d0bf140218cd 100644 --- a/mace/ops/relu.cc +++ b/mace/ops/relu.cc @@ -23,6 +23,23 @@ bool ReluOp::Run() { return true; } +template <> +bool ReluOp::Run() { + const Tensor* X = Input(0); + Tensor* Y = Output(0); + Y->ResizeLike(X); + + const float* Xdata = X-> data(); + float* Ydata = Y->mutable_data(); + for (int i = 0; i < X->size(); ++i) { + Ydata[i] = std::max(Xdata[i], 0.f); + VLOG(0) << i << ": " << Xdata[i] << " " << Ydata[i]; + } + + return true; +} + REGISTER_CPU_OPERATOR(Relu, ReluOp); +REGISTER_NEON_OPERATOR(Relu, ReluOp); } // namespace mace diff --git a/mace/proto/mace.proto b/mace/proto/mace.proto index 10c37f12267b996a30265a40cbffbf89ef01bb2a..05c317d137a3d9819174f18084e186029cfe3fc3 100644 --- a/mace/proto/mace.proto +++ b/mace/proto/mace.proto @@ -3,8 +3,9 @@ syntax = "proto2"; package mace; enum DeviceType { - CPU = 0; // In default, we will use CPU. - GPU = 1; + CPU = 0; // In default, we will use CPU. + NEON = 1; + OPENCL = 2; } enum DataType { @@ -70,4 +71,4 @@ message NetDef { optional string version = 3; repeated Argument arg = 4; repeated TensorProto tensors = 5; -} \ No newline at end of file +}