提交 f56e394a 编写于 作者: W wuchenghui

Merge branch 'master' into feature_wuch

...@@ -16,10 +16,12 @@ void SetCPUAllocator(CPUAllocator* alloc) { ...@@ -16,10 +16,12 @@ void SetCPUAllocator(CPUAllocator* alloc) {
} }
Allocator* GetDeviceAllocator(DeviceType type) { Allocator* GetDeviceAllocator(DeviceType type) {
if (type == DeviceType::CPU) { switch (type) {
return cpu_allocator(); case DeviceType::CPU:
} else { case DeviceType::NEON:
REQUIRE(false, "device type ", type, " is not supported."); return cpu_allocator();
default:
REQUIRE(false, "device type ", type, " is not supported.");
} }
return nullptr; return nullptr;
} }
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
namespace mace { namespace mace {
// 16 bytes = 32 * 4 (Neon) #ifdef __ANDROID__
// 16 bytes = 128 bits = 32 * 4 (Neon)
constexpr size_t kMaceAlignment = 16; constexpr size_t kMaceAlignment = 16;
#else
// 32 bytes = 256 bits (AVX512)
constexpr size_t kMaceAlignment = 32;
#endif
class Allocator { class Allocator {
public: public:
...@@ -41,27 +46,20 @@ class CPUAllocator: public Allocator { ...@@ -41,27 +46,20 @@ class CPUAllocator: public Allocator {
void* data = nullptr; void* data = nullptr;
#ifdef __ANDROID__ #ifdef __ANDROID__
data = memalign(kMaceAlignment, nbytes); data = memalign(kMaceAlignment, nbytes);
#elif defined(_MSC_VER)
data = _aligned_malloc(nbytes, kMaceAlignment);
#else #else
CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0); CHECK(posix_memalign(&data, kMaceAlignment, nbytes) == 0);
#endif #endif
CHECK_NOTNULL(data); CHECK_NOTNULL(data);
// TODO(heliangliang) This should be avoided sometimes
memset(data, 0, nbytes); memset(data, 0, nbytes);
return data; return data;
} }
#ifdef _MSC_VER void Delete(void* data) override {
void Delete(void* data) {
_aligned_free(data);
}
#else
void Delete(void* data) {
free(data); free(data);
} }
#endif
void CopyBytes(void* dst, const void* src, size_t size) { void CopyBytes(void* dst, const void* src, size_t size) override {
memcpy(dst, src, size); memcpy(dst, src, size);
} }
}; };
...@@ -80,6 +78,11 @@ struct DeviceContext<DeviceType::CPU> { ...@@ -80,6 +78,11 @@ struct DeviceContext<DeviceType::CPU> {
static Allocator* allocator() { return cpu_allocator(); } static Allocator* allocator() { return cpu_allocator(); }
}; };
template <>
struct DeviceContext<DeviceType::NEON> {
static Allocator* allocator() { return cpu_allocator(); }
};
Allocator* GetDeviceAllocator(DeviceType type); Allocator* GetDeviceAllocator(DeviceType type);
} // namespace mace } // namespace mace
......
...@@ -18,6 +18,13 @@ MACE_DEFINE_REGISTRY( ...@@ -18,6 +18,13 @@ MACE_DEFINE_REGISTRY(
Workspace*); Workspace*);
MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry); MACE_REGISTER_DEVICE_TYPE(DeviceType::CPU, CPUOperatorRegistry);
MACE_DEFINE_REGISTRY(
NEONOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
MACE_REGISTER_DEVICE_TYPE(DeviceType::NEON, NEONOperatorRegistry);
unique_ptr<OperatorBase> CreateOperator( unique_ptr<OperatorBase> CreateOperator(
const OperatorDef& operator_def, const OperatorDef& operator_def,
Workspace* ws, Workspace* ws,
...@@ -33,4 +40,4 @@ OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws) ...@@ -33,4 +40,4 @@ OperatorBase::OperatorBase(const OperatorDef &operator_def, Workspace *ws)
} }
} // namespace mace } // namespace mace
\ No newline at end of file
...@@ -105,7 +105,7 @@ class Operator : public OperatorBase { ...@@ -105,7 +105,7 @@ class Operator : public OperatorBase {
DataTypeToEnum<T>::v()))); DataTypeToEnum<T>::v())));
} }
} }
virtual bool Run() = 0; virtual bool Run() override = 0;
~Operator() noexcept override {} ~Operator() noexcept override {}
}; };
...@@ -145,6 +145,17 @@ MACE_DECLARE_REGISTRY( ...@@ -145,6 +145,17 @@ MACE_DECLARE_REGISTRY(
#define REGISTER_CPU_OPERATOR(name, ...) \ #define REGISTER_CPU_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) MACE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
MACE_DECLARE_REGISTRY(
NEONOperatorRegistry,
OperatorBase,
const OperatorDef&,
Workspace*);
#define REGISTER_NEON_OPERATOR_CREATOR(key, ...) \
MACE_REGISTER_CREATOR(NEONOperatorRegistry, key, __VA_ARGS__)
#define REGISTER_NEON_OPERATOR(name, ...) \
MACE_REGISTER_CLASS(NEONOperatorRegistry, name, __VA_ARGS__)
unique_ptr<OperatorBase> CreateOperator( unique_ptr<OperatorBase> CreateOperator(
const OperatorDef &operator_def, const OperatorDef &operator_def,
Workspace *ws, Workspace *ws,
......
# Description:
# Mace neon kernels.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android")
cc_library(
name = "kernels",
srcs = glob(["*.cc"]) + if_android(glob(["neon/*.cc"])),
hdrs = glob(["*.h"]) + if_android(glob(["neon/*.h"])),
deps = [
"//mace/core:core",
],
copts = ['-std=c++11'],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/neon/relu_neon.h"
namespace mace {
namespace kernels{
void NeonReluFuntion_float(const Tensor *input_tensor,
Tensor *output_tensor) {
int64 size = input_tensor->size();
output_tensor->ResizeLike(input_tensor);
const float* input = input_tensor->data<float>();
float* output = output_tensor->mutable_data<float>();
float32x4_t _zero = vdupq_n_f32(0.f);
for (; size > 0; size--) {
float32x4_t _inp = vld1q_f32(input);
float32x4_t _outp = vmaxq_f32(_inp, _zero);
vst1q_f32(output, _outp);
input += 4;
output += 4;
}
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_RELU_NEON_H_
#define MACE_KERNELS_RELU_NEON_H_
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
void NeonReluFuntion_float(const Tensor *input_tensor,
Tensor *output_tensor);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RELU_NEON_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_RELU_H_
#define MACE_KERNELS_RELU_H_
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template<typename T>
void ReluFuntion(const Tensor *input_tensor, Tensor *output_tensor) {
int64 size = input_tensor->size();
output_tensor->ResizeLike(input_tensor);
const float* input = input_tensor->data<float>();
float* output = output_tensor->mutable_data<float>();
for (int64 i = 0; i < size; ++i) {
output[i] = std::max(input[i], static_cast<T>(0));
}
}
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_RELU_H_
\ No newline at end of file
...@@ -17,6 +17,7 @@ cc_library( ...@@ -17,6 +17,7 @@ cc_library(
deps = [ deps = [
"//mace/proto:cc_proto", "//mace/proto:cc_proto",
"//mace/core:core", "//mace/core:core",
"//mace/kernels:kernels",
], ],
copts = ['-std=c++11'], copts = ['-std=c++11'],
alwayslink = 1, alwayslink = 1,
......
...@@ -4,25 +4,32 @@ ...@@ -4,25 +4,32 @@
#include "mace/ops/relu.h" #include "mace/ops/relu.h"
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
#include "mace/kernels/relu.h"
#if __ARM_NEON
#include "mace/kernels/neon/relu_neon.h"
#endif // __ARM_NEON
namespace mace { namespace mace {
template <> template <>
bool ReluOp<DeviceType::CPU, float>::Run() { bool ReluOp<DeviceType::CPU, float>::Run() {
const Tensor* X = Input(0); const Tensor* input_tensor = Input(0);
Tensor* Y = Output(0); Tensor* output_tensor = Output(0);
Y->ResizeLike(X); kernels::ReluFuntion<float>(input_tensor, output_tensor);
return true;
}
REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
const float* Xdata = X-> data<float>();
float* Ydata = Y->mutable_data<float>();
for (int i = 0; i < X->size(); ++i) {
Ydata[i] = std::max(Xdata[i], 0.f);
VLOG(0) << i << ": " << Xdata[i] << " " << Ydata[i];
}
#if __ARM_NEON
template <>
bool ReluOp<DeviceType::NEON, float>::Run() {
const Tensor* input_tensor = Input(0);
Tensor* output_tensor = Output(0);
kernels::NeonReluFuntion_float(input_tensor, output_tensor);
return true; return true;
} }
REGISTER_NEON_OPERATOR(Relu, ReluOp<DeviceType::NEON, float>);
REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>); #endif // __ARM_NEON
} // namespace mace } // namespace mace
...@@ -3,8 +3,9 @@ syntax = "proto2"; ...@@ -3,8 +3,9 @@ syntax = "proto2";
package mace; package mace;
enum DeviceType { enum DeviceType {
CPU = 0; // In default, we will use CPU. CPU = 0; // In default, we will use CPU.
GPU = 1; NEON = 1;
OPENCL = 2;
} }
enum DataType { enum DataType {
...@@ -70,4 +71,4 @@ message NetDef { ...@@ -70,4 +71,4 @@ message NetDef {
optional string version = 3; optional string version = 3;
repeated Argument arg = 4; repeated Argument arg = 4;
repeated TensorProto tensors = 5; repeated TensorProto tensors = 5;
} }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册