提交 1eab8cce 编写于 作者: Z zlx

modify the annotations of HookAttribute, Variable declaration

上级 15bf6e05
......@@ -31,9 +31,9 @@ namespace paddle {
/**
* The static pruning hook
* Static means user specific a sparsity_ratio before training start, and the
* Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More deatils
* can see https://arxiv.org/pdf/1506.02626.pdf.
* can be found https://arxiv.org/pdf/1506.02626.pdf.
*/
class StaticPruningHook : public IParameterUpdaterHook {
......@@ -57,29 +57,31 @@ public:
}
void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp->zeroMem();
real* maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0;
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
maskVec_->copyFrom(*maskTemp);
} else {
maskVec_ = maskTemp_;
maskVec_ = maskTemp;
}
}
......@@ -91,15 +93,14 @@ public:
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
auto& paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*maskVec_);
}
private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_;
};
......
......@@ -58,11 +58,17 @@ def is_compatible_with(x, Type):
class HookAttribute(object):
"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning'
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
......@@ -78,7 +84,7 @@ class HookAttribute(object):
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] '
def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册