提交 b1bca066 编写于 作者: H hedaoyuan

Refine the ExpandConvLayer.

上级 e4c340e4
...@@ -53,14 +53,15 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, ...@@ -53,14 +53,15 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
weights_.emplace_back(w); weights_.emplace_back(w);
index++; index++;
} }
if (biasParameter_.get()) { if (biasParameter_.get()) {
if (sharedBiases_) { if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ = biases_ = std::unique_ptr<Weight>(
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_)); new Weight(1, numFilters_, biasParameter_, 0));
} else { } else {
biases_ = biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_)); std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_, 0));
} }
} }
...@@ -189,12 +190,7 @@ void ExpandConvLayer::forward(PassType passType) { ...@@ -189,12 +190,7 @@ void ExpandConvLayer::forward(PassType passType) {
/* add the bias-vector */ /* add the bias-vector */
if (biases_.get()) { if (biases_.get()) {
MatrixPtr bias = Matrix::create(biases_->getW()->getData(), output_.value->addBias(*biases_->getW(), 1.0, sharedBiases_);
1,
biases_->getW()->getElementCnt(),
false,
useGpu_);
output_.value->addBias(*bias, 1.0, sharedBiases_);
} }
/* activation */ /* activation */
...@@ -206,13 +202,7 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) { ...@@ -206,13 +202,7 @@ void ExpandConvLayer::backward(const UpdateCallback &callback) {
MatrixPtr outGrad = getOutputGrad(); MatrixPtr outGrad = getOutputGrad();
if (biases_ && biases_->getWGrad()) { if (biases_ && biases_->getWGrad()) {
// bpropBiases(outGrad); biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBiases_);
MatrixPtr bias = Matrix::create(biases_->getWGrad()->getData(),
1,
biases_->getWGrad()->getElementCnt(),
false,
useGpu_);
bias->collectBias(*getOutputGrad(), 1, sharedBiases_);
/* Increasing the number of gradient */ /* Increasing the number of gradient */
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册