提交 18435f2a 编写于 作者: X xzl

modify the pruning from reading mask to specify sparsity_ratio

上级 ca55a24e
...@@ -19,130 +19,31 @@ limitations under the License. */ ...@@ -19,130 +19,31 @@ limitations under the License. */
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h" #include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h" #include "paddle/parameter/Parameter.h"
#include "paddle/utils/Flags.h" #include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
using std::vector;
using std::pair;
namespace paddle { namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* * Static means user specific a sparsity_ratio map before training started. The
* Static means user load a mask map before training started. This map will * network will
* define which link/weight between neural is disabled. * hold the sparsity_ratio maximum numbers of parameters, and cut off the rest.
*/
class StaticPruningHook : public IParameterUpdaterHook {
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/ */
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
}
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
}
void update(Parameter* para) {
updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
void init(Parameter* para) {
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
auto maskVec = Vector::create(this->mask_.size(), false);
{ // Initialize maskVec with float mask vector
real* dataPtr = maskVec->getData();
size_t i = 0;
for (bool m : mask_) {
dataPtr[i++] = m ? 1.0 : 0.0;
}
}
// Currently just use a mask vector for hack. class StaticPruningHook : public IParameterUpdaterHook {
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu());
maskVec_->copyFrom(*maskVec);
} else {
maskVec_ = maskVec;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
private:
bool loadMaskFile(const std::string& mask_filename) {
std::ifstream fin;
fin.open(mask_filename);
if (fin.is_open()) {
StaticMaskHeader header;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader));
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size);
uint8_t buf;
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
}
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
std::vector<bool> mask_;
};
class DynamicPruningHook : public IParameterUpdaterHook {
public: public:
explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig) explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig)
: initCount_(0) { : initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio(); sparsityRatio_ = hookConfig.sparsity_ratio();
} }
static bool sortPairAscend(const pair<real, size_t>& pair1, static bool sortPairAscend(const std::pair<real, size_t>& pair1,
const pair<real, size_t>& pair2) { const std::pair<real, size_t>& pair2) {
return pair1.first > pair2.first; return pair1.first > pair2.first;
} }
...@@ -162,7 +63,7 @@ public: ...@@ -162,7 +63,7 @@ public:
VectorPtr vecCpu = Vector::create(para->getSize(), false); VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec); vecCpu->copyFrom(*vec);
vector<pair<real, size_t>> param; std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
...@@ -175,7 +76,7 @@ public: ...@@ -175,7 +76,7 @@ public:
void init(Parameter* para) { void init(Parameter* para) {
generateMask(para); generateMask(para);
size_t initCount = this->initCount_.fetch_add(1); size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke " CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater"; "in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para; VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId()); SetDevice device(para->getDeviceId());
...@@ -234,16 +135,9 @@ static WeakKVCache<std::pair<std::string, int>, ...@@ -234,16 +135,9 @@ static WeakKVCache<std::pair<std::string, int>,
static IParameterUpdaterHook* createImpl( static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) { const ParameterUpdaterHookConfig& config) {
auto& type = config.type(); auto& type = config.type();
if (type == "pruning_static") { if (type == "pruning") {
if (config.has_purning_mask_filename())
return new StaticPruningHook(config.purning_mask_filename());
else
LOG(FATAL) << "There must be mask_filename parameter for " << type
<< " Hook";
} else if (type == "pruning") {
if (config.has_sparsity_ratio()) if (config.has_sparsity_ratio())
return new DynamicPruningHook(config); return new StaticPruningHook(config);
else else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook"; << " Hook";
......
...@@ -26,8 +26,7 @@ enum ParameterInitStrategy { ...@@ -26,8 +26,7 @@ enum ParameterInitStrategy {
message ParameterUpdaterHookConfig { message ParameterUpdaterHookConfig {
required string type = 1; required string type = 1;
//hook type such as 'pruning', 'pruning_static' //hook type such as 'pruning'
optional string purning_mask_filename = 2;
optional double sparsity_ratio = 3; optional double sparsity_ratio = 3;
} }
......
...@@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs): ...@@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs):
@config_func @config_func
def ParameterHook(type, **kwargs): def ParameterHook(type, **kwargs):
if type == 'pruning_static': if type == 'pruning':
hook = ParameterUpdaterHookConfig()
hook.type = type
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook.pruning_mask_filename = mask_filename
return hook
elif type == 'pruning':
hook = ParameterUpdaterHookConfig() hook = ParameterUpdaterHookConfig()
hook.type = type hook.type = type
sparsity_ratio = kwargs.get('sparsity_ratio', None) sparsity_ratio = kwargs.get('sparsity_ratio', None)
......
...@@ -64,32 +64,24 @@ class HookAttribute(object): ...@@ -64,32 +64,24 @@ class HookAttribute(object):
here paddle/parameter/ParameterUpdaterHook.cpp here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE. NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning', 'pruning_static' :param type: Hook type, eg: 'pruning'
:type type: string :type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning', :param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest. the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1 :type sparsity_ratio: float number between 0 and 1
""" """
def __init__(self, type, mask_filename=None, sparsity_ratio=None): def __init__(self, type, sparsity_ratio=None):
self.type = type self.type = type
self.mask_filename = mask_filename
self.sparsity_ratio = sparsity_ratio self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio, assert is_compatible_with(self.sparsity_ratio,
float), 'sparisity_ratio must be float type' float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
def __call__(self): def __call__(self):
return ParameterHook( return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
self.type,
mask_filename=self.mask_filename,
sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object): class ParameterAttribute(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册