DeConv3DLayer.cpp 7.7 KB
Newer Older
1
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

C
chengduoZH 已提交
15
#include "DeConv3DLayer.h"
X
Xin Pan 已提交
16 17
#include "paddle/legacy/utils/Logging.h"
#include "paddle/legacy/utils/Stat.h"
C
chengduoZH 已提交
18 19 20 21 22 23

namespace paddle {

REGISTER_LAYER(deconv3d, DeConv3DLayer);

bool DeConv3DLayer::init(const LayerMap &layerMap,
C
chengduoZH 已提交
24
                         const ParameterMap &parameterMap) {
C
chengduoZH 已提交
25 26 27 28 29 30
  if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
  // for Deconv, the dimension of Kernel is
  // channel * output * depth * height * weigth
  // Matrix storage format: (output * depth * height * weigth) x  channel
  for (int index = 0; index < config_.inputs().size(); ++index) {
    M_.push_back(filterChannels_[index]);
C
chengduoZH 已提交
31
    K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index]));
C
chengduoZH 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44

    // create a new weight
    size_t height, width;
    height = filterPixels_[index] * numFilters_;
    width = filterChannels_[index];
    CHECK_EQ(parameters_[index]->getSize(), width * height);
    Weight *w = new Weight(height, width, parameters_[index]);
    weights_.emplace_back(w);
  }
  if (biasParameter_.get()) {
    if (sharedBiases_) {
      CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
      biases_ =
C
chengduoZH 已提交
45
          std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
C
chengduoZH 已提交
46 47
    } else {
      biases_ =
C
chengduoZH 已提交
48
          std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
C
chengduoZH 已提交
49
    }
C
chengduoZH 已提交
50 51 52 53 54 55
  }
  return true;
}

size_t DeConv3DLayer::getSize() {
  CHECK_NE(inputLayers_.size(), 0UL);
56 57 58
  imgSizeW_.clear();
  imgSizeH_.clear();
  imgSizeD_.clear();
C
chengduoZH 已提交
59
  N_.clear();
C
chengduoZH 已提交
60
  NOut_.clear();
C
chengduoZH 已提交
61 62
  size_t layerSize = 0;
  for (size_t i = 0; i < inputLayers_.size(); ++i) {
63 64 65 66 67 68 69 70
    imgSizeW_.push_back(
        imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true));
    imgSizeH_.push_back(imageSize(
        outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
    imgSizeD_.push_back(imageSize(
        outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
    NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
    N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
C
chengduoZH 已提交
71
    CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
C
chengduoZH 已提交
72
    layerSize += NOut_[i] * numFilters_;
C
chengduoZH 已提交
73
  }
74 75 76
  getOutput().setFrameHeight(imgSizeH_[0]);
  getOutput().setFrameWidth(imgSizeW_[0]);
  getOutput().setFrameDepth(imgSizeD_[0]);
C
chengduoZH 已提交
77 78 79 80 81 82 83 84 85 86
  return layerSize;
}

void DeConv3DLayer::forward(PassType passType) {
  Layer::forward(passType);
  int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
  int outWidth = getSize();
  resetOutput(batchSize, outWidth);
  const MatrixPtr outMat = getOutputValue();

87
  REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str());
C
chengduoZH 已提交
88
  for (size_t i = 0; i != inputLayers_.size(); ++i) {
C
chengduoZH 已提交
89
    const MatrixPtr &inMat = getInputValue(i);
C
chengduoZH 已提交
90 91 92 93
    int M = M_[i];
    int N = N_[i];
    int K = K_[i];
    MatrixPtr wMat = weights_[i]->getW();
C
chengduoZH 已提交
94
    Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
C
chengduoZH 已提交
95
    for (int n = 0; n < batchSize; ++n) {
C
chengduoZH 已提交
96 97 98 99 100 101 102
      real *inData = inMat->getData() + n * inMat->getStride();
      for (int g = 0; g < groups_[i]; ++g) {
        MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
        MatrixPtr wMatSub = wMat->subMatrix(g * K, K);
        MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
        colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0);
        inData += M * N;
C
chengduoZH 已提交
103
      }
C
chengduoZH 已提交
104 105
      colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
                       numFilters_,
106 107 108
                       imgSizeD_[i],
                       imgSizeH_[i],
                       imgSizeW_[i],
C
chengduoZH 已提交
109 110 111 112 113 114 115 116 117 118 119
                       filterSizeZ_[i],
                       filterSizeY_[i],
                       filterSize_[i],
                       strideZ_[i],
                       strideY_[i],
                       stride_[i],
                       paddingZ_[i],
                       paddingY_[i],
                       padding_[i],
                       1.0,
                       1.0);
C
chengduoZH 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    }
  }
  if (nullptr != this->biasParameter_) {
    this->addBias();
  }
  forwardActivation();
}

void DeConv3DLayer::backward(const UpdateCallback &callback) {
  backwardActivation();
  int batchSize = getOutputGrad()->getHeight();
  if (biases_ && biases_->getWGrad()) {
    bpropBiases();
    biases_->getParameterPtr()->incUpdate(callback);
  }
135
  REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str());
C
chengduoZH 已提交
136 137 138 139 140 141 142 143 144 145 146
  for (size_t i = 0; i < inputLayers_.size(); ++i) {
    if (weights_[i]->getWGrad() || this->needGradient_) {
      int M = M_[i];
      int N = N_[i];
      int K = K_[i];
      Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
      const MatrixPtr &inMat = getInputValue(i);
      for (int n = 0; n < batchSize; ++n) {
        colBuf_->vol2Col(
            getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
            numFilters_,
147 148 149
            imgSizeD_[i],
            imgSizeH_[i],
            imgSizeW_[i],
C
chengduoZH 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
            filterSizeZ_[i],
            filterSizeY_[i],
            filterSize_[i],
            strideZ_[i],
            strideY_[i],
            stride_[i],
            paddingZ_[i],
            paddingY_[i],
            padding_[i]);
        if (weights_[i]->getWGrad()) {
          real *inData = inMat->getData() + n * inMat->getStride();
          for (int g = 0; g < groups_[i]; ++g) {
            MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
            MatrixPtr wGradMatSub =
                weights_[i]->getWGrad()->subMatrix(g * K, K);
            MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
            wGradMatSub->mul(
                *colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0);
            inData += M * N;
          }
C
chengduoZH 已提交
170
        }
C
chengduoZH 已提交
171 172 173 174 175 176 177 178 179 180 181
        if (getInputGrad(i)) {
          real *preGrad =
              getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
          for (int g = 0; g < groups_[i]; ++g) {
            MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K);
            MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K);
            MatrixPtr inGradMatSub =
                Matrix::create(preGrad, M, N, false, useGpu_);
            inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0);
            preGrad += M * N;
          }
C
chengduoZH 已提交
182 183
        }
      }
C
chengduoZH 已提交
184
      weights_[i]->getParameterPtr()->incUpdate(callback);
C
chengduoZH 已提交
185 186 187
    }
  }
}
C
chengduoZH 已提交
188 189
void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropData(int i) {}
C
chengduoZH 已提交
190 191

void DeConv3DLayer::bpropBiases() {
C
chengduoZH 已提交
192 193 194 195 196
  MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(),
                                    1,
                                    biases_->getWGrad()->getElementCnt(),
                                    false,
                                    useGpu_);
C
chengduoZH 已提交
197
  const MatrixPtr &outGradMat = getOutputGrad();
C
chengduoZH 已提交
198 199

  if (this->sharedBiases_) {
C
chengduoZH 已提交
200
    biases->collectSharedBias(*outGradMat, 1.0f);
C
chengduoZH 已提交
201
  } else {
C
chengduoZH 已提交
202
    biases->collectBias(*outGradMat, 1.0f);
C
chengduoZH 已提交
203 204 205 206 207
  }
}

void DeConv3DLayer::addBias() {
  MatrixPtr outMat = getOutputValue();
C
chengduoZH 已提交
208 209 210 211 212
  MatrixPtr bias = Matrix::create(biases_->getW()->getData(),
                                  1,
                                  biases_->getW()->getElementCnt(),
                                  false,
                                  useGpu_);
C
chengduoZH 已提交
213
  if (this->sharedBiases_) {
C
chengduoZH 已提交
214
    outMat->addSharedBias(*(bias), 1.0f);
C
chengduoZH 已提交
215
  } else {
C
chengduoZH 已提交
216
    outMat->addBias(*(bias), 1.0f);
C
chengduoZH 已提交
217 218 219 220
  }
}

}  // namespace paddle