提交 1eab8cce 编写于 作者: Z zlx

modify the annotations of HookAttribute, Variable declaration

上级 15bf6e05
...@@ -31,9 +31,9 @@ namespace paddle { ...@@ -31,9 +31,9 @@ namespace paddle {
/** /**
* The static pruning hook * The static pruning hook
* Static means user specific a sparsity_ratio before training start, and the * Static means user specify a sparsity_ratio before training started, and the
* network will prune the parameters based on the sparsity_ratio. More deatils * network will prune the parameters based on the sparsity_ratio. More deatils
* can see https://arxiv.org/pdf/1506.02626.pdf. * can be found https://arxiv.org/pdf/1506.02626.pdf.
*/ */
class StaticPruningHook : public IParameterUpdaterHook { class StaticPruningHook : public IParameterUpdaterHook {
...@@ -57,29 +57,31 @@ public: ...@@ -57,29 +57,31 @@ public:
} }
void generateMask(Parameter* para) { void generateMask(Parameter* para) {
VectorPtr vec = para->getBuf(PARAMETER_VALUE);
maskTemp_ = Vector::create(para->getSize(), false); VectorPtr maskTemp = Vector::create(para->getSize(), false);
maskTemp_->zeroMem(); maskTemp->zeroMem();
real* dataPtr = maskTemp_->getData(); real* maskTempData = maskTemp->getData();
size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_); size_t nonZeroNum = para->getSize() * (1 - sparsityRatio_);
VectorPtr vecCpu = Vector::create(para->getSize(), false); VectorPtr paraVec = para->getBuf(PARAMETER_VALUE);
vecCpu->copyFrom(*vec); VectorPtr paraCpuCopy = Vector::create(para->getSize(), false);
paraCpuCopy->copyFrom(*paraVec);
std::vector<std::pair<real, size_t>> param; std::vector<std::pair<real, size_t>> param;
for (size_t i = 0; i < para->getSize(); i++) for (size_t i = 0; i < para->getSize(); i++)
param.push_back(std::make_pair(fabs(vecCpu->getData()[i]), i)); param.push_back(std::make_pair(fabs(paraCpuCopy->getData()[i]), i));
std::partial_sort( std::partial_sort(
param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend); param.begin(), param.begin() + nonZeroNum, param.end(), sortPairAscend);
for (size_t i = 0; i < nonZeroNum; i++) dataPtr[param[i].second] = 1.0; for (size_t i = 0; i < nonZeroNum; i++) maskTempData[param[i].second] = 1.0;
// Currently just use a mask vector for hack. // Currently just use a mask vector for hack.
if (para->useGpu()) { if (para->useGpu()) {
maskVec_ = Vector::create(para->getSize(), para->useGpu()); maskVec_ = Vector::create(para->getSize(), para->useGpu());
maskVec_->copyFrom(*maskTemp_); maskVec_->copyFrom(*maskTemp);
} else { } else {
maskVec_ = maskTemp_; maskVec_ = maskTemp;
} }
} }
...@@ -91,15 +93,14 @@ public: ...@@ -91,15 +93,14 @@ public:
VLOG(3) << "Initialize Parameter " << para; VLOG(3) << "Initialize Parameter " << para;
SetDevice device(para->getDeviceId()); SetDevice device(para->getDeviceId());
auto& vec = para->getBuf(PARAMETER_VALUE); auto& paraVec = para->getBuf(PARAMETER_VALUE);
vec->dotMul(*maskVec_); paraVec->dotMul(*maskVec_);
} }
private: private:
SameThreadChecker updateThreadChecker_; SameThreadChecker updateThreadChecker_;
std::atomic<size_t> initCount_; std::atomic<size_t> initCount_;
VectorPtr maskVec_; VectorPtr maskVec_;
VectorPtr maskTemp_;
real sparsityRatio_; real sparsityRatio_;
}; };
......
...@@ -58,15 +58,21 @@ def is_compatible_with(x, Type): ...@@ -58,15 +58,21 @@ def is_compatible_with(x, Type):
class HookAttribute(object): class HookAttribute(object):
""" """
Hook Attribute object. The hook is an auxiliary operation that occurs Hook Attribute object. As a member of ParameterAttribute class, the hook is an auxiliary operation that occurs
during network propagation. during training process of a layer with parameters, such as img_conv layer, fc layer.
NOTE: IT IS A HIGH LEVEL USER INTERFACE.
:param type: Hook type, currently supported types:
:param type: Hook type, eg: 'pruning' 'pruning' : user specify a sparsity_ratio before training started, and the
network will prune the parameters based on the sparsity_ratio.
eg: The definition of Hook object can be hk = HookAttribute('pruning', 0.6)
The specific usage can be paddle.layer.img_conv(input=img, filter_size=3,
num_channels=3, num_filters=64,
param_attr=ParameterAttribute(update_hooks=hk) )
The pruning deatils can be found https://arxiv.org/pdf/1506.02626.pdf
:type type: string :type type: string
:param sparsity_ratio: Must be specified if hook type is 'pruning', :param sparsity_ratio: Must be specified if hook type is 'pruning',
it represents the ratio of the zero elements to be set by the Parameter. it represents the ratio of the zero elements to be set by the Parameter.
:type sparsity_ratio: float or None :type sparsity_ratio: float or None
""" """
...@@ -78,7 +84,7 @@ class HookAttribute(object): ...@@ -78,7 +84,7 @@ class HookAttribute(object):
assert is_compatible_with( assert is_compatible_with(
self.sparsity_ratio, self.sparsity_ratio,
float), 'sparisity_ratio must be float type' float), 'sparisity_ratio must be float type'
assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity must be a flaot between [0, 1] ' assert self.sparsity_ratio <= 1 and self.sparsity_ratio >= 0, 'sparisity_ratio must be a float between [0, 1] '
def __call__(self): def __call__(self):
return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio) return ParameterHook(self.type, sparsity_ratio=self.sparsity_ratio)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册