提交 1eab8cce 编写于 作者: Z zlx

modify the annotations of HookAttribute, Variable declaration

上级 15bf6e05
......@@ -31,9 +31,9 @@ namespace paddle {
/**
* The static pruning hook
* Static means user specific a sparsity_ratio before training start, and the
* Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More deatils
* can see https://arxiv.org/pdf/1506.02626.pdf.
* can be found https://arxiv.org/pdf/1506.02626.pdf.
*/
class StaticPruningHook : public IParameterUpdaterHook {
......@@ -57,29 +57,31 @@ public:
}
void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false);
maskTemp_->zeroMem();
real* dataPtr = maskTemp_->getData();
VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp->zeroMem();
real* maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false);
vecCpu->copyFrom(*vec);
VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i));
param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0;
for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack.
if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_);
maskVec_->copyFrom(*maskTemp);
} else {
maskVec_ = maskTemp_;
maskVec_ = maskTemp;
}
}
......@@ -91,15 +93,14 @@ public:
VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId());
auto& vec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_);
auto& paraVec = para->getBuf(PARAMETER_VALUE);
paraVec->dotMul(*maskVec_);
}
private:
SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_;
VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_;
};
......
......@@ -58,15 +58,21 @@ def is_compatible_with(x, Type):
class HookAttribute(object):
"""
Hook Attribute object. The hook is an auxiliary operation that occurs
during network propagation.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, eg: 'pruning'
Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during training process of a layer with parameters, such as img_conv layer, fc layer.
:param type: Hook type, currently supported types:
'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter.
it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None
"""
......@@ -78,7 +84,7 @@ class HookAttribute(object):
assert is_compatible_with(
self.sparsity_ratio,
float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] '
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] '
def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册