提交 8b86624b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #2354 from NHZlX/improve_pruning

Improve pruning  module
......@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic>
#include <fstream>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
......@@ -29,106 +31,76 @@ namespace paddle {
/**
* The static pruning hook
*
* Static means user load a mask map before training started. This map will
* define which link/weight between neural is disabled.
* Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More details
* can be found https://arxiv.org/pdf/1506.02626.pdf.
*/
class StaticPruningHook : public IParameterUpdaterHook {
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
}
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
: initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio();
}
void update(Parameter* para) {
static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
}
void update(Parameter *para) {
updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT);
auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
void init(Parameter* para) {
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
void generateMask(Parameter *para) {
VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp->zeroMem();
real *maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
auto maskVec = Vector::create(this->mask_.size(), false);
{ // Initialize maskVec with float mask vector
real* dataPtr = maskVec->getData();
size_t i = 0;
for (bool m : mask_) {
dataPtr[i++] = m ? 1.0 : 0.0;
}
}
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu());
maskVec_->copyFrom(*maskVec);
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp);
} else {
maskVec_ = maskVec;
maskVec_ = maskTemp;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
private:
bool loadMaskFile(const std::string& mask_filename) {
std::ifstream fin;
fin.open(mask_filename);
if (fin.is_open()) {
StaticMaskHeader header;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader));
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size);
uint8_t buf;
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
void init(Parameter *para) {
generateMask(para);
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
auto &paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*maskVec_);
}
private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
std::vector<bool> mask_;
real sparsityRatio_;
};
IParameterUpdaterHook::IParameterUpdaterHook() {}
......@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/
class StringIntPairHasher {
public:
size_t operator()(const std::pair<std::string, int>& k) const {
size_t operator()(const std::pair<std::string, int> &k) const {
return intHasher_(strHasher_(k.first) + k.second);
}
......@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/**
* ParameterUpdaterHook actually factory method.
*/
static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) {
auto& type = config.type();
static IParameterUpdaterHook *createImpl(
const ParameterUpdaterHookConfig &config) {
auto &type = config.type();
if (type == "pruning") {
if (config.has_purning_mask_filename()) {
return new StaticPruningHook(config.purning_mask_filename());
}
return new StaticPruningHook(config);
}
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr;
}
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
const ParameterConfig& paramConfig, int idx) {
const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
......
......@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
}
message ParameterUpdaterHookConfig {
// hook type such as 'pruning'
required string type = 1;
optional string purning_mask_filename = 2;
// this represents the ratio of zero element to be set by the Parameter
optional double sparsity_ratio = 2 [default = 0.6];
}
message ParameterConfig {
......
......@@ -3139,11 +3139,11 @@ def Layer(name, type, **xargs):
@config_func
def ParameterHook(type, **kwargs):
if type == 'pruning':
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook = ParameterUpdaterHookConfig()
hook.type = type
hook.purning_mask_filename = mask_filename
sparsity_ratio = kwargs.get('sparsity_ratio', None)
if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio
return hook
else:
return None
......@@ -3251,13 +3251,13 @@ def Parameter(name,
if update_hooks is not None:
if hasattr(update_hooks, '__call__'):
update_hooks = update_hooks(para.name)
update_hooks = update_hooks()
if isinstance(update_hooks, list):
for hook in update_hooks:
para.update_hooks.extend([hook])
else:
para.update_hooks.extend(update_hooks)
para.update_hooks.extend([update_hooks])
g_parameter_map[name] = para
if initializer is not None:
......
......@@ -14,7 +14,8 @@
from paddle.trainer.config_parser import *
__all__ = [
'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute'
'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute'
]
......@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return False
class HookAttribute(object):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def __init__(self, type, sparsity_ratio=None):
self.type = type
self.sparsity_ratio = sparsity_ratio
if self.sparsity_ratio is not None:
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparsity_ratio must be a float between [0, 1] '
def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object):
"""
Parameter Attributes object. To fine-tuning network training process, user
......@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum=None,
gradient_clipping_threshold=None,
sparse_update=False,
update_hooks=None,
initializer=None):
self.attr = {}
......@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if initializer is not None:
self.attr['initializer'] = initializer
if update_hooks:
self.attr['update_hooks'] = update_hooks
def set_default_parameter_name(self, name):
"""
Set default parameter name. If parameter not set, then will use default
......@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return attr.attr
HookAttr = HookAttribute
ParamAttr = ParameterAttribute
ExtraAttr = ExtraLayerAttribute
......@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__ = [
"Param",
"Extra",
"Hook",
]
Param = paddle.trainer_config_helpers.attrs.ParameterAttribute
Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute
Hook = paddle.trainer_config_helpers.attrs.HookAttribute
for each in paddle.trainer_config_helpers.attrs.__all__:
globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册