提交 5413af8d 编写于 作者: X xzl

imporve pruning module

上级 da83d286
......@@ -25,6 +25,9 @@ limitations under the License. */
#include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h"
using std::vector;
using std::pair;
namespace paddle {
/**
......@@ -131,6 +134,73 @@ private:
std::vector<bool> mask_;
};
class DynamicPruningHook : public IParameterUpdaterHook {
public:
explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig)
: initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio();
}
static bool sortPairAscend(const pair<real, size_t>& pair1,
const pair<real, size_t>& pair2) {
return pair1.first > pair2.first;
}
void update(Parameter* para) {
updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
vector<pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
std::sort(param.begin(), param.end(), sortPairAscend);
for (size_t i = 0; i < para->getSize() * sparsityRatio_; i++)
dataPtr[param[i].second] = 1.0;
}
void init(Parameter* para) {
generateMask(para);
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
} else {
maskVec_ = maskTemp_;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_;
};
IParameterUpdaterHook::IParameterUpdaterHook() {}
IParameterUpdaterHook::~IParameterUpdaterHook() {}
......@@ -156,8 +226,7 @@ private:
static WeakKVCache<std::pair<std::string, int>,
IParameterUpdaterHook,
StringIntPairHasher>
g_hookCache_;
StringIntPairHasher> g_hookCache_;
/**
* ParameterUpdaterHook actually factory method.
......@@ -165,11 +234,22 @@ static WeakKVCache<std::pair<std::string, int>,
static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) {
auto& type = config.type();
if (type == "pruning") {
if (config.has_purning_mask_filename()) {
if (type == "pruning_static") {
if (config.has_purning_mask_filename())
return new StaticPruningHook(config.purning_mask_filename());
}
else
LOG(FATAL) << "There must be mask_filename parameter for " << type
<< " Hook";
} else if (type == "pruning") {
if (config.has_sparsity_ratio())
return new DynamicPruningHook(config);
else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook";
}
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr;
}
......
......@@ -26,7 +26,9 @@ enum ParameterInitStrategy {
message ParameterUpdaterHookConfig {
required string type = 1;
//hook type such as 'pruning', 'pruning_static'
optional string purning_mask_filename = 2;
optional double sparsity_ratio = 3;
}
message ParameterConfig {
......
......@@ -3171,12 +3171,19 @@ def Layer(name, type, **xargs):
@config_func
def ParameterHook(type, **kwargs):
if type == 'pruning':
if type == 'pruning_static':
hook = ParameterUpdaterHookConfig()
hook.type = type
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook.pruning_mask_filename = mask_filename
return hook
elif type == 'pruning':
hook = ParameterUpdaterHookConfig()
hook.type = type
hook.purning_mask_filename = mask_filename
sparsity_ratio = kwargs.get('sparsity_ratio', None)
assert sparsity_ratio is not None
hook.sparsity_ratio = sparsity_ratio
return hook
else:
return None
......@@ -3283,13 +3290,13 @@ def Parameter(name,
if update_hooks is not None:
if hasattr(update_hooks, '__call__'):
update_hooks = update_hooks(para.name)
update_hooks = update_hooks()
if isinstance(update_hooks, list):
for hook in update_hooks:
para.update_hooks.extend([hook])
else:
para.update_hooks.extend(update_hooks)
para.update_hooks.extend([update_hooks])
g_parameter_map[name] = para
......
......@@ -14,7 +14,8 @@
from paddle.trainer.config_parser import *
__all__ = [
'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute'
'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute'
]
......@@ -55,6 +56,42 @@ def is_compatible_with(x, Type):
return False
class HookAttribute(object):
"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation. Such as pruning operation, It will cut off
redundant parameters in the network before training. More detail can see
here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning', 'pruning_static'
:type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
"""
def __init__(self, type, mask_filename=None, sparsity_ratio=None):
self.type = type
self.mask_filename = mask_filename
self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
def __call__(self):
return ParameterHook(
self.type,
mask_filename=self.mask_filename,
sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object):
"""
Parameter Attributes object. To fine-tuning network training process, user
......@@ -109,7 +146,8 @@ class ParameterAttribute(object):
learning_rate=None,
momentum=None,
gradient_clipping_threshold=None,
sparse_update=False):
sparse_update=False,
update_hooks=None):
self.attr = {}
if is_static:
......@@ -162,6 +200,9 @@ class ParameterAttribute(object):
self.attr['gradient_clipping_threshold'] = \
gradient_clipping_threshold
if update_hooks:
self.attr['update_hooks'] = update_hooks
def set_default_parameter_name(self, name):
"""
Set default parameter name. If parameter not set, then will use default
......@@ -237,5 +278,6 @@ class ExtraLayerAttribute(object):
return attr.attr
HookAttr = HookAttribute
ParamAttr = ParameterAttribute
ExtraAttr = ExtraLayerAttribute
......@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__ = [
"Param",
"Extra",
"Hook",
]
Param = paddle.trainer_config_helpers.attrs.ParameterAttribute
Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute
Hook = paddle.trainer_config_helpers.attrs.HookAttribute
for each in paddle.trainer_config_helpers.attrs.__all__:
globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册