提交 63748319 编写于 作者: C chengduoZH

fix conflict

......@@ -224,4 +224,80 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
extern void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
/**
* @brief Matrix vol2Col: Convert 3D volume into col matrix
*
* @param[in] matSrc input matrix.
* @param[in] channel channel of matSrc.
* @param[in] depth depth of matSrc.
* @param[in] height height of matSrc.
* @param[in] width width of matSrc.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[out] dataDst output matrix.
*
*/
extern void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst);
/**
* @brief Matrix col2Vol: Convert col matrix into 3D volume
*
* @param[out] matDst output matrix.
* @param[in] channel channel of matDst.
* @param[in] depth depth of matDst.
* @param[in] height height of matDst.
* @param[in] width width of matDst.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[in] matSrc input matrix.
* @param[in] beta input
* @param[in] alpha input
*
*/
extern void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta);
#endif /* HL_MATRIX_H_ */
......@@ -99,4 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
inline void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {}
inline void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta) {}
#endif // HL_MATRIX_STUB_H_
......@@ -592,3 +592,204 @@ void hl_matrix_rotate(
mat, matRot, dimM, dimN, clockWise);
CHECK_SYNC("hl_matrix_rotate failed");
}
__global__ void keMatrixVol2Col(int num_kernels,
const real* dataSrc,
real* dataDst,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int d_out = (index / width_col / height_col) % depth_col;
int channel_in = index / width_col / height_col / depth_col;
int channel_out = channel_in * filterD * filterH * filterW;
int w_in = w_out * strideW - paddingW;
int h_in = h_out * strideH - paddingH;
int d_in = d_out * strideD - paddingD;
dataDst +=
((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
w_out;
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filterD; ++k) {
for (int i = 0; i < filterH; ++i) {
for (int j = 0; j < filterW; ++j) {
int d = d_in + k;
int h = h_in + i;
int w = w_in + j;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? dataSrc[(k * height + i) * width + j]
: 0;
dataDst += depth_col * height_col * width_col;
}
}
}
}
}
void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth_col * height_col * width_col;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataSrc,
dataDst,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col);
CHECK_SYNC("hl_matrix_vol2Col failed");
}
__global__ void keMatrixCol2Vol(int num_kernels,
real* dataDst,
const real* dataSrc,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col,
real alpha,
real beta) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
real srcVal = 0;
real dstVal = dataDst[index];
int w = index % width + paddingW;
int h = (index / width) % height + paddingH;
int d = (index / width / height) % depth + paddingD;
int c = index / width / height / depth;
// compute the start and end of the output
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
int w_col_end = min(w / strideW + 1, width_col);
int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1;
int h_col_end = min(h / strideH + 1, height_col);
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
int d_col_end = min(d / strideD + 1, depth_col);
int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
h * filterW + w) *
depth_col * height_col * width_col;
int coeff_d_col =
(1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col =
(1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
w_col * coeff_w_col];
}
}
}
dataDst[index] = alpha * srcVal + beta * dstVal;
}
}
void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth * height * width;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataDst,
dataSrc,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col,
alpha,
beta);
CHECK_SYNC("hl_matrix_col2Vol failed");
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Conv3DLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(conv3d, Conv3DLayer);
bool Conv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
int index = 0;
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
M_.push_back(numFilters_ / conf.groups());
K_.push_back(filterPixels_[index] * filterChannels_[index]);
// create a new weight
size_t height, width;
width = filterPixels_[index] * filterChannels_[index];
height = numFilters_;
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
++index;
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
return true;
}
size_t Conv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
outputH_.clear();
outputW_.clear();
outputD_.clear();
N_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
outputW_.push_back(outputSize(
imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(outputSize(
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(outputSize(
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += N_[i] * numFilters_;
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
getOutput().setFrameDepth(outputD_[0]);
return layerSize;
}
void Conv3DLayer::forward(PassType passType) {
Layer::forward(passType);
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outWidth = getSize();
resetOutput(batchSize, outWidth);
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdConv3D", getName().c_str());
const MatrixPtr &inMat = getInputValue(i);
const MatrixPtr &outMat = getOutputValue();
int M = M_[i];
int N = N_[i];
int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW();
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outData = outMat->getData() + n * outMat->getStride();
MatrixPtr outMatSub =
Matrix::create(outData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; g++) {
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr in = colBuf_->subMatrix(g * K, K);
MatrixPtr out = outMatSub->subMatrix(g * M, M);
out->mul(*wMatSub, *in, 1.0, 1.0);
}
}
}
if (nullptr != this->biasParameter_) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
this->addBias();
}
forwardActivation();
}
void Conv3DLayer::backward(const UpdateCallback &callback) {
backwardActivation();
if (biases_ && biases_->getWGrad()) {
bpropBiases();
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("BwdConv3D", getName().c_str());
if (weights_[i]->getWGrad()) {
bpropWeights(i);
}
if (getInputGrad(i)) {
bpropData(i);
}
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
void Conv3DLayer::bpropWeights(int i) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
const MatrixPtr &inMat = getInputValue(i);
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wGradMat = weights_[i]->getWGrad();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outGradData =
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
MatrixPtr outGradSub =
Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K);
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M);
wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0);
}
}
}
void Conv3DLayer::bpropData(int i) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) {
real *outGradData =
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
real *preGradData =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
MatrixPtr outGradSub =
Matrix::create(outGradData, M * groups_[i], N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K);
inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0);
}
colBuf_->col2Vol(preGradData,
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
}
}
void Conv3DLayer::bpropBiases() {
MatrixPtr outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
}
}
void Conv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A subclass of convolution layer.
* This layer expands input and use matrix multiplication to
* calculate convolution operation.
*/
class Conv3DLayer : public ConvBaseLayer {
public:
explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~Conv3DLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void addBias();
void backward(const UpdateCallback& callback);
void bpropBiases();
void bpropData(int i);
void bpropWeights(int i);
size_t getSize();
protected:
// Figure out the dimensions for individual gemms.
IntV M_; /// numFilters_ / filter_group_;
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV K_; /// outputD_ * outputH_ * outputW_
MatrixPtr colBuf_;
};
} // namespace paddle
......@@ -38,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
strideY_.push_back(conf.stride_y());
dilationY_.push_back(conf.dilation_y());
filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels());
imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y()
: conf.img_size());
......@@ -47,31 +46,20 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
filterChannels_.push_back(conf.filter_channels());
outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x());
outputW_.push_back(conf.output_x());
paddingZ_.push_back(conf.padding_z());
strideZ_.push_back(conf.stride_z());
filterSizeZ_.push_back(conf.filter_size_z());
imgSizeD_.push_back(conf.img_size_z());
outputD_.push_back(conf.output_z());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back() *
filterSizeZ_.back());
}
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
/* initialize the biases_ */
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
// create new weights_ in derived class
// create new biases_ in derived class
// default caffe model
caffeMode_ = true;
......
......@@ -62,6 +62,13 @@ protected:
IntV outputH_;
/// The spatial dimensions of output feature map width.
IntV outputW_;
IntV outputD_;
IntV imgSizeD_;
IntV filterSizeZ_;
IntV strideZ_;
IntV paddingZ_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
......
......@@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap,
projConf_.emplace_back(conf);
projections_.emplace_back(
Projection::create(*projConf_[i], parameters_[i], useGpu_));
// create a new weight
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
if (biases_.get() && sharedBiases_) {
hl_create_tensor_descriptor(&biasDesc_);
hl_create_tensor_descriptor(&outputDesc_);
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DeConv3DLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(deconv3d, DeConv3DLayer);
bool DeConv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
// for Deconv, the dimension of Kernel is
// channel * output * depth * height * weigth
// Matrix storage format: (output * depth * height * weigth) x channel
for (int index = 0; index < config_.inputs().size(); ++index) {
M_.push_back(filterChannels_[index]);
K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index]));
// create a new weight
size_t height, width;
height = filterPixels_[index] * numFilters_;
width = filterChannels_[index];
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
return true;
}
size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
outputH_.clear();
outputW_.clear();
outputD_.clear();
N_.clear();
NOut_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
outputW_.push_back(
imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(imageSize(
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(imageSize(
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += NOut_[i] * numFilters_;
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
getOutput().setFrameDepth(outputD_[0]);
return layerSize;
}
void DeConv3DLayer::forward(PassType passType) {
Layer::forward(passType);
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outWidth = getSize();
resetOutput(batchSize, outWidth);
const MatrixPtr outMat = getOutputValue();
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str());
const MatrixPtr &inMat = getInputValue(i);
int M = M_[i];
int N = N_[i];
int K = K_[i];
MatrixPtr wMat = weights_[i]->getW();
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
for (int n = 0; n < batchSize; ++n) {
real *inData = inMat->getData() + n * inMat->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
MatrixPtr wMatSub = wMat->subMatrix(g * K, K);
MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0);
inData += M * N;
}
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_,
outputD_[i],
outputH_[i],
outputW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
}
}
if (nullptr != this->biasParameter_) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
this->addBias();
}
forwardActivation();
}
void DeConv3DLayer::backward(const UpdateCallback &callback) {
backwardActivation();
int batchSize = getOutputGrad()->getHeight();
if (biases_ && biases_->getWGrad()) {
bpropBiases();
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i < inputLayers_.size(); ++i) {
if (weights_[i]->getWGrad() || this->needGradient_) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str());
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
const MatrixPtr &inMat = getInputValue(i);
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(
getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
numFilters_,
outputD_[i],
outputH_[i],
outputW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
if (weights_[i]->getWGrad()) {
real *inData = inMat->getData() + n * inMat->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
MatrixPtr wGradMatSub =
weights_[i]->getWGrad()->subMatrix(g * K, K);
MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
wGradMatSub->mul(
*colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0);
inData += M * N;
}
}
if (getInputGrad(i)) {
real *preGrad =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K);
MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K);
MatrixPtr inGradMatSub =
Matrix::create(preGrad, M, N, false, useGpu_);
inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0);
preGrad += M * N;
}
}
}
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
}
void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropData(int i) {}
void DeConv3DLayer::bpropBiases() {
const MatrixPtr &outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
}
}
void DeConv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A subclass of deconvolution3D layer.
* This layer expands input and use matrix multiplication to
* calculate deconvolution3D operation.
*/
class DeConv3DLayer : public ConvBaseLayer {
public:
explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~DeConv3DLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void addBias();
void backward(const UpdateCallback& callback);
void bpropBiases();
void bpropData(int i);
void bpropWeights(int i);
size_t getSize();
protected:
// Figure out the dimensions for individual gemms.
IntV M_; /// numFilters_ / filter_group_;
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV K_; /// outputD_ * outputH_ * outputW_
IntV NOut_;
MatrixPtr colBuf_;
};
} // namespace paddle
......@@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
/* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap);
int index = 0;
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
/* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode();
}
// create a new weight
size_t height, width;
height = filterPixels_[index] * filterChannels_[index];
width = (!isDeconv_) ? numFilters_ : channels_[index];
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
index++;
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
getOutputSize();
return true;
......
......@@ -2116,6 +2116,159 @@ TEST(Layer, RowL2NormLayer) {
}
}
void test3DConvLayer(const string& type, bool trans, bool useGpu) {
// filter size
const int NUM_FILTERS = 6;
// const int CHANNELS = 3;
const int FILTER_SIZE = 3;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Z = 3;
// input image
const int CHANNELS = 3;
const int IMAGE_SIZE = 9;
const int IMAGE_SIZE_Y = 9;
const int IMAGE_SIZE_Z = 9;
TestConfig config;
config.biasSize = NUM_FILTERS;
config.layerConfig.set_type(type);
config.layerConfig.set_num_filters(NUM_FILTERS);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
// Setting up conv3D-trans layer
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_channels(CHANNELS);
conv->set_filter_size(FILTER_SIZE);
conv->set_filter_size_y(FILTER_SIZE_Y);
conv->set_filter_size_z(FILTER_SIZE_Z);
conv->set_padding(0);
conv->set_padding_y(0);
conv->set_padding_z(0);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z);
conv->set_output_x(outputSize(conv->img_size(),
conv->filter_size(),
conv->padding(),
conv->stride(),
/* caffeMode */ true));
conv->set_output_y(outputSize(conv->img_size_y(),
conv->filter_size_y(),
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true));
conv->set_output_z(outputSize(conv->img_size_z(),
conv->filter_size_z(),
conv->padding_z(),
conv->stride_z(),
/* caffeMode */ true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
conv->output_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z,
conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z *
NUM_FILTERS});
testLayerGrad(config, "conv3D", 10, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "conv3D", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, test3DConvLayer) {
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
// filter size
const int NUM_FILTERS = 6;
// const int CHANNELS = 3;
const int FILTER_SIZE = 3;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Z = 3;
// input image
const int CHANNELS = 3;
const int IMAGE_SIZE = 4;
const int IMAGE_SIZE_Y = 6;
const int IMAGE_SIZE_Z = 6;
// Setting up conv-trans layer
TestConfig config;
config.biasSize = NUM_FILTERS;
config.layerConfig.set_type("deconv3d");
config.layerConfig.set_num_filters(NUM_FILTERS);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_channels(CHANNELS);
conv->set_filter_size(FILTER_SIZE);
conv->set_filter_size_y(FILTER_SIZE_Y);
conv->set_filter_size_z(FILTER_SIZE_Z);
conv->set_padding(0);
conv->set_padding_y(0);
conv->set_padding_z(0);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z);
conv->set_output_x(imageSize(conv->img_size(),
conv->filter_size(),
conv->padding(),
conv->stride(),
true));
conv->set_output_y(imageSize(conv->img_size_y(),
conv->filter_size_y(),
conv->padding_y(),
conv->stride_y(),
true));
conv->set_output_z(imageSize(conv->img_size_z(),
conv->filter_size_z(),
conv->padding_z(),
conv->stride_z(),
true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
conv->output_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z,
conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z *
NUM_FILTERS});
testLayerGrad(config, "deconv3D", 10, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "deconv3D", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, test3DDeConvLayer) {
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
TEST(Layer, ScaleShiftLayer) {
const size_t batchSize = 16;
const size_t size = 32;
......
......@@ -1604,6 +1604,72 @@ void GpuMatrix::multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label) {
output_d, grad_d, mat_d, height_, width_);
}
void GpuMatrix::vol2Col(real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
hl_matrix_vol2Col(dataSrc,
channels,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData());
}
void GpuMatrix::col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
hl_matrix_col2Vol(dataDst,
channels,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
getData(),
alpha,
beta);
}
/**
* CpuMatrix
*/
......@@ -4460,6 +4526,95 @@ void CpuMatrix::bilinearBackward(const Matrix& out,
}
}
void CpuMatrix::vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
real* outData = getData();
int outHeight = (height + 2 * paddingH - filterH) / strideH + 1;
int outWidth = (width + 2 * paddingW - filterW) / strideW + 1;
int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1;
int channelsCol = channels * filterD * filterH * filterW;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterW;
int hOffset = (c / filterW) % filterH;
int dOffset = (c / filterW / filterH) % filterD;
int cIn = c / filterW / filterH / filterD;
for (int d = 0; d < outDepth; ++d) {
for (int h = 0; h < outHeight; ++h) {
for (int w = 0; w < outWidth; ++w) {
int dPad = d * strideD - paddingD + dOffset;
int hPad = h * strideH - paddingH + hOffset;
int wPad = w * strideW - paddingW + wOffset;
if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width &&
dPad >= 0 && dPad < depth)
outData[((c * outDepth + d) * outHeight + h) * outWidth + w] =
data[((cIn * depth + dPad) * height + hPad) * width + wPad];
else
outData[((c * outDepth + d) * outHeight + h) * outWidth + w] = 0;
}
}
}
}
}
void CpuMatrix::col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
real* src = getData();
int outDepth = (depth + 2 * paddingD - filterD) / strideD + 1;
int outHeight = (height + 2 * paddingH - filterH) / strideH + 1;
int outWidth = (width + 2 * paddingW - filterW) / strideW + 1;
int channelsCol = channels * filterD * filterH * filterW;
for (int c = 0; c < channelsCol; ++c) {
int wOffset = c % filterW;
int hOffset = (c / filterW) % filterH;
int dOffset = (c / filterW / filterH) % filterD;
int cIm = c / filterW / filterH / filterD;
for (int d = 0; d < outDepth; ++d) {
for (int h = 0; h < outHeight; ++h) {
for (int w = 0; w < outWidth; ++w) {
int dPad = d * strideD - paddingD + dOffset;
int hPad = h * strideH - paddingH + hOffset;
int wPad = w * strideW - paddingW + wOffset;
if (hPad >= 0 && hPad < height && wPad >= 0 && wPad < width &&
dPad >= 0 && dPad < depth)
trg[((cIm * depth + dPad) * height + hPad) * width + wPad] =
alpha *
src[((c * outDepth + d) * outHeight + h) * outWidth + w] +
beta *
trg[((cIm * depth + dPad) * height + hPad) * width + wPad];
}
}
}
}
}
////////////////////////////////////////////////////////////////
// functions executed via cpu //
////////////////////////////////////////////////////////////////
......
......@@ -1126,6 +1126,42 @@ public:
LOG(FATAL) << "Not implemented";
}
virtual void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
LOG(FATAL) << "Not implemeted";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
......@@ -1537,6 +1573,38 @@ public:
const real ratioH,
const real ratioW);
void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
......@@ -1954,6 +2022,38 @@ public:
const real ratioH,
const real ratioW);
void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorCpuApply<real>(*this, expr);
......
......@@ -1524,7 +1524,6 @@ TEST(Matrix, Pool3DFwdBwd) {
}
}
//
// for (auto numSamples : {1, 3}) {
// for (auto channels : {1, 3}) {
// for (auto imgSizeD : {9,16}) {
......@@ -1598,4 +1597,104 @@ TEST(Matrix, Pool3DFwdBwd) {
// }
}
void testMatrixCol2Vol(int depth, int height, int width) {
int channel = 3;
int filterX = 3, filterY = 4, filterZ = 5;
int strideX = 2, strideY = 2, strideZ = 2;
int padX = 1, padY = 1, padZ = 1;
MatrixPtr cpuImage =
std::make_shared<CpuMatrix>(channel, depth * height * width);
MatrixPtr gpuImage =
std::make_shared<GpuMatrix>(channel, depth * height * width);
cpuImage->randomizeUniform();
gpuImage->copyFrom(*cpuImage);
int outD = outputSize(depth, filterZ, padZ, strideZ, true);
int outH = outputSize(height, filterY, padY, strideY, true);
int outW = outputSize(width, filterX, padX, strideX, true);
int colBufHeight = channel * filterZ * filterY * filterX;
int colBufWidth = outD * outH * outW;
MatrixPtr cpuColBuf = std::make_shared<CpuMatrix>(colBufHeight, colBufWidth);
MatrixPtr gpuColBuf = std::make_shared<GpuMatrix>(colBufHeight, colBufWidth);
cpuColBuf->vol2Col(cpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX);
gpuColBuf->vol2Col(gpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX);
TensorCheckEqual(*cpuColBuf, *gpuColBuf);
cpuColBuf->randomizeUniform();
gpuColBuf->copyFrom(*cpuColBuf);
cpuColBuf->col2Vol(cpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX,
1.0,
1.0);
gpuColBuf->col2Vol(gpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX,
1.0,
1.0);
TensorCheckErr(*cpuImage, *gpuImage);
}
TEST(Matrix, col2Vol) {
for (auto depth : {9, 16, 64}) {
for (auto height : {9, 11, 128}) {
for (auto width : {9, 32, 128}) {
VLOG(3) << "depth=" << depth << " height=" << height
<< " width=" << width;
testMatrixCol2Vol(depth, height, width);
}
}
}
}
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -886,6 +886,36 @@ class Conv(Cfg):
config_assert(output_x <= 0)
# please refer to the comments in proto/ModelConfig.proto
@config_class
class Conv3D(Cfg):
def __init__(self,
filter_size,
channels,
padding=None,
stride=None,
groups=None,
filter_channels=None,
output_x=None,
img_size=None,
caffe_mode=True,
filter_size_y=None,
padding_y=None,
stride_y=None,
filter_size_z=None,
padding_z=None,
stride_z=None):
self.add_keys(locals())
self.filter_size_y = filter_size_y if filter_size_y else filter_size
self.filter_size_z = filter_size_z if filter_size_z else filter_size
self.padding_y = padding_y if padding_y else padding
self.padding_z = padding_z if padding_z else padding
self.stride_y = stride_y if stride_y else stride
self.stride_z = stride_z if stride_z else stride
if output_x is not None:
config_assert(output_x <= 0)
@config_class
class BilinearInterp(Cfg):
def __init__(self, out_size_x=None, out_size_y=None, channels=None):
......@@ -1360,6 +1390,50 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.stride_y, conv_conf.caffe_mode)
#caffe_mode: compute the output size using floor instead of ceil,
# which is consistent of caffe and CuDNN's convention.
def parse_conv3d(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.filter_size = conv.filter_size
conv_conf.filter_size_y = conv.filter_size_y
conv_conf.filter_size_z = conv.filter_size_z
conv_conf.channels = conv.channels
conv_conf.padding = conv.padding
conv_conf.padding_y = conv.padding_y
conv_conf.padding_z = conv.padding_z
conv_conf.stride = conv.stride
conv_conf.stride_y = conv.stride_y
conv_conf.stride_z = conv.stride_z
conv_conf.groups = conv.groups
conv_conf.caffe_mode = conv.caffe_mode
if not trans:
conv_conf.filter_channels = conv.channels / conv.groups
conv_conf.img_size, conv_conf.img_size_y, conv_conf.img_size_z = \
get_img3d_size(input_layer_name, conv.channels)
conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.output_y = cnn_output_size(
conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.output_z = cnn_output_size(
conv_conf.img_size_z, conv_conf.filter_size_z, conv_conf.padding_z,
conv_conf.stride_z, conv_conf.caffe_mode)
else:
conv_conf.filter_channels = num_filters / conv.groups
conv_conf.output_x, conv_conf.output_y, conv_conf.output_z = \
get_img3d_size(input_layer_name, conv.channels)
conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.img_size_y = cnn_image_size(
conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.img_size_z = cnn_image_size(
conv_conf.output_z, conv_conf.filter_size_z, conv_conf.padding_z,
conv_conf.stride_z, conv_conf.caffe_mode)
def parse_block_expand(block_expand, input_layer_name, block_expand_conf):
block_expand_conf.channels = block_expand.channels
block_expand_conf.stride_x = block_expand.stride_x
......@@ -2011,7 +2085,7 @@ class ConvLayerBase(LayerBase):
def calc_parameter_size(self, conv_conf):
return self.config.num_filters * conv_conf.filter_channels \
* (conv_conf.filter_size * conv_conf.filter_size_y)
* (conv_conf.filter_size * conv_conf.filter_size_y)
@config_layer('exconv')
......@@ -2095,6 +2169,87 @@ class ConvTransLayer(ConvTransLayerBase):
layer_type = 'cudnn_convt'
@config_layer('conv_3d')
class Conv3DLayerBase(LayerBase):
def __init__(self,
name,
inputs=[],
bias=True,
num_filters=None,
shared_biases=True,
**xargs):
super(Conv3DLayerBase, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs)
if num_filters is not None:
self.config.num_filters = num_filters
# need to specify layer in config
self.config.type = self.layer_type
trans = False
if self.config.type == "deconv3d":
trans = True
if shared_biases is not None:
self.config.shared_biases = shared_biases
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
conv_conf = self.config.inputs[input_index].conv_conf
parse_conv3d(
self.inputs[input_index].conv,
input_layer.name,
conv_conf,
num_filters,
trans=trans
) # for z-axis pad:0, strid:1, filter_size:1, img_size:1
psize = self.calc_parameter_size(conv_conf)
self.create_input_parameter(input_index, psize)
if trans:
self.set_cnn_layer(name, conv_conf.img_size_z,
conv_conf.img_size_y, conv_conf.img_size,
self.config.num_filters)
else:
self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y,
conv_conf.output_x, self.config.num_filters)
psize = self.config.size
if shared_biases:
psize = self.config.num_filters
self.create_bias_parameter(bias, psize, [psize, 1])
def calc_parameter_size(self, conv_conf):
return self.config.num_filters * conv_conf.filter_channels \
* (conv_conf.filter_size * conv_conf.filter_size_y \
* conv_conf.filter_size_z)
def set_cnn_layer(self,
input_layer_name,
depth,
height,
width,
channels,
is_print=True):
size = depth * height * width * channels
self.set_layer_size(size)
self.set_layer_height_width(height, width)
self.set_layer_depth(depth)
if is_print:
print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, depth, height, width, size))
@config_layer('conv3d')
class Conv3DLayer(Conv3DLayerBase):
layer_type = 'conv3d'
@config_layer('deconv3d')
class Conv3DLayer(Conv3DLayerBase):
layer_type = 'deconv3d'
@config_layer('norm')
class NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
文件模式从 100755 更改为 100644
......@@ -140,6 +140,7 @@ __all__ = [
'kmax_sequence_score_layer',
'img_pool3d_layer',
'scale_shift_layer',
'img_conv3d_layer',
]
......@@ -222,6 +223,9 @@ class LayerType(object):
CRF_DECODING_LAYER = 'crf_decoding'
NCE_LAYER = 'nce'
CONV3D_LAYER = 'conv3d'
DECONV3D_LAYER = 'deconv3d'
RANK_COST = 'rank-cost'
LAMBDA_COST = 'lambda_cost'
HUBER_REGRESSION = 'huber_regression'
......@@ -6629,6 +6633,149 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1):
name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size)
@wrap_name_default("conv3d")
@wrap_param_attr_default()
@wrap_bias_attr_default()
@wrap_act_default(act=ReluActivation())
@layer_support(DROPOUT)
def img_conv3d_layer(input,
filter_size,
num_filters,
name=None,
num_channels=None,
act=None,
groups=1,
stride=1,
padding=0,
bias_attr=None,
param_attr=None,
shared_biases=True,
layer_attr=None,
trans=False,
layer_type=None):
"""
The example usage is:
.. code-block:: python
conv = img_conv3d_layer(input=data, filter_size=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
:param name: Layer name.
:type name: basestring
:param input: Layer Input.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a list.
:type filter_size: int|tuple|list
:param num_filters: Each filter group's number of filter
:param act: Activation type. Default is tanh
:type act: BaseActivation
:param groups: Group size of filters.
:type groups: int
:param stride: The x dimension of the stride. Or input a tuple for two image
dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension
:type padding: int|tuple|list
:param bias_attr: Convolution bias attribute. None means default bias.
False means no bias.
:type bias_attr: ParameterAttribute|False
:param num_channels: number of input channels. If None will be set
automatically from previous output.
:type num_channels: int
:param param_attr: Convolution param attribute. None means default attribute
:type param_attr: ParameterAttribute
:param shared_biases: Is biases will be shared between filters or not.
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:type layer_attr: ExtraLayerAttribute
:param trans: true if it is a convTransLayer, false if it is a convLayer
:type trans: bool
:param layer_type: specify the layer_type, default is None. If trans=True,
layer_type has to be "exconvt" or "cudnn_convt",
otherwise layer_type has to be either "exconv" or
"cudnn_conv"
:type layer_type: String
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if isinstance(filter_size, collections.Sequence):
assert len(filter_size) == 3
filter_size, filter_size_y, filter_size_z = filter_size
else:
filter_size_y = filter_size
filter_size_z = filter_size
if isinstance(stride, collections.Sequence):
assert len(stride) == 3
stride, stride_y, stride_z = stride
else:
stride_y = stride
stride_z = stride
if isinstance(padding, collections.Sequence):
assert len(padding) == 3
padding, padding_y, padding_z = padding
else:
padding_y = padding
padding_z = padding
if param_attr.attr.get('initial_smart'):
# special initial for conv layers.
init_w = (2.0 / (filter_size**2 * num_channels))**0.5
param_attr.attr["initial_mean"] = 0.0
param_attr.attr["initial_std"] = init_w
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
if layer_type:
if trans:
assert layer_type in ["deconv3d"]
lt = layer_type
else:
lt = LayerType.DECONV3D_LAYER if trans else LayerType.CONV3D_LAYER
l = Layer(
name=name,
inputs=Input(
input.name,
conv=Conv3D(
filter_size=filter_size,
padding=padding,
stride=stride,
channels=num_channels,
groups=groups,
filter_size_y=filter_size_y,
padding_y=padding_y,
stride_y=stride_y,
filter_size_z=filter_size_z,
padding_z=padding_z,
stride_z=stride_z),
**param_attr.attr),
active_type=act.name,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
shared_biases=shared_biases,
type=lt,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
lt,
parents=[input],
activation=act,
num_filters=num_filters,
size=l.config.size)
@wrap_name_default("scale_shift")
@wrap_param_attr_default()
@wrap_bias_attr_default()
......
文件模式从 100755 更改为 100644
......@@ -9,6 +9,7 @@ test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer)
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 36288
active_type: ""
height: 48
width: 42
depth: 6
}
layers {
name: "conv3d_1"
type: "conv3d"
size: 24192
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_conv3d_1.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 3
output_x: 21
img_size: 42
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 24
img_size_y: 48
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 3
img_size_z: 6
}
}
bias_parameter_name: "_conv3d_1.wbias"
num_filters: 16
shared_biases: true
height: 24
width: 21
depth: 3
}
layers {
name: "conv3d_2"
type: "conv3d"
size: 24192
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_conv3d_2.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 3
output_x: 21
img_size: 42
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 24
img_size_y: 48
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 3
img_size_z: 6
}
}
bias_parameter_name: "_conv3d_2.wbias"
num_filters: 16
shared_biases: true
height: 24
width: 21
depth: 3
}
parameters {
name: "_conv3d_1.w0"
size: 1296
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_1.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_2.w0"
size: 1296
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_2.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "conv3d_2"
sub_models {
name: "root"
layer_names: "data"
layer_names: "conv3d_1"
layer_names: "conv3d_2"
input_layer_names: "data"
output_layer_names: "conv3d_2"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "data"
type: "data"
size: 36288
active_type: ""
height: 48
width: 42
depth: 6
}
layers {
name: "deconv3d_1"
type: "deconv3d"
size: 1387760
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_deconv3d_1.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 16
output_x: 42
img_size: 83
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 48
img_size_y: 95
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 6
img_size_z: 11
}
}
bias_parameter_name: "_deconv3d_1.wbias"
num_filters: 16
shared_biases: true
height: 95
width: 83
depth: 11
}
layers {
name: "deconv3d_2"
type: "deconv3d"
size: 1387760
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_deconv3d_2.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 16
output_x: 42
img_size: 83
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 48
img_size_y: 95
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 6
img_size_z: 11
}
}
bias_parameter_name: "_deconv3d_2.wbias"
num_filters: 16
shared_biases: true
height: 95
width: 83
depth: 11
}
parameters {
name: "_deconv3d_1.w0"
size: 6912
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_1.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_2.w0"
size: 6912
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_2.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "deconv3d_2"
sub_models {
name: "root"
layer_names: "data"
layer_names: "deconv3d_1"
layer_names: "deconv3d_2"
input_layer_names: "data"
output_layer_names: "deconv3d_2"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
num_channels = 3
filter_size = 3
filter_size_y = 3
filter_size_z = 3
stride = 2
stride_y = 2
stride_z = 2
padding = 1
padding_y = 1
padding_z = 1
groups = 1
data = data_layer(
name='data', size=12096 * num_channels, height=48, width=42, depth=6)
# first
conv3d_1 = img_conv3d_layer(
input=data,
name='conv3d_1',
num_filters=16,
num_channels=num_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=True,
shared_biases=True,
trans=False,
layer_type="conv3d",
act=LinearActivation())
# second
conv3d_2 = img_conv3d_layer(
input=data,
name='conv3d_2',
num_filters=16,
num_channels=num_channels,
filter_size=[filter_size, filter_size_y, filter_size_z],
stride=[stride, stride_y, stride_z],
padding=[padding, padding_y, padding_z],
groups=groups,
bias_attr=True,
shared_biases=True,
trans=False,
layer_type="conv3d",
act=LinearActivation())
outputs(conv3d_2)
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
num_channels = 3
filter_size = 3
filter_size_y = 3
filter_size_z = 3
stride = 2
stride_y = 2
stride_z = 2
padding = 1
padding_y = 1
padding_z = 1
groups = 1
data = data_layer(
name='data', size=12096 * num_channels, height=48, width=42, depth=6)
# first
deconv3d_1 = img_conv3d_layer(
input=data,
name='deconv3d_1',
num_filters=16,
num_channels=num_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=True,
shared_biases=True,
trans=True,
layer_type="deconv3d",
act=LinearActivation())
# second
deconv3d_2 = img_conv3d_layer(
input=data,
name='deconv3d_2',
num_filters=16,
num_channels=num_channels,
filter_size=[filter_size, filter_size_y, filter_size_z],
stride=[stride, stride_y, stride_z],
padding=[padding, padding_y, padding_z],
groups=groups,
bias_attr=True,
shared_biases=True,
trans=True,
layer_type="deconv3d",
act=LinearActivation())
outputs(deconv3d_2)
......@@ -17,3 +17,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize
if __name__ == '__main__':
parse_config_and_serialize(
'trainer_config_helpers/tests/layers_test_config.py', '')
# layers_test_config.py
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册