Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
18435f2a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
18435f2a
编写于
6月 02, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify the pruning from reading mask to specify sparsity_ratio
上级
ca55a24e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
17 addition
and
139 deletion
+17
-139
paddle/parameter/ParameterUpdaterHook.cpp
paddle/parameter/ParameterUpdaterHook.cpp
+12
-118
proto/ParameterConfig.proto
proto/ParameterConfig.proto
+1
-2
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+1
-8
python/paddle/trainer_config_helpers/attrs.py
python/paddle/trainer_config_helpers/attrs.py
+3
-11
未找到文件。
paddle/parameter/ParameterUpdaterHook.cpp
浏览文件 @
18435f2a
...
@@ -19,130 +19,31 @@ limitations under the License. */
...
@@ -19,130 +19,31 @@ limitations under the License. */
#include <mutex>
#include <mutex>
#include <thread>
#include <thread>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
#include "paddle/parameter/Parameter.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Util.h"
using
std
::
vector
;
using
std
::
pair
;
namespace
paddle
{
namespace
paddle
{
/**
/**
* The static pruning hook
* The static pruning hook
*
*
Static means user specific a sparsity_ratio map before training started. The
*
Static means user load a mask map before training started. This map
will
*
network
will
*
define which link/weight between neural is disabled
.
*
hold the sparsity_ratio maximum numbers of parameters, and cut off the rest
.
*/
*/
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct
StaticMaskHeader
{
uint32_t
version
;
size_t
size
;
}
__attribute__
((
__packed__
));
explicit
StaticPruningHook
(
const
std
::
string
&
mask_filename
)
:
initCount_
(
0
)
{
bool
ok
=
this
->
loadMaskFile
(
mask_filename
);
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
"Fail to load mask file "
<<
mask_filename
<<
" in current directory, searching in init_model_path"
;
std
::
string
combineMaskFilename
=
path
::
join
(
FLAGS_init_model_path
,
mask_filename
);
CHECK
(
this
->
loadMaskFile
(
combineMaskFilename
))
<<
"Cannot load "
<<
mask_filename
<<
" in ./"
<<
mask_filename
<<
" and "
<<
combineMaskFilename
;
}
VLOG
(
3
)
<<
mask_filename
<<
" mask size = "
<<
this
->
mask_
.
size
();
}
void
update
(
Parameter
*
para
)
{
class
StaticPruningHook
:
public
IParameterUpdaterHook
{
updateThreadChecker_
.
check
();
auto
&
vec
=
para
->
getBuf
(
PARAMETER_GRADIENT
);
if
(
vec
)
{
vec
->
dotMul
(
*
maskVec_
);
}
}
void
init
(
Parameter
*
para
)
{
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
auto
maskVec
=
Vector
::
create
(
this
->
mask_
.
size
(),
false
);
{
// Initialize maskVec with float mask vector
real
*
dataPtr
=
maskVec
->
getData
();
size_t
i
=
0
;
for
(
bool
m
:
mask_
)
{
dataPtr
[
i
++
]
=
m
?
1.0
:
0.0
;
}
}
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if
(
para
->
useGpu
())
{
maskVec_
=
Vector
::
create
(
this
->
mask_
.
size
(),
para
->
useGpu
());
maskVec_
->
copyFrom
(
*
maskVec
);
}
else
{
maskVec_
=
maskVec
;
}
auto
&
vec
=
para
->
getBuf
(
PARAMETER_VALUE
);
vec
->
dotMul
(
*
maskVec_
);
}
private:
bool
loadMaskFile
(
const
std
::
string
&
mask_filename
)
{
std
::
ifstream
fin
;
fin
.
open
(
mask_filename
);
if
(
fin
.
is_open
())
{
StaticMaskHeader
header
;
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
header
),
sizeof
(
StaticMaskHeader
));
CHECK_EQ
(
header
.
version
,
0UL
);
mask_
.
resize
(
header
.
size
);
uint8_t
buf
;
for
(
size_t
i
=
0
;
i
<
header
.
size
;
++
i
,
buf
<<=
1
)
{
if
(
i
%
8
==
0
)
{
fin
.
read
(
reinterpret_cast
<
char
*>
(
&
buf
),
sizeof
(
uint8_t
));
}
mask_
[
i
]
=
buf
&
0x80
;
}
fin
.
close
();
return
true
;
}
else
{
return
false
;
}
}
SameThreadChecker
updateThreadChecker_
;
std
::
atomic
<
size_t
>
initCount_
;
VectorPtr
maskVec_
;
std
::
vector
<
bool
>
mask_
;
};
class
DynamicPruningHook
:
public
IParameterUpdaterHook
{
public:
public:
explicit
Dynam
icPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
explicit
Stat
icPruningHook
(
const
ParameterUpdaterHookConfig
&
hookConfig
)
:
initCount_
(
0
)
{
:
initCount_
(
0
)
{
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
sparsityRatio_
=
hookConfig
.
sparsity_ratio
();
}
}
static
bool
sortPairAscend
(
const
pair
<
real
,
size_t
>&
pair1
,
static
bool
sortPairAscend
(
const
std
::
pair
<
real
,
size_t
>&
pair1
,
const
pair
<
real
,
size_t
>&
pair2
)
{
const
std
::
pair
<
real
,
size_t
>&
pair2
)
{
return
pair1
.
first
>
pair2
.
first
;
return
pair1
.
first
>
pair2
.
first
;
}
}
...
@@ -162,7 +63,7 @@ public:
...
@@ -162,7 +63,7 @@ public:
VectorPtr
vecCpu
=
Vector
::
create
(
para
->
getSize
(),
false
);
VectorPtr
vecCpu
=
Vector
::
create
(
para
->
getSize
(),
false
);
vecCpu
->
copyFrom
(
*
vec
);
vecCpu
->
copyFrom
(
*
vec
);
vector
<
pair
<
real
,
size_t
>>
param
;
std
::
vector
<
std
::
pair
<
real
,
size_t
>>
param
;
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
for
(
size_t
i
=
0
;
i
<
para
->
getSize
();
i
++
)
param
.
push_back
(
std
::
make_pair
(
fabs
(
vecCpu
->
getData
()[
i
]),
i
));
param
.
push_back
(
std
::
make_pair
(
fabs
(
vecCpu
->
getData
()[
i
]),
i
));
...
@@ -175,7 +76,7 @@ public:
...
@@ -175,7 +76,7 @@ public:
void
init
(
Parameter
*
para
)
{
void
init
(
Parameter
*
para
)
{
generateMask
(
para
);
generateMask
(
para
);
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
size_t
initCount
=
this
->
initCount_
.
fetch_add
(
1
);
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the
Dynam
icPruningHook must invoke "
CHECK_EQ
(
initCount
,
0UL
)
<<
"Currently the
Stat
icPruningHook must invoke "
"in same ParamterUpdater"
;
"in same ParamterUpdater"
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
VLOG
(
3
)
<<
"Initialize Parameter "
<<
para
;
SetDevice
device
(
para
->
getDeviceId
());
SetDevice
device
(
para
->
getDeviceId
());
...
@@ -234,16 +135,9 @@ static WeakKVCache<std::pair<std::string, int>,
...
@@ -234,16 +135,9 @@ static WeakKVCache<std::pair<std::string, int>,
static
IParameterUpdaterHook
*
createImpl
(
static
IParameterUpdaterHook
*
createImpl
(
const
ParameterUpdaterHookConfig
&
config
)
{
const
ParameterUpdaterHookConfig
&
config
)
{
auto
&
type
=
config
.
type
();
auto
&
type
=
config
.
type
();
if
(
type
==
"pruning_static"
)
{
if
(
type
==
"pruning"
)
{
if
(
config
.
has_purning_mask_filename
())
return
new
StaticPruningHook
(
config
.
purning_mask_filename
());
else
LOG
(
FATAL
)
<<
"There must be mask_filename parameter for "
<<
type
<<
" Hook"
;
}
else
if
(
type
==
"pruning"
)
{
if
(
config
.
has_sparsity_ratio
())
if
(
config
.
has_sparsity_ratio
())
return
new
Dynam
icPruningHook
(
config
);
return
new
Stat
icPruningHook
(
config
);
else
else
LOG
(
FATAL
)
<<
"There must be sparsity_ratio parameter for "
<<
type
LOG
(
FATAL
)
<<
"There must be sparsity_ratio parameter for "
<<
type
<<
" Hook"
;
<<
" Hook"
;
...
...
proto/ParameterConfig.proto
浏览文件 @
18435f2a
...
@@ -26,8 +26,7 @@ enum ParameterInitStrategy {
...
@@ -26,8 +26,7 @@ enum ParameterInitStrategy {
message
ParameterUpdaterHookConfig
{
message
ParameterUpdaterHookConfig
{
required
string
type
=
1
;
required
string
type
=
1
;
//hook type such as 'pruning', 'pruning_static'
//hook type such as 'pruning'
optional
string
purning_mask_filename
=
2
;
optional
double
sparsity_ratio
=
3
;
optional
double
sparsity_ratio
=
3
;
}
}
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
18435f2a
...
@@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs):
...
@@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs):
@
config_func
@
config_func
def
ParameterHook
(
type
,
**
kwargs
):
def
ParameterHook
(
type
,
**
kwargs
):
if
type
==
'pruning_static'
:
if
type
==
'pruning'
:
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
mask_filename
=
kwargs
.
get
(
'mask_filename'
,
None
)
assert
mask_filename
is
not
None
hook
.
pruning_mask_filename
=
mask_filename
return
hook
elif
type
==
'pruning'
:
hook
=
ParameterUpdaterHookConfig
()
hook
=
ParameterUpdaterHookConfig
()
hook
.
type
=
type
hook
.
type
=
type
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
sparsity_ratio
=
kwargs
.
get
(
'sparsity_ratio'
,
None
)
...
...
python/paddle/trainer_config_helpers/attrs.py
浏览文件 @
18435f2a
...
@@ -64,32 +64,24 @@ class HookAttribute(object):
...
@@ -64,32 +64,24 @@ class HookAttribute(object):
here paddle/parameter/ParameterUpdaterHook.cpp
here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning'
, 'pruning_static'
:param type: Hook type, eg: 'pruning'
:type type: string
:type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
:type sparsity_ratio: float number between 0 and 1
"""
"""
def
__init__
(
self
,
type
,
mask_filename
=
None
,
sparsity_ratio
=
None
):
def
__init__
(
self
,
type
,
sparsity_ratio
=
None
):
self
.
type
=
type
self
.
type
=
type
self
.
mask_filename
=
mask_filename
self
.
sparsity_ratio
=
sparsity_ratio
self
.
sparsity_ratio
=
sparsity_ratio
assert
is_compatible_with
(
self
.
sparsity_ratio
,
assert
is_compatible_with
(
self
.
sparsity_ratio
,
float
),
'sparisity_ratio must be float type'
float
),
'sparisity_ratio must be float type'
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparisity must be a flaot between [0, 1] '
assert
self
.
sparsity_ratio
<=
1
and
self
.
sparsity_ratio
>=
0
,
'sparisity must be a flaot between [0, 1] '
def
__call__
(
self
):
def
__call__
(
self
):
return
ParameterHook
(
return
ParameterHook
(
self
.
type
,
sparsity_ratio
=
self
.
sparsity_ratio
)
self
.
type
,
mask_filename
=
self
.
mask_filename
,
sparsity_ratio
=
self
.
sparsity_ratio
)
class
ParameterAttribute
(
object
):
class
ParameterAttribute
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录