Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
72ef73a5
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
72ef73a5
编写于
8月 29, 2017
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor ops and net
上级
4656b708
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
42 addition
and
50 deletion
+42
-50
mace/core/allocator.h
mace/core/allocator.h
+4
-9
mace/core/common.h
mace/core/common.h
+2
-0
mace/core/net.cc
mace/core/net.cc
+2
-2
mace/core/net.h
mace/core/net.h
+10
-4
mace/core/operator.cc
mace/core/operator.cc
+2
-2
mace/core/operator.h
mace/core/operator.h
+2
-8
mace/core/registry.h
mace/core/registry.h
+0
-4
mace/core/tensor.h
mace/core/tensor.h
+6
-10
mace/core/types.h
mace/core/types.h
+1
-1
mace/core/workspace.cc
mace/core/workspace.cc
+3
-1
mace/core/workspace.h
mace/core/workspace.h
+1
-1
mace/mace.bzl
mace/mace.bzl
+1
-1
mace/ops/BUILD
mace/ops/BUILD
+5
-4
mace/ops/relu.h
mace/ops/relu.h
+3
-3
未找到文件。
mace/core/allocator.h
浏览文件 @
72ef73a5
...
...
@@ -6,20 +6,15 @@
#ifndef MACE_CORE_ALLOCATOR_H_
#define MACE_CORE_ALLOCATOR_H_
#include <unordered_map>
#include <functional>
#include <malloc.h>
#include <cstring>
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
namespace
mace
{
// 16 bytes = 32 * 4 (Neon)
constexpr
size_t
kMaceAlignment
=
16
;
using
MemoryDeleter
=
std
::
function
<
void
(
void
*
ptr
)
>
;
class
Allocator
{
public:
Allocator
()
{}
...
...
@@ -44,9 +39,9 @@ class CPUAllocator: public Allocator {
void
*
New
(
size_t
nbytes
)
override
{
void
*
data
=
nullptr
;
#ifdef __ANDROID__
data
=
memalign
(
g
MaceAlignment
,
nbytes
);
data
=
memalign
(
k
MaceAlignment
,
nbytes
);
#elif defined(_MSC_VER)
data
=
_aligned_malloc
(
nbytes
,
g
MaceAlignment
);
data
=
_aligned_malloc
(
nbytes
,
k
MaceAlignment
);
#else
CHECK
(
posix_memalign
(
&
data
,
kMaceAlignment
,
nbytes
)
==
0
);
#endif
...
...
@@ -72,7 +67,7 @@ CPUAllocator* cpu_allocator();
// ownership of the pointer.
void
SetCPUAllocator
(
CPUAllocator
*
alloc
);
template
<
DeviceType
D
T
>
template
<
DeviceType
D
>
struct
DeviceContext
{};
template
<
>
...
...
mace/core/common.h
浏览文件 @
72ef73a5
...
...
@@ -6,6 +6,7 @@
#define MACE_CORE_COMMON_H_
#include <set>
#include <map>
#include <string>
#include <memory>
#include <vector>
...
...
@@ -15,6 +16,7 @@
#include "mace/core/logging.h"
using
std
::
set
;
using
std
::
map
;
using
std
::
string
;
using
std
::
unique_ptr
;
using
std
::
vector
;
...
...
mace/core/net.cc
浏览文件 @
72ef73a5
...
...
@@ -8,8 +8,8 @@ namespace mace {
NetBase
::
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
DeviceType
type
)
{
DeviceType
type
)
:
name_
(
net_def
->
name
())
{
}
...
...
mace/core/net.h
浏览文件 @
72ef73a5
...
...
@@ -14,7 +14,9 @@ namespace mace {
class
NetBase
{
public:
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
NetBase
(
const
std
::
shared_ptr
<
const
NetDef
>
&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
virtual
~
NetBase
()
noexcept
{}
virtual
bool
Run
()
=
0
;
...
...
@@ -31,9 +33,11 @@ class NetBase {
class
SimpleNet
:
public
NetBase
{
public:
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
SimpleNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
virtual
bool
Run
()
override
;
bool
Run
()
override
;
protected:
vector
<
unique_ptr
<
OperatorBase
>
>
operators_
;
...
...
@@ -41,7 +45,9 @@ class SimpleNet : public NetBase {
DISABLE_COPY_AND_ASSIGN
(
SimpleNet
);
};
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
unique_ptr
<
NetBase
>
CreateNet
(
const
NetDef
&
net_def
,
Workspace
*
ws
,
DeviceType
type
);
unique_ptr
<
NetBase
>
CreateNet
(
const
std
::
shared_ptr
<
const
NetDef
>&
net_def
,
Workspace
*
ws
,
...
...
mace/core/operator.cc
浏览文件 @
72ef73a5
...
...
@@ -6,8 +6,8 @@
namespace
mace
{
std
::
map
<
int32
_t
,
OperatorRegistry
*>*
gDeviceTypeRegistry
()
{
static
std
::
map
<
int32
_t
,
OperatorRegistry
*>
g_device_type_registry
;
std
::
map
<
int32
,
OperatorRegistry
*>*
gDeviceTypeRegistry
()
{
static
std
::
map
<
int32
,
OperatorRegistry
*>
g_device_type_registry
;
return
&
g_device_type_registry
;
}
...
...
mace/core/operator.h
浏览文件 @
72ef73a5
...
...
@@ -58,10 +58,7 @@ class OperatorBase {
inline
const
vector
<
const
Tensor
*>
&
Inputs
()
const
{
return
inputs_
;
}
inline
const
vector
<
Tensor
*>
&
Outputs
()
{
return
outputs_
;
}
virtual
bool
Run
()
{
MACE_NOT_IMPLEMENTED
;
return
false
;
}
virtual
bool
Run
()
=
0
;
inline
const
OperatorDef
&
debug_def
()
const
{
REQUIRE
(
has_debug_def
(),
"operator_def was null!"
);
...
...
@@ -108,10 +105,7 @@ class Operator : public OperatorBase {
DataTypeToEnum
<
T
>::
v
())));
}
}
virtual
bool
Run
()
{
MACE_NOT_IMPLEMENTED
;
return
false
;
}
virtual
bool
Run
()
=
0
;
~
Operator
()
noexcept
override
{}
};
...
...
mace/core/registry.h
浏览文件 @
72ef73a5
...
...
@@ -5,11 +5,7 @@
#ifndef MACE_CORE_REGISTRY_H_
#define MACE_CORE_REGISTRY_H_
#include <memory>
#include <mutex>
#include <string>
#include <map>
#include "mace/core/common.h"
namespace
mace
{
...
...
mace/core/tensor.h
浏览文件 @
72ef73a5
...
...
@@ -53,7 +53,7 @@ class Tensor {
size_
(
0
),
dtype_
(
DT_FLOAT
),
data_
(
nullptr
)
{};
Tensor
(
Allocator
*
a
,
DataType
type
)
:
alloc_
(
a
),
size_
(
0
),
dtype_
(
DT_FLOAT
),
data_
(
nullptr
)
{};
:
alloc_
(
a
),
size_
(
0
),
dtype_
(
type
),
data_
(
nullptr
)
{};
~
Tensor
()
{
if
(
alloc_
&&
data_
.
get
())
{
...
...
@@ -65,10 +65,6 @@ class Tensor {
inline
const
vector
<
TIndex
>&
shape
()
const
{
return
shape_
;
}
inline
int64
NumElements
()
const
{
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
1
,
std
::
multiplies
<
int64
>
());
}
inline
TIndex
dim_size
()
{
return
shape_
.
size
();
}
inline
TIndex
size
()
const
{
return
size_
;
}
...
...
@@ -86,10 +82,6 @@ class Tensor {
return
static_cast
<
T
*>
(
data_
.
get
());
}
void
Deleter
(
void
*
data
)
{
alloc_
->
Delete
(
data
);
}
inline
void
*
raw_mutable_data
()
{
if
(
data_
.
get
()
||
size_
==
0
)
{
return
data_
.
get
();
...
...
@@ -113,7 +105,7 @@ class Tensor {
shape_
=
shape
;
TIndex
size
=
NumElements
();
if
(
size_
!=
size
)
{
size_
=
NumElements
()
;
size_
=
size
;
data_
.
reset
();
}
}
...
...
@@ -127,6 +119,10 @@ class Tensor {
}
private:
inline
int64
NumElements
()
const
{
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
1
,
std
::
multiplies
<
int64
>
());
}
Allocator
*
alloc_
;
TIndex
size_
;
DataType
dtype_
;
...
...
mace/core/types.h
浏览文件 @
72ef73a5
...
...
@@ -16,7 +16,7 @@ struct IsValidDataType;
template
<
class
T
>
struct
DataTypeToEnum
{
static_assert
(
IsValidDataType
<
T
>::
value
,
"Specified Data Type not supported"
);
};
// Specializations below
};
// EnumToDataType<VALUE>::Type is the type for DataType constant VALUE, e.g.
...
...
mace/core/workspace.cc
浏览文件 @
72ef73a5
...
...
@@ -15,7 +15,9 @@ vector<string> Workspace::Tensors() const {
return
names
;
}
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
DataType
type
)
{
Tensor
*
Workspace
::
CreateTensor
(
const
string
&
name
,
Allocator
*
alloc
,
DataType
type
)
{
if
(
HasTensor
(
name
))
{
VLOG
(
1
)
<<
"Tensor "
<<
name
<<
" already exists. Skipping."
;
}
else
{
...
...
mace/core/workspace.h
浏览文件 @
72ef73a5
...
...
@@ -14,7 +14,7 @@ namespace mace {
class
Workspace
{
public:
typedef
std
::
map
<
string
,
unique_ptr
<
Tensor
>>
TensorMap
;
typedef
map
<
string
,
unique_ptr
<
Tensor
>>
TensorMap
;
Workspace
()
{}
...
...
mace/mace.bzl
浏览文件 @
72ef73a5
mace/ops/BUILD
浏览文件 @
72ef73a5
...
...
@@ -9,7 +9,7 @@ package(
licenses
([
"notice"
])
# Apache 2.0
cc_library
(
name
=
"op"
,
name
=
"op
s
"
,
srcs
=
[
"relu.cc"
],
hdrs
=
glob
([
"*.h"
]),
deps
=
[
...
...
@@ -19,10 +19,11 @@ cc_library(
)
cc_test
(
name
=
"
op
_test"
,
name
=
"
relu
_test"
,
srcs
=
[
"relu_test.cc"
,],
deps
=
[
"@gtest//:gtest"
,
":op"
,
":op
s
"
,
],
)
mace/ops/relu.h
浏览文件 @
72ef73a5
...
...
@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OP
ERATOR
S_RELU_H_
#define MACE_OP
ERATOR
S_RELU_H_
#ifndef MACE_OPS_RELU_H_
#define MACE_OPS_RELU_H_
#include "mace/core/operator.h"
...
...
@@ -19,4 +19,4 @@ class ReluOp : public Operator<D, T> {
}
// namespace mace
#endif // MACE_OP
ERATOR
S_RELU_H_
#endif // MACE_OPS_RELU_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录