提交 8b86624b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #2354 from NHZlX/improve_pruning

Improve pruning  module
...@@ -14,11 +14,13 @@ limitations under the License. */ ...@@ -14,11 +14,13 @@ limitations under the License. */
#include "ParameterUpdaterHook.h" #include "ParameterUpdaterHook.h"
#include <algorithm>
#include <atomic> #include <atomic>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
...@@ -29,106 +31,76 @@ namespace paddle { ...@@ -29,106 +31,76 @@ namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* * Static means user specify a sparsity_ratio before training started, and the
* Static means user load a mask map before training started. This map will * network will prune the parameters based on the sparsity_ratio. More details
* define which link/weight between neural is disabled. * can be found https://arxiv.org/pdf/1506.02626.pdf.
*/ */
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
public: public:
/** explicit StaticPruningHook(const ParameterUpdaterHookConfig &hookConfig)
* The Mask Map Header. : initCount_(0) {
* The map file started with this header. sparsityRatio_ = hookConfig.sparsity_ratio();
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
}
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
} }
void update(Parameter* para) { static bool sortPairAscend(const std::pair<real, size_t> &pair1,
const std::pair<real, size_t> &pair2) {
return pair1.first > pair2.first;
}
void update(Parameter *para) {
updateThreadChecker_.check(); updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT); auto &vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) { if (vec) {
vec->dotMul(*maskVec_); vec->dotMul(*maskVec_);
} }
} }
void init(Parameter* para) { void generateMask(Parameter *para) {
size_t initCount = this->initCount_.fetch_add(1); VectorPtr maskTemp = Vector::create(para->getSize(), false);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " maskTemp->zeroMem();
"in same ParamterUpdater"; real *maskTempData = maskTemp->getData();
VLOG(3) << "Initialize Parameter " << para; size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
SetDevice device(para->getDeviceId());
auto maskVec = Vector::create(this->mask_.size(), false); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
{ // Initialize maskVec with float mask vector VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
real* dataPtr = maskVec->getData();
size_t i = 0; paraCpuCopy->copyFrom(*paraVec);
for (bool m : mask_) { std::vector<std::pair<real, size_t>> param;
dataPtr[i++] = m ? 1.0 : 0.0;
} for (size_t i = 0; i < para->getSize(); i++)
} param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) { if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskVec); maskVec_->copyFrom(*maskTemp);
} else { } else {
maskVec_ = maskVec; maskVec_ = maskTemp;
} }
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
} }
private: void init(Parameter *para) {
bool loadMaskFile(const std::string& mask_filename) { generateMask(para);
std::ifstream fin; size_t initCount = this->initCount_.fetch_add(1);
fin.open(mask_filename); CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
if (fin.is_open()) { "in same ParamterUpdater";
StaticMaskHeader header; VLOG(3) << "Initialize Parameter " << para;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader)); SetDevice device(para->getDeviceId());
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size); auto &paraVec = para->getBuf(PARAMETER_VALUE);
uint8_t buf; paraVec->dotMul(*maskVec_);
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
} }
private:
SameThreadChecker updateThreadChecker_; SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_; std::atomic<size_t> initCount_;
VectorPtr maskVec_; VectorPtr maskVec_;
std::vector<bool> mask_; real sparsityRatio_;
}; };
IParameterUpdaterHook::IParameterUpdaterHook() {} IParameterUpdaterHook::IParameterUpdaterHook() {}
...@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {} ...@@ -145,7 +117,7 @@ IParameterUpdaterHook::~IParameterUpdaterHook() {}
*/ */
class StringIntPairHasher { class StringIntPairHasher {
public: public:
size_t operator()(const std::pair<std::string, int>& k) const { size_t operator()(const std::pair<std::string, int> &k) const {
return intHasher_(strHasher_(k.first) + k.second); return intHasher_(strHasher_(k.first) + k.second);
} }
...@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>, ...@@ -162,19 +134,19 @@ static WeakKVCache<std::pair<std::string, int>,
/** /**
* ParameterUpdaterHook actually factory method. * ParameterUpdaterHook actually factory method.
*/ */
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook *createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig &config) {
auto& type = config.type(); auto &type = config.type();
if (type == "pruning") { if (type == "pruning") {
if (config.has_purning_mask_filename()) { return new StaticPruningHook(config);
return new StaticPruningHook(config.purning_mask_filename());
}
} }
LOG(FATAL) << "Unknown Hook type: " << type;
return nullptr; return nullptr;
} }
std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create( std::shared_ptr<IParameterUpdaterHook> IParameterUpdaterHook::create(
const ParameterConfig& paramConfig, int idx) { const ParameterConfig &paramConfig, int idx) {
std::pair<std::string, int> key = {paramConfig.name(), idx}; std::pair<std::string, int> key = {paramConfig.name(), idx};
return g_hookCache_.get( return g_hookCache_.get(
key, [&] { return createImpl(paramConfig.update_hooks(idx)); }); key, [&] { return createImpl(paramConfig.update_hooks(idx)); });
......
...@@ -25,8 +25,10 @@ enum ParameterInitStrategy { ...@@ -25,8 +25,10 @@ enum ParameterInitStrategy {
} }
message ParameterUpdaterHookConfig { message ParameterUpdaterHookConfig {
// hook type such as 'pruning'
required string type = 1; required string type = 1;
optional string purning_mask_filename = 2; // this represents the ratio of zero element to be set by the Parameter
optional double sparsity_ratio = 2 [default = 0.6];
} }
message ParameterConfig { message ParameterConfig {
......
...@@ -3139,11 +3139,11 @@ def Layer(name, type, **xargs): ...@@ -3139,11 +3139,11 @@ def Layer(name, type, **xargs):
@config_func @config_func
def ParameterHook(type, **kwargs): def ParameterHook(type, **kwargs):
if type == 'pruning': if type == 'pruning':
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook = ParameterUpdaterHookConfig() hook = ParameterUpdaterHookConfig()
hook.type = type hook.type = type
hook.purning_mask_filename = mask_filename sparsity_ratio = kwargs.get('sparsity_ratio', None)
if sparsity_ratio is not None:
hook.sparsity_ratio = sparsity_ratio
return hook return hook
else: else:
return None return None
...@@ -3251,13 +3251,13 @@ def Parameter(name, ...@@ -3251,13 +3251,13 @@ def Parameter(name,
if update_hooks is not None: if update_hooks is not None:
if hasattr(update_hooks, '__call__'): if hasattr(update_hooks, '__call__'):
update_hooks = update_hooks(para.name) update_hooks = update_hooks()
if isinstance(update_hooks, list): if isinstance(update_hooks, list):
for hook in update_hooks: for hook in update_hooks:
para.update_hooks.extend([hook]) para.update_hooks.extend([hook])
else: else:
para.update_hooks.extend(update_hooks) para.update_hooks.extend([update_hooks])
g_parameter_map[name] = para g_parameter_map[name] = para
if initializer is not None: if initializer is not None:
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
from paddle.trainer.config_parser import * from paddle.trainer.config_parser import *
__all__ = [ __all__ = [
'ParamAttr', 'ExtraAttr', 'ParameterAttribute', 'ExtraLayerAttribute' 'HookAttr', 'ParamAttr', 'ExtraAttr', 'ParameterAttribute',
'ExtraLayerAttribute'
] ]
...@@ -55,6 +56,40 @@ def is_compatible_with(x, Type): ...@@ -55,6 +56,40 @@ def is_compatible_with(x, Type):
return False return False
class HookAttribute(object):
"""
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning details can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
def __init__(self, type, sparsity_ratio=None):
self.type = type
self.sparsity_ratio = sparsity_ratio
if self.sparsity_ratio is not None:
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparsity_ratio must be a float between [0, 1] '
def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object): class ParameterAttribute(object):
""" """
Parameter Attributes object. To fine-tuning network training process, user Parameter Attributes object. To fine-tuning network training process, user
...@@ -114,6 +149,7 @@ class ParameterAttribute(object): ...@@ -114,6 +149,7 @@ class ParameterAttribute(object):
momentum=None, momentum=None,
gradient_clipping_threshold=None, gradient_clipping_threshold=None,
sparse_update=False, sparse_update=False,
update_hooks=None,
initializer=None): initializer=None):
self.attr = {} self.attr = {}
...@@ -169,6 +205,9 @@ class ParameterAttribute(object): ...@@ -169,6 +205,9 @@ class ParameterAttribute(object):
if initializer is not None: if initializer is not None:
self.attr['initializer'] = initializer self.attr['initializer'] = initializer
if update_hooks:
self.attr['update_hooks'] = update_hooks
def set_default_parameter_name(self, name): def set_default_parameter_name(self, name):
""" """
Set default parameter name. If parameter not set, then will use default Set default parameter name. If parameter not set, then will use default
...@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object): ...@@ -244,5 +283,6 @@ class ExtraLayerAttribute(object):
return attr.attr return attr.attr
HookAttr = HookAttribute
ParamAttr = ParameterAttribute ParamAttr = ParameterAttribute
ExtraAttr = ExtraLayerAttribute ExtraAttr = ExtraLayerAttribute
...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs ...@@ -17,10 +17,12 @@ import paddle.trainer_config_helpers.attrs
__all__ = [ __all__ = [
"Param", "Param",
"Extra", "Extra",
"Hook",
] ]
Param = paddle.trainer_config_helpers.attrs.ParameterAttribute Param = paddle.trainer_config_helpers.attrs.ParameterAttribute
Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute Extra = paddle.trainer_config_helpers.attrs.ExtraLayerAttribute
Hook = paddle.trainer_config_helpers.attrs.HookAttribute
for each in paddle.trainer_config_helpers.attrs.__all__: for each in paddle.trainer_config_helpers.attrs.__all__:
globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each) globals()[each] = getattr(paddle.trainer_config_helpers.attrs, each)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册