From 18435f2a738b2baec680eea6fc2648dd094e5c87 Mon Sep 17 00:00:00 2001 From: xzl Date: Fri, 2 Jun 2017 16:31:49 +0800 Subject: [PATCH] modify the pruning from reading mask to specify sparsity_ratio --- paddle/parameter/ParameterUpdaterHook.cpp | 130 ++---------------- proto/ParameterConfig.proto | 3 +- python/paddle/trainer/config_parser.py | 9 +- python/paddle/trainer_config_helpers/attrs.py | 14 +- 4 files changed, 17 insertions(+), 139 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterHook.cpp b/paddle/parameter/ParameterUpdaterHook.cpp index 76cc3ecad14..e29494868bc 100644 --- a/paddle/parameter/ParameterUpdaterHook.cpp +++ b/paddle/parameter/ParameterUpdaterHook.cpp @@ -19,130 +19,31 @@ limitations under the License. */ #include #include #include +#include #include "paddle/math/Vector.h" #include "paddle/parameter/Parameter.h" #include "paddle/utils/Flags.h" #include "paddle/utils/Util.h" -using std::vector; -using std::pair; - namespace paddle { /** * The static pruning hook - * - * Static means user load a mask map before training started. This map will - * define which link/weight between neural is disabled. + * Static means user specific a sparsity_ratio map before training started. The + * network will + * hold the sparsity_ratio maximum numbers of parameters, and cut off the rest. */ -class StaticPruningHook : public IParameterUpdaterHook { -public: - /** - * The Mask Map Header. - * The map file started with this header. - * - * In Version 0, reset file will be: - * contains header.size bit, each bit means such weight is enabled or not. - * if bit is 1, then such weight is enabled. - * at end, the file will round to byte, and the low bits of end byte will be - * filled by zero. - * - */ - struct StaticMaskHeader { - uint32_t version; - size_t size; - } __attribute__((__packed__)); - - explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) { - bool ok = this->loadMaskFile(mask_filename); - if (!ok) { - LOG(WARNING) << "Fail to load mask file " << mask_filename - << " in current directory, searching in init_model_path"; - std::string combineMaskFilename = - path::join(FLAGS_init_model_path, mask_filename); - CHECK(this->loadMaskFile(combineMaskFilename)) - << "Cannot load " << mask_filename << " in ./" << mask_filename - << " and " << combineMaskFilename; - } - VLOG(3) << mask_filename << " mask size = " << this->mask_.size(); - } - void update(Parameter* para) { - updateThreadChecker_.check(); - auto& vec = para->getBuf(PARAMETER_GRADIENT); - if (vec) { - vec->dotMul(*maskVec_); - } - } - - void init(Parameter* para) { - size_t initCount = this->initCount_.fetch_add(1); - CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " - "in same ParamterUpdater"; - VLOG(3) << "Initialize Parameter " << para; - SetDevice device(para->getDeviceId()); - - auto maskVec = Vector::create(this->mask_.size(), false); - { // Initialize maskVec with float mask vector - real* dataPtr = maskVec->getData(); - size_t i = 0; - for (bool m : mask_) { - dataPtr[i++] = m ? 1.0 : 0.0; - } - } - - // Currently just use a mask vector for hack. - // @TODO(yuyang18): Implemented the mask operation in vector. - if (para->useGpu()) { - maskVec_ = Vector::create(this->mask_.size(), para->useGpu()); - maskVec_->copyFrom(*maskVec); - } else { - maskVec_ = maskVec; - } - - auto& vec = para->getBuf(PARAMETER_VALUE); - vec->dotMul(*maskVec_); - } - -private: - bool loadMaskFile(const std::string& mask_filename) { - std::ifstream fin; - fin.open(mask_filename); - if (fin.is_open()) { - StaticMaskHeader header; - fin.read(reinterpret_cast(&header), sizeof(StaticMaskHeader)); - CHECK_EQ(header.version, 0UL); - mask_.resize(header.size); - uint8_t buf; - for (size_t i = 0; i < header.size; ++i, buf <<= 1) { - if (i % 8 == 0) { - fin.read(reinterpret_cast(&buf), sizeof(uint8_t)); - } - mask_[i] = buf & 0x80; - } - fin.close(); - return true; - } else { - return false; - } - } - - SameThreadChecker updateThreadChecker_; - std::atomic initCount_; - VectorPtr maskVec_; - std::vector mask_; -}; - -class DynamicPruningHook : public IParameterUpdaterHook { +class StaticPruningHook : public IParameterUpdaterHook { public: - explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig) + explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig) : initCount_(0) { sparsityRatio_ = hookConfig.sparsity_ratio(); } - static bool sortPairAscend(const pair& pair1, - const pair& pair2) { + static bool sortPairAscend(const std::pair& pair1, + const std::pair& pair2) { return pair1.first > pair2.first; } @@ -162,7 +63,7 @@ public: VectorPtr vecCpu = Vector::create(para->getSize(), false); vecCpu->copyFrom(*vec); - vector> param; + std::vector> param; for (size_t i = 0; i < para->getSize(); i++) param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); @@ -175,7 +76,7 @@ public: void init(Parameter* para) { generateMask(para); size_t initCount = this->initCount_.fetch_add(1); - CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke " + CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke " "in same ParamterUpdater"; VLOG(3) << "Initialize Parameter " << para; SetDevice device(para->getDeviceId()); @@ -234,16 +135,9 @@ static WeakKVCache, static IParameterUpdaterHook* createImpl( const ParameterUpdaterHookConfig& config) { auto& type = config.type(); - if (type == "pruning_static") { - if (config.has_purning_mask_filename()) - return new StaticPruningHook(config.purning_mask_filename()); - else - LOG(FATAL) << "There must be mask_filename parameter for " << type - << " Hook"; - - } else if (type == "pruning") { + if (type == "pruning") { if (config.has_sparsity_ratio()) - return new DynamicPruningHook(config); + return new StaticPruningHook(config); else LOG(FATAL) << "There must be sparsity_ratio parameter for " << type << " Hook"; diff --git a/proto/ParameterConfig.proto b/proto/ParameterConfig.proto index 61f4b037cf0..53e3b94f031 100644 --- a/proto/ParameterConfig.proto +++ b/proto/ParameterConfig.proto @@ -26,8 +26,7 @@ enum ParameterInitStrategy { message ParameterUpdaterHookConfig { required string type = 1; - //hook type such as 'pruning', 'pruning_static' - optional string purning_mask_filename = 2; + //hook type such as 'pruning' optional double sparsity_ratio = 3; } diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 3775375c9b7..bebb76d9847 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs): @config_func def ParameterHook(type, **kwargs): - if type == 'pruning_static': - hook = ParameterUpdaterHookConfig() - hook.type = type - mask_filename = kwargs.get('mask_filename', None) - assert mask_filename is not None - hook.pruning_mask_filename = mask_filename - return hook - elif type == 'pruning': + if type == 'pruning': hook = ParameterUpdaterHookConfig() hook.type = type sparsity_ratio = kwargs.get('sparsity_ratio', None) diff --git a/python/paddle/trainer_config_helpers/attrs.py b/python/paddle/trainer_config_helpers/attrs.py index 011147a3685..a0ad8c44525 100644 --- a/python/paddle/trainer_config_helpers/attrs.py +++ b/python/paddle/trainer_config_helpers/attrs.py @@ -64,32 +64,24 @@ class HookAttribute(object): here paddle/parameter/ParameterUpdaterHook.cpp NOTE: IT IS A HIGH LEVEL USER INTERFACE. - :param type: Hook type, eg: 'pruning', 'pruning_static' + :param type: Hook type, eg: 'pruning' :type type: string - :param mask_file: Must be specified if hook type is 'pruning_static', - the network reads the mask from the file to determine which parameters should be cut off - :type mask_file: string - :param sparsity_ratio: Must be specified if hook type is 'pruning', the network will hold the sparsity_ratio maximum parameters, and cut off the rest. :type sparsity_ratio: float number between 0 and 1 """ - def __init__(self, type, mask_filename=None, sparsity_ratio=None): + def __init__(self, type, sparsity_ratio=None): self.type = type - self.mask_filename = mask_filename self.sparsity_ratio = sparsity_ratio assert is_compatible_with(self.sparsity_ratio, float), 'sparisity_ratio must be float type' assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' def __call__(self): - return ParameterHook( - self.type, - mask_filename=self.mask_filename, - sparsity_ratio=self.sparsity_ratio) + return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) class ParameterAttribute(object): -- GitLab