提交 18435f2a 编写于 作者: X xzl

modify the pruning from reading mask to specify sparsity_ratio

上级 ca55a24e
......@@ -19,130 +19,31 @@ limitations under the License. */
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include "paddle/math/Vector.h"
#include "paddle/parameter/Parameter.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Util.h"
using std::vector;
using std::pair;
namespace paddle {
/**
* The static pruning hook
*
* Static means user load a mask map before training started. This map will
* define which link/weight between neural is disabled.
* Static means user specific a sparsity_ratio map before training started. The
* network will
* hold the sparsity_ratio maximum numbers of parameters, and cut off the rest.
*/
class StaticPruningHook : public IParameterUpdaterHook {
public:
/**
* The Mask Map Header.
* The map file started with this header.
*
* In Version 0, reset file will be:
* contains header.size bit, each bit means such weight is enabled or not.
* if bit is 1, then such weight is enabled.
* at end, the file will round to byte, and the low bits of end byte will be
* filled by zero.
*
*/
struct StaticMaskHeader {
uint32_t version;
size_t size;
} __attribute__((__packed__));
explicit StaticPruningHook(const std::string& mask_filename) : initCount_(0) {
bool ok = this->loadMaskFile(mask_filename);
if (!ok) {
LOG(WARNING) << "Fail to load mask file " << mask_filename
<< " in current directory, searching in init_model_path";
std::string combineMaskFilename =
path::join(FLAGS_init_model_path, mask_filename);
CHECK(this->loadMaskFile(combineMaskFilename))
<< "Cannot load " << mask_filename << " in ./" << mask_filename
<< " and " << combineMaskFilename;
}
VLOG(3) << mask_filename << " mask size = " << this->mask_.size();
}
void update(Parameter* para) {
updateThreadChecker_.check();
auto& vec = para->getBuf(PARAMETER_GRADIENT);
if (vec) {
vec->dotMul(*maskVec_);
}
}
void init(Parameter* para) {
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
auto maskVec = Vector::create(this->mask_.size(), false);
{ // Initialize maskVec with float mask vector
real* dataPtr = maskVec->getData();
size_t i = 0;
for (bool m : mask_) {
dataPtr[i++] = m ? 1.0 : 0.0;
}
}
// Currently just use a mask vector for hack.
// @TODO(yuyang18): Implemented the mask operation in vector.
if (para->useGpu()) {
maskVec_ = Vector::create(this->mask_.size(), para->useGpu());
maskVec_->copyFrom(*maskVec);
} else {
maskVec_ = maskVec;
}
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
}
private:
bool loadMaskFile(const std::string& mask_filename) {
std::ifstream fin;
fin.open(mask_filename);
if (fin.is_open()) {
StaticMaskHeader header;
fin.read(reinterpret_cast<char*>(&header), sizeof(StaticMaskHeader));
CHECK_EQ(header.version, 0UL);
mask_.resize(header.size);
uint8_t buf;
for (size_t i = 0; i < header.size; ++i, buf <<= 1) {
if (i % 8 == 0) {
fin.read(reinterpret_cast<char*>(&buf), sizeof(uint8_t));
}
mask_[i] = buf & 0x80;
}
fin.close();
return true;
} else {
return false;
}
}
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
std::vector<bool> mask_;
};
class DynamicPruningHook : public IParameterUpdaterHook {
class StaticPruningHook : public IParameterUpdaterHook {
public:
explicit DynamicPruningHook(const ParameterUpdaterHookConfig& hookConfig)
explicit StaticPruningHook(const ParameterUpdaterHookConfig& hookConfig)
: initCount_(0) {
sparsityRatio_ = hookConfig.sparsity_ratio();
}
static bool sortPairAscend(const pair<real, size_t>& pair1,
const pair<real, size_t>& pair2) {
static bool sortPairAscend(const std::pair<real, size_t>& pair1,
const std::pair<real, size_t>& pair2) {
return pair1.first > pair2.first;
}
......@@ -162,7 +63,7 @@ public:
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
vector<pair<real, size_t>> param;
std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
......@@ -175,7 +76,7 @@ public:
void init(Parameter* para) {
generateMask(para);
size_t initCount = this->initCount_.fetch_add(1);
CHECK_EQ(initCount, 0UL) << "Currently the DynamicPruningHook must invoke "
CHECK_EQ(initCount, 0UL) << "Currently the StaticPruningHook must invoke "
"in same ParamterUpdater";
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
......@@ -234,16 +135,9 @@ static WeakKVCache<std::pair<std::string, int>,
static IParameterUpdaterHook* createImpl(
const ParameterUpdaterHookConfig& config) {
auto& type = config.type();
if (type == "pruning_static") {
if (config.has_purning_mask_filename())
return new StaticPruningHook(config.purning_mask_filename());
else
LOG(FATAL) << "There must be mask_filename parameter for " << type
<< " Hook";
} else if (type == "pruning") {
if (type == "pruning") {
if (config.has_sparsity_ratio())
return new DynamicPruningHook(config);
return new StaticPruningHook(config);
else
LOG(FATAL) << "There must be sparsity_ratio parameter for " << type
<< " Hook";
......
......@@ -26,8 +26,7 @@ enum ParameterInitStrategy {
message ParameterUpdaterHookConfig {
required string type = 1;
//hook type such as 'pruning', 'pruning_static'
optional string purning_mask_filename = 2;
//hook type such as 'pruning'
optional double sparsity_ratio = 3;
}
......
......@@ -3171,14 +3171,7 @@ def Layer(name, type, **xargs):
@config_func
def ParameterHook(type, **kwargs):
if type == 'pruning_static':
hook = ParameterUpdaterHookConfig()
hook.type = type
mask_filename = kwargs.get('mask_filename', None)
assert mask_filename is not None
hook.pruning_mask_filename = mask_filename
return hook
elif type == 'pruning':
if type == 'pruning':
hook = ParameterUpdaterHookConfig()
hook.type = type
sparsity_ratio = kwargs.get('sparsity_ratio', None)
......
......@@ -64,32 +64,24 @@ class HookAttribute(object):
here paddle/parameter/ParameterUpdaterHook.cpp
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning', 'pruning_static'
:param type: Hook type, eg: 'pruning'
:type type: string
:param mask_file: Must be specified if hook type is 'pruning_static',
the network reads the mask from the file to determine which parameters should be cut off
:type mask_file: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
the network will hold the sparsity_ratio maximum parameters, and cut off the rest.
:type sparsity_ratio: float number between 0 and 1
"""
def __init__(self, type, mask_filename=None, sparsity_ratio=None):
def __init__(self, type, sparsity_ratio=None):
self.type = type
self.mask_filename = mask_filename
self.sparsity_ratio = sparsity_ratio
assert is_compatible_with(self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
def __call__(self):
return ParameterHook(
self.type,
mask_filename=self.mask_filename,
sparsity_ratio=self.sparsity_ratio)
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
class ParameterAttribute(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册