提交 a55dd226 编写于 作者: C chengduoZH

fix conflict

......@@ -173,6 +173,96 @@ extern void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride);
extern void hl_maxpool3D_forward(const int frameCnt,
const real* inputData,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real* tgtData,
real* maxPoolIdxData,
const int tgtStride);
extern void hl_maxpool3D_backward(const int frameCnt,
const real* outGrad,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real scaleA,
real scaleB,
real* targetGrad,
real* maxPoolIdxData,
const int outStride);
extern void hl_avgpool3D_forward(const int frameCnt,
const real* inputData,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride);
extern void hl_avgpool3D_backward(const int frameCnt,
const real* outGrad,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
int paddingD,
int paddingH,
int paddingW,
real scaleA,
real scaleB,
real* backGrad,
const int outStride);
/**
* @brief Bilinear interpolation forward.
*
......@@ -275,4 +365,4 @@ extern void hl_maxout_backward(real* inGrad,
size_t featLen,
size_t groups);
#endif /* HL_CNN_H_ */
#endif // HL_CNN_H_
......@@ -224,4 +224,80 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
extern void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
/**
* @brief Matrix vol2Col: Convert 3D volume into col matrix
*
* @param[in] matSrc input matrix.
* @param[in] channel channel of matSrc.
* @param[in] depth depth of matSrc.
* @param[in] height height of matSrc.
* @param[in] width width of matSrc.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[out] dataDst output matrix.
*
*/
extern void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst);
/**
* @brief Matrix col2Vol: Convert col matrix into 3D volume
*
* @param[out] matDst output matrix.
* @param[in] channel channel of matDst.
* @param[in] depth depth of matDst.
* @param[in] height height of matDst.
* @param[in] width width of matDst.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[in] matSrc input matrix.
* @param[in] beta input
* @param[in] alpha input
*
*/
extern void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta);
#endif /* HL_MATRIX_H_ */
......@@ -87,6 +87,96 @@ inline void hl_avgpool_backward(const int frameCnt,
real* backGrad,
const int outStride) {}
inline void hl_maxpool3D_forward(const int frameCnt,
const real* inputData,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real* tgtData,
real* maxPoolIdxData,
const int tgtStride) {}
inline void hl_maxpool3D_backward(const int frameCnt,
const real* outGrad,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real scaleA,
real scaleB,
real* targetGrad,
real* maxPoolIdxData,
const int outStride) {}
inline void hl_avgpool3D_forward(const int frameCnt,
const real* inputData,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {}
inline void hl_avgpool3D_backward(const int frameCnt,
const real* outGrad,
const int channels,
const int depth,
const int height,
const int width,
const int pooledD,
const int pooledH,
const int pooledW,
const int sizeZ,
const int sizeY,
const int sizeX,
const int strideD,
const int strideH,
const int strideW,
const int paddingD,
const int paddingH,
const int paddingW,
real scaleA,
real scaleB,
real* backGrad,
const int outStride) {}
inline void hl_bilinear_forward(const real* inData,
const size_t inImgH,
const size_t inImgW,
......
......@@ -99,4 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
inline void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {}
inline void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta) {}
#endif // HL_MATRIX_STUB_H_
此差异已折叠。
......@@ -592,3 +592,204 @@ void hl_matrix_rotate(
mat, matRot, dimM, dimN, clockWise);
CHECK_SYNC("hl_matrix_rotate failed");
}
__global__ void keMatrixVol2Col(int num_kernels,
const real* dataSrc,
real* dataDst,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int d_out = (index / width_col / height_col) % depth_col;
int channel_in = index / width_col / height_col / depth_col;
int channel_out = channel_in * filterD * filterH * filterW;
int w_in = w_out * strideW - paddingW;
int h_in = h_out * strideH - paddingH;
int d_in = d_out * strideD - paddingD;
dataDst +=
((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
w_out;
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filterD; ++k) {
for (int i = 0; i < filterH; ++i) {
for (int j = 0; j < filterW; ++j) {
int d = d_in + k;
int h = h_in + i;
int w = w_in + j;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? dataSrc[(k * height + i) * width + j]
: 0;
dataDst += depth_col * height_col * width_col;
}
}
}
}
}
void hl_matrix_vol2Col(const real* dataSrc,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth_col * height_col * width_col;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataSrc,
dataDst,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col);
CHECK_SYNC("hl_matrix_vol2Col failed");
}
__global__ void keMatrixCol2Vol(int num_kernels,
real* dataDst,
const real* dataSrc,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col,
real alpha,
real beta) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
real srcVal = 0;
real dstVal = dataDst[index];
int w = index % width + paddingW;
int h = (index / width) % height + paddingH;
int d = (index / width / height) % depth + paddingD;
int c = index / width / height / depth;
// compute the start and end of the output
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
int w_col_end = min(w / strideW + 1, width_col);
int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1;
int h_col_end = min(h / strideH + 1, height_col);
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
int d_col_end = min(d / strideD + 1, depth_col);
int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
h * filterW + w) *
depth_col * height_col * width_col;
int coeff_d_col =
(1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col =
(1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
w_col * coeff_w_col];
}
}
}
dataDst[index] = alpha * srcVal + beta * dstVal;
}
}
void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth * height * width;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataDst,
dataSrc,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col,
alpha,
beta);
CHECK_SYNC("hl_matrix_col2Vol failed");
}
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Conv3DLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(conv3d, Conv3DLayer);
bool Conv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
int index = 0;
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
M_.push_back(numFilters_ / conf.groups());
K_.push_back(filterPixels_[index] * filterChannels_[index]);
// create a new weight
size_t height, width;
width = filterPixels_[index] * filterChannels_[index];
height = numFilters_;
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
++index;
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
return true;
}
size_t Conv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
outputH_.clear();
outputW_.clear();
outputD_.clear();
N_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
outputW_.push_back(outputSize(
imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(outputSize(
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(outputSize(
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += N_[i] * numFilters_;
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
getOutput().setFrameDepth(outputD_[0]);
return layerSize;
}
void Conv3DLayer::forward(PassType passType) {
Layer::forward(passType);
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outWidth = getSize();
resetOutput(batchSize, outWidth);
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdConv3D", getName().c_str());
const MatrixPtr &inMat = getInputValue(i);
const MatrixPtr &outMat = getOutputValue();
int M = M_[i];
int N = N_[i];
int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW();
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outData = outMat->getData() + n * outMat->getStride();
MatrixPtr outMatSub =
Matrix::create(outData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; g++) {
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr in = colBuf_->subMatrix(g * K, K);
MatrixPtr out = outMatSub->subMatrix(g * M, M);
out->mul(*wMatSub, *in, 1.0, 1.0);
}
}
}
if (nullptr != this->biasParameter_) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
this->addBias();
}
forwardActivation();
}
void Conv3DLayer::backward(const UpdateCallback &callback) {
backwardActivation();
if (biases_ && biases_->getWGrad()) {
bpropBiases();
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("BwdConv3D", getName().c_str());
if (weights_[i]->getWGrad()) {
bpropWeights(i);
}
if (getInputGrad(i)) {
bpropData(i);
}
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
void Conv3DLayer::bpropWeights(int i) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
const MatrixPtr &inMat = getInputValue(i);
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wGradMat = weights_[i]->getWGrad();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outGradData =
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
MatrixPtr outGradSub =
Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr inMatSub = colBuf_->subMatrix(g * K, K);
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
MatrixPtr wGradSub = wGradMat->subMatrix(g * M, M);
wGradSub->mul(*outG, *(inMatSub->getTranspose()), 1.0, 1.0);
}
}
}
void Conv3DLayer::bpropData(int i) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) {
real *outGradData =
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
real *preGradData =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
MatrixPtr outGradSub =
Matrix::create(outGradData, M * groups_[i], N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr outG = outGradSub->subMatrix(g * M, M);
MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K);
inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0);
}
colBuf_->col2Vol(preGradData,
channels_[i],
imgSizeD_[i],
imgSizeH_[i],
imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
}
}
void Conv3DLayer::bpropBiases() {
MatrixPtr outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
}
}
void Conv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A subclass of convolution layer.
* This layer expands input and use matrix multiplication to
* calculate convolution operation.
*/
class Conv3DLayer : public ConvBaseLayer {
public:
explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~Conv3DLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void addBias();
void backward(const UpdateCallback& callback);
void bpropBiases();
void bpropData(int i);
void bpropWeights(int i);
size_t getSize();
protected:
// Figure out the dimensions for individual gemms.
IntV M_; /// numFilters_ / filter_group_;
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV K_; /// outputD_ * outputH_ * outputW_
MatrixPtr colBuf_;
};
} // namespace paddle
......@@ -38,7 +38,6 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
strideY_.push_back(conf.stride_y());
dilationY_.push_back(conf.dilation_y());
filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels());
imgSizeH_.push_back(conf.has_img_size_y() ? conf.img_size_y()
: conf.img_size());
......@@ -47,31 +46,20 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
filterChannels_.push_back(conf.filter_channels());
outputH_.push_back(conf.has_output_y() ? conf.output_y() : conf.output_x());
outputW_.push_back(conf.output_x());
paddingZ_.push_back(conf.padding_z());
strideZ_.push_back(conf.stride_z());
filterSizeZ_.push_back(conf.filter_size_z());
imgSizeD_.push_back(conf.img_size_z());
outputD_.push_back(conf.output_z());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back() *
filterSizeZ_.back());
}
CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
/* initialize the biases_ */
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
// create new weights_ in derived class
// create new biases_ in derived class
// default caffe model
caffeMode_ = true;
......
......@@ -62,6 +62,13 @@ protected:
IntV outputH_;
/// The spatial dimensions of output feature map width.
IntV outputW_;
IntV outputD_;
IntV imgSizeD_;
IntV filterSizeZ_;
IntV strideZ_;
IntV paddingZ_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
......
......@@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap,
projConf_.emplace_back(conf);
projections_.emplace_back(
Projection::create(*projConf_[i], parameters_[i], useGpu_));
// create a new weight
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
if (biases_.get() && sharedBiases_) {
hl_create_tensor_descriptor(&biasDesc_);
hl_create_tensor_descriptor(&outputDesc_);
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "DeConv3DLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace paddle {
REGISTER_LAYER(deconv3d, DeConv3DLayer);
bool DeConv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
// for Deconv, the dimension of Kernel is
// channel * output * depth * height * weigth
// Matrix storage format: (output * depth * height * weigth) x channel
for (int index = 0; index < config_.inputs().size(); ++index) {
M_.push_back(filterChannels_[index]);
K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index]));
// create a new weight
size_t height, width;
height = filterPixels_[index] * numFilters_;
width = filterChannels_[index];
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
return true;
}
size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
outputH_.clear();
outputW_.clear();
outputD_.clear();
N_.clear();
NOut_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
outputW_.push_back(
imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(imageSize(
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(imageSize(
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += NOut_[i] * numFilters_;
}
getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]);
getOutput().setFrameDepth(outputD_[0]);
return layerSize;
}
void DeConv3DLayer::forward(PassType passType) {
Layer::forward(passType);
int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outWidth = getSize();
resetOutput(batchSize, outWidth);
const MatrixPtr outMat = getOutputValue();
for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str());
const MatrixPtr &inMat = getInputValue(i);
int M = M_[i];
int N = N_[i];
int K = K_[i];
MatrixPtr wMat = weights_[i]->getW();
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
for (int n = 0; n < batchSize; ++n) {
real *inData = inMat->getData() + n * inMat->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
MatrixPtr wMatSub = wMat->subMatrix(g * K, K);
MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0);
inData += M * N;
}
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_,
outputD_[i],
outputH_[i],
outputW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
}
}
if (nullptr != this->biasParameter_) {
REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
this->addBias();
}
forwardActivation();
}
void DeConv3DLayer::backward(const UpdateCallback &callback) {
backwardActivation();
int batchSize = getOutputGrad()->getHeight();
if (biases_ && biases_->getWGrad()) {
bpropBiases();
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i < inputLayers_.size(); ++i) {
if (weights_[i]->getWGrad() || this->needGradient_) {
int M = M_[i];
int N = N_[i];
int K = K_[i];
REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str());
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
const MatrixPtr &inMat = getInputValue(i);
for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(
getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
numFilters_,
outputD_[i],
outputH_[i],
outputW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
if (weights_[i]->getWGrad()) {
real *inData = inMat->getData() + n * inMat->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
MatrixPtr wGradMatSub =
weights_[i]->getWGrad()->subMatrix(g * K, K);
MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
wGradMatSub->mul(
*colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0);
inData += M * N;
}
}
if (getInputGrad(i)) {
real *preGrad =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K);
MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K);
MatrixPtr inGradMatSub =
Matrix::create(preGrad, M, N, false, useGpu_);
inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0);
preGrad += M * N;
}
}
}
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weights_[i]->getParameterPtr()->incUpdate(callback);
}
}
}
void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropData(int i) {}
void DeConv3DLayer::bpropBiases() {
const MatrixPtr &outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
}
}
void DeConv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
}
}
} // namespace paddle
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "ConvBaseLayer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief A subclass of deconvolution3D layer.
* This layer expands input and use matrix multiplication to
* calculate deconvolution3D operation.
*/
class DeConv3DLayer : public ConvBaseLayer {
public:
explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~DeConv3DLayer() {}
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
void forward(PassType passType);
void addBias();
void backward(const UpdateCallback& callback);
void bpropBiases();
void bpropData(int i);
void bpropWeights(int i);
size_t getSize();
protected:
// Figure out the dimensions for individual gemms.
IntV M_; /// numFilters_ / filter_group_;
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV K_; /// outputD_ * outputH_ * outputW_
IntV NOut_;
MatrixPtr colBuf_;
};
} // namespace paddle
......@@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
/* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap);
int index = 0;
for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf();
/* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode();
}
// create a new weight
size_t height, width;
height = filterPixels_[index] * filterChannels_[index];
width = (!isDeconv_) ? numFilters_ : channels_[index];
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
index++;
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
getOutputSize();
return true;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Pool3DLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
namespace paddle {
REGISTER_LAYER(pool3d, Pool3DLayer);
bool Pool3DLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);
/* the size of inputs for pool-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1);
const PoolConfig& conf = config_.inputs(0).pool_conf();
poolType_ = conf.pool_type();
channels_ = conf.channels();
sizeX_ = conf.size_x();
sizeY_ = conf.size_y();
sizeZ_ = conf.size_z();
strideW_ = conf.stride();
strideH_ = conf.stride_y();
strideD_ = conf.stride_z();
imgSizeW_ = conf.img_size();
imgSizeH_ = conf.img_size_y();
imgSizeD_ = conf.img_size_z();
paddingW_ = conf.padding();
paddingH_ = conf.padding_y();
paddingD_ = conf.padding_z();
outputW_ = conf.output_x();
outputH_ = conf.output_y();
outputD_ = conf.output_z();
return true;
}
size_t Pool3DLayer::getSize() {
CHECK_EQ(inputLayers_.size(), 1UL);
size_t layerSize = 0;
outputD_ = outputSize(imgSizeD_, sizeZ_, paddingD_, strideD_, false);
outputH_ = outputSize(imgSizeH_, sizeY_, paddingH_, strideH_, false);
outputW_ = outputSize(imgSizeW_, sizeX_, paddingW_, strideW_, false);
layerSize = outputD_ * outputH_ * outputW_ * channels_;
getOutput().setFrameHeight(outputH_);
getOutput().setFrameWidth(outputW_);
getOutput().setFrameDepth(outputD_);
return layerSize;
}
void Pool3DLayer::forward(PassType passType) {
Layer::forward(passType);
const MatrixPtr& inMat = inputLayers_[0]->getOutputValue();
size_t batchSize = inMat->getHeight();
size_t outWidth = getSize();
resetOutput(batchSize, outWidth);
Matrix::resizeOrCreate(maxPoolIdx_, batchSize, outWidth, false, useGpu_);
const MatrixPtr outMat = getOutputValue();
if (poolType_ == "avg") {
outMat->avgPool3DForward(*inMat,
channels_,
imgSizeD_,
imgSizeH_,
imgSizeW_,
outputD_,
outputH_,
outputW_,
sizeZ_,
sizeY_,
sizeX_,
strideD_,
strideH_,
strideW_,
paddingD_,
paddingH_,
paddingW_);
} else if (poolType_ == "max") {
outMat->maxPool3DForward(*inMat,
*maxPoolIdx_,
channels_,
imgSizeD_,
imgSizeH_,
imgSizeW_,
outputD_,
outputH_,
outputW_,
sizeZ_,
sizeY_,
sizeX_,
strideD_,
strideH_,
strideW_,
paddingD_,
paddingH_,
paddingW_);
} else {
LOG(FATAL) << "Unknown pool type: " << poolType_;
}
forwardActivation();
}
void Pool3DLayer::backward(const UpdateCallback& callback) {
backwardActivation();
(void)callback;
if (NULL == getInputGrad(0)) return;
MatrixPtr inMat = inputLayers_[0]->getOutputValue();
MatrixPtr inGradMat = inputLayers_[0]->getOutputGrad();
MatrixPtr outMat = getOutputValue();
MatrixPtr outGradMat = getOutputGrad();
if (poolType_ == "avg") {
inGradMat->avgPool3DBackward(*outGradMat,
imgSizeD_,
imgSizeH_,
imgSizeW_,
outputD_,
outputH_,
outputW_,
sizeZ_,
sizeY_,
sizeZ_,
strideD_,
strideH_,
strideW_,
paddingD_,
paddingH_,
paddingW_,
1.0,
1.0);
} else if (poolType_ == "max") {
inGradMat->maxPool3DBackward(*outGradMat,
*maxPoolIdx_,
imgSizeD_,
imgSizeH_,
imgSizeW_,
outputD_,
outputH_,
outputW_,
sizeZ_,
sizeY_,
sizeZ_,
strideD_,
strideH_,
strideW_,
paddingD_,
paddingH_,
paddingW_,
1.0,
1.0);
} else {
LOG(FATAL) << "Unknown pool type: " << poolType_;
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "Layer.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
namespace paddle {
/**
* @brief Basic parent layer of pooling
* Pools the input within regions
*/
class Pool3DLayer : public Layer {
public:
explicit Pool3DLayer(const LayerConfig& config) : Layer(config) {}
~Pool3DLayer() {}
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;
size_t getSize();
protected:
int channels_;
int sizeX_, sizeY_, sizeZ_;
int strideW_, strideH_, strideD_;
int paddingW_, paddingH_, paddingD_;
int imgSizeW_, imgSizeH_, imgSizeD_;
int outputW_, outputH_, outputD_;
std::string poolType_;
MatrixPtr maxPoolIdx_;
};
} // namespace paddle
......@@ -1246,6 +1246,75 @@ TEST(Layer, PoolLayer) {
#endif
}
void setPool3DConfig(TestConfig* config,
PoolConfig* pool,
const string& poolType) {
// filter size
const int NUM_FILTERS = 16;
const int FILTER_SIZE = 3;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Z = 3;
const int CHANNELS = 16;
(*config).biasSize = 0;
(*config).layerConfig.set_type("pool3d");
(*config).layerConfig.set_num_filters(NUM_FILTERS);
int kw = FILTER_SIZE, kh = FILTER_SIZE_Y, kd = FILTER_SIZE_Z;
int pw = 0, ph = 0, pd = 0;
int sw = 2, sh = 2, sd = 2;
pool->set_pool_type(poolType);
pool->set_pool_type("avg");
pool->set_channels(CHANNELS);
pool->set_size_x(kw);
pool->set_size_y(kh);
pool->set_size_z(kd);
pool->set_padding(0);
pool->set_padding_y(0);
pool->set_padding_z(0);
pool->set_stride(sw);
pool->set_stride_y(sh);
pool->set_stride_z(sd);
pool->set_start(0);
int ow = outputSize(pool->img_size(), kw, pw, sw, /* caffeMode */ false);
int oh = outputSize(pool->img_size_y(), kh, ph, sh, /* caffeMode */ false);
int od = outputSize(pool->img_size_z(), kd, pd, sd, /* caffeMode */ false);
pool->set_output_x(ow);
pool->set_output_y(oh);
pool->set_output_z(od);
}
void testPool3DLayer(const string& poolType, bool trans, bool useGpu) {
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 11664, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PoolConfig* pool = input->mutable_pool_conf();
const int IMAGE_SIZE = 9;
const int IMAGE_SIZE_Y = 9;
const int IMAGE_SIZE_Z = 9;
pool->set_img_size(IMAGE_SIZE);
pool->set_img_size_y(IMAGE_SIZE_Y);
pool->set_img_size_z(IMAGE_SIZE_Z);
setPool3DConfig(&config, pool, poolType);
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
pool->channels());
testLayerGrad(config, "pool3d", 100, trans, useGpu);
}
TEST(Layer, Pool3DLayer) {
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ false);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
testPool3DLayer("avg", /* trans= */ false, /* useGpu= */ true);
testPool3DLayer("max", /* trans= */ false, /* useGpu= */ true);
#endif
}
void testSppLayer(const string& poolType,
const int pyramidHeight,
bool trans,
......@@ -2096,6 +2165,159 @@ TEST(Layer, RowL2NormLayer) {
}
}
void test3DConvLayer(const string& type, bool trans, bool useGpu) {
// filter size
const int NUM_FILTERS = 6;
// const int CHANNELS = 3;
const int FILTER_SIZE = 3;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Z = 3;
// input image
const int CHANNELS = 3;
const int IMAGE_SIZE = 9;
const int IMAGE_SIZE_Y = 9;
const int IMAGE_SIZE_Z = 9;
TestConfig config;
config.biasSize = NUM_FILTERS;
config.layerConfig.set_type(type);
config.layerConfig.set_num_filters(NUM_FILTERS);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
// Setting up conv3D-trans layer
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_channels(CHANNELS);
conv->set_filter_size(FILTER_SIZE);
conv->set_filter_size_y(FILTER_SIZE_Y);
conv->set_filter_size_z(FILTER_SIZE_Z);
conv->set_padding(0);
conv->set_padding_y(0);
conv->set_padding_z(0);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z);
conv->set_output_x(outputSize(conv->img_size(),
conv->filter_size(),
conv->padding(),
conv->stride(),
/* caffeMode */ true));
conv->set_output_y(outputSize(conv->img_size_y(),
conv->filter_size_y(),
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true));
conv->set_output_z(outputSize(conv->img_size_z(),
conv->filter_size_z(),
conv->padding_z(),
conv->stride_z(),
/* caffeMode */ true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
conv->output_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z,
conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z *
NUM_FILTERS});
testLayerGrad(config, "conv3D", 10, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "conv3D", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, test3DConvLayer) {
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
test3DConvLayer("conv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
// filter size
const int NUM_FILTERS = 6;
// const int CHANNELS = 3;
const int FILTER_SIZE = 3;
const int FILTER_SIZE_Y = 3;
const int FILTER_SIZE_Z = 3;
// input image
const int CHANNELS = 3;
const int IMAGE_SIZE = 4;
const int IMAGE_SIZE_Y = 6;
const int IMAGE_SIZE_Z = 6;
// Setting up conv-trans layer
TestConfig config;
config.biasSize = NUM_FILTERS;
config.layerConfig.set_type("deconv3d");
config.layerConfig.set_num_filters(NUM_FILTERS);
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_channels(CHANNELS);
conv->set_filter_size(FILTER_SIZE);
conv->set_filter_size_y(FILTER_SIZE_Y);
conv->set_filter_size_z(FILTER_SIZE_Z);
conv->set_padding(0);
conv->set_padding_y(0);
conv->set_padding_z(0);
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z);
conv->set_output_x(imageSize(conv->img_size(),
conv->filter_size(),
conv->padding(),
conv->stride(),
true));
conv->set_output_y(imageSize(conv->img_size_y(),
conv->filter_size_y(),
conv->padding_y(),
conv->stride_y(),
true));
conv->set_output_z(imageSize(conv->img_size_z(),
conv->filter_size_z(),
conv->padding_z(),
conv->stride_z(),
true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() *
conv->output_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
{INPUT_DATA,
"layer_0",
CHANNELS * IMAGE_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_Z,
conv->filter_channels() * FILTER_SIZE * FILTER_SIZE_Y * FILTER_SIZE_Z *
NUM_FILTERS});
testLayerGrad(config, "deconv3D", 10, trans, useGpu);
// Use small batch_size and useWeight=true to test biasGrad
testLayerGrad(config, "deconv3D", 2, trans, useGpu, true, 0.02);
}
TEST(Layer, test3DDeConvLayer) {
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ false);
#ifndef PADDLE_ONLY_CPU
test3DDeConvLayer("deconv3d", /* trans= */ false, /* useGpu= */ true);
#endif
}
TEST(Layer, ScaleShiftLayer) {
const size_t batchSize = 16;
const size_t size = 32;
......
此差异已折叠。
......@@ -928,15 +928,102 @@ public:
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
* Output: output size is the number of input sequences (NOT input
* instances).
*
* output[i] is set to max_input[i].
* Pooling 3D forward operation, pick out the largest element
* in the sizeX of value
*/
virtual void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void maxPool3DBackward(Matrix& outGrad,
Matrix& maxPoolIdx,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput) {
LOG(FATAL) << "Not implemeted";
}
virtual void avgPool3DForward(Matrix& input,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void avgPool3DBackward(Matrix& input,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput) {
LOG(FATAL) << "Not implemeted";
}
/**
* Input: one or more sequences. Each sequence contains some instances.
*
* Output: output size is the number of input sequences (NOT input
* instances).
*
* output[i] is set to max_input[i].
*/
virtual void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index) {
......@@ -1039,6 +1126,42 @@ public:
LOG(FATAL) << "Not implemented";
}
virtual void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW) {
LOG(FATAL) << "Not implemeted";
}
virtual void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta) {
LOG(FATAL) << "Not implemeted";
}
virtual void bilinearForward(const Matrix& in,
const size_t inImgH,
const size_t inImgW,
......@@ -1348,6 +1471,82 @@ public:
size_t paddingH,
size_t paddingW);
void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW);
void maxPool3DBackward(Matrix& outGrad,
Matrix& maxPoolIdx,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput);
void avgPool3DForward(Matrix& input,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW);
void avgPool3DBackward(Matrix& input,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
......@@ -1374,6 +1573,38 @@ public:
const real ratioH,
const real ratioW);
void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
void multiBinaryLabelCrossEntropy(Matrix& output, Matrix& label);
void multiBinaryLabelCrossEntropyBp(Matrix& output, Matrix& label);
......@@ -1507,6 +1738,82 @@ public:
size_t paddingH,
size_t paddingW);
void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW);
void maxPool3DBackward(Matrix& outGrad,
Matrix& maxPoolIdx,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput);
void avgPool3DForward(Matrix& input,
size_t channels,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW);
void avgPool3DBackward(Matrix& input,
size_t imgSizeD,
size_t imgSizeH,
size_t imgSizeW,
size_t outputD,
size_t outputH,
size_t outputW,
size_t sizeZ,
size_t sizeY,
size_t sizeX,
size_t strideD,
size_t strideH,
size_t strideW,
size_t paddingD,
size_t paddingH,
size_t paddingW,
real scaleTargets,
real scaleOutput);
void maxSequenceForward(Matrix& input,
const IVector& sequence,
IVector& index);
......@@ -1715,6 +2022,38 @@ public:
const real ratioH,
const real ratioW);
void vol2Col(real* data,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW);
void col2Vol(real* trg,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real alpha,
real beta);
template <typename ExpressionType>
void operator=(const ExpressionType& expr) {
TensorCpuApply<real>(*this, expr);
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "TensorCheck.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h"
#include "paddle/testing/TestUtil.h"
......@@ -1203,4 +1204,497 @@ TEST(Matrix, warpCTC) {
}
}
void testMaxPool3DFwdBwd(int numSamples,
int channels,
int imgSizeD,
int imgSizeH,
int imgSizeW,
int ksizeD,
int ksizeH,
int ksizeW,
int strideD,
int strideH,
int strideW,
int padD,
int padH,
int padW) {
int outD = outputSize(imgSizeD, ksizeD, padD, strideD, true);
int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true);
int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true);
int inWidth = channels * imgSizeD * imgSizeH * imgSizeW;
MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true);
int outWidth = channels * outD * outH * outW;
MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true);
MatrixPtr maxIdx = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr maxIdxGpu = GpuMatrix::create(numSamples, outWidth, false, true);
input->randomizeUniform();
target->randomizeUniform();
inputGpu->copyFrom(*input);
targetGpu->copyFrom(*target);
target->maxPool3DForward(*input,
*maxIdx,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW);
targetGpu->maxPool3DForward(*inputGpu,
*maxIdxGpu,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW);
MatrixPtr targetCheck = CpuMatrix::create(numSamples, outWidth, false, false);
targetCheck->copyFrom(*targetGpu);
checkMatrixEqual(target, targetCheck);
MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpuGrad =
GpuMatrix::create(numSamples, outWidth, false, true);
inputGrad->randomizeUniform();
targetGrad->randomizeUniform();
inputGpuGrad->copyFrom(*inputGrad);
targetGpuGrad->copyFrom(*targetGrad);
inputGrad->maxPool3DBackward(*targetGrad,
*maxIdx,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW,
1.0,
1.0);
inputGpuGrad->maxPool3DBackward(*targetGpuGrad,
*maxIdxGpu,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW,
1.0,
1.0);
MatrixPtr targetBwdCheck =
CpuMatrix::create(numSamples, inWidth, false, false);
targetBwdCheck->copyFrom(*inputGpuGrad);
checkMatrixEqual(inputGrad, targetBwdCheck);
}
void testAvgPool3DFwdBwd(int numSamples,
int channels,
int imgSizeD,
int imgSizeH,
int imgSizeW,
int ksizeD,
int ksizeH,
int ksizeW,
int strideD,
int strideH,
int strideW,
int padD,
int padH,
int padW) {
int outD = outputSize(imgSizeD, ksizeD, padD, strideD, true);
int outH = outputSize(imgSizeH, ksizeH, padH, strideH, true);
int outW = outputSize(imgSizeW, ksizeW, padW, strideW, true);
int inWidth = imgSizeD * imgSizeH * imgSizeW * channels;
MatrixPtr input = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpu = GpuMatrix::create(numSamples, inWidth, false, true);
int outWidth = channels * outD * outH * outW;
MatrixPtr target = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpu = GpuMatrix::create(numSamples, outWidth, false, true);
input->randomizeUniform();
target->randomizeUniform();
inputGpu->copyFrom(*input);
targetGpu->copyFrom(*target);
target->avgPool3DForward(*input,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW);
targetGpu->avgPool3DForward(*inputGpu,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW);
TensorCheckErr(*target, *targetGpu);
MatrixPtr inputGrad = CpuMatrix::create(numSamples, inWidth, false, false);
MatrixPtr inputGpuGrad = GpuMatrix::create(numSamples, inWidth, false, true);
MatrixPtr targetGrad = CpuMatrix::create(numSamples, outWidth, false, false);
MatrixPtr targetGpuGrad =
GpuMatrix::create(numSamples, outWidth, false, true);
inputGrad->randomizeUniform();
targetGrad->randomizeUniform();
inputGpuGrad->copyFrom(*inputGrad);
targetGpuGrad->copyFrom(*targetGrad);
inputGrad->avgPool3DBackward(*targetGrad,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW,
1.0,
1.0);
inputGpuGrad->avgPool3DBackward(*targetGpuGrad,
imgSizeD,
imgSizeH,
imgSizeW,
outD,
outH,
outW,
ksizeD,
ksizeH,
ksizeW,
strideD,
strideH,
strideW,
padD,
padH,
padW,
1.0,
1.0);
TensorCheckErr(*inputGrad, *inputGpuGrad);
}
// TODO(yi): I noticed many such blindly combinatorial tests in this
// file. They are no help to locate defects at all.
TEST(Matrix, Pool3DFwdBwd) {
for (auto numSamples : {1, 3}) {
for (auto channels : {3}) {
for (auto imgSizeD : {9, 16}) {
for (auto imgSizeH : {9, 32}) {
for (auto imgSizeW : {9, 32}) {
for (auto sizeX : {3}) {
for (auto sizeY : {3}) {
for (auto sizeZ : {3}) {
for (auto sD : {2}) {
for (auto sH : {2}) {
for (auto sW : {2}) {
for (auto pD : {0, (sizeZ - 1) / 2}) {
for (auto pH : {0, (sizeY - 1) / 2}) {
for (auto pW : {0, (sizeX - 1) / 2}) {
VLOG(3) << " numSamples=" << numSamples
<< " channels=" << channels
<< " imgSizeD=" << imgSizeD
<< " imgSizeH=" << imgSizeH
<< " imgSizeW=" << imgSizeW
<< " sizeX=" << sizeX
<< " sizeY=" << sizeY
<< " sizeZ=" << sizeZ << " strideD=" << sD
<< " strideH=" << sH << " strideW=" << sW
<< " padingD=" << pD << " padingH=" << pH
<< " padingW=" << pW;
testMaxPool3DFwdBwd(numSamples,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
sizeX,
sizeY,
sizeZ,
sD,
sH,
sW,
pD,
pH,
pW);
testAvgPool3DFwdBwd(numSamples,
channels,
imgSizeD,
imgSizeH,
imgSizeW,
sizeX,
sizeY,
sizeZ,
sD,
sH,
sW,
pD,
pH,
pW);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
// for (auto numSamples : {1, 3}) {
// for (auto channels : {1, 3}) {
// for (auto imgSizeD : {9,16}) {
// for (auto imgSizeH : {9, 32}) {
// for (auto imgSizeW : {9, 32}) {
// for (auto sizeX : {2, 3}) {
// for (auto sizeY : {2, 3}) {
// for (auto sizeZ : {2,3}){
// for (auto sD : {1, 2}) {
// for (auto sH : {1, 2}) {
// for (auto sW : {1, 2}) {
// for (auto pD : {0, (sizeZ - 1) / 2}){
// for (auto pH : {0, (sizeY - 1) / 2}) {
// for (auto pW : {0, (sizeX - 1) / 2}) {
// VLOG(3) << " numSamples=" << numSamples
// << " channels=" << channels
// << " imgSizeD=" << imgSizeD
// << " imgSizeH=" << imgSizeH
// << " imgSizeW=" << imgSizeW
// << " sizeX=" << sizeX
// << " sizeY=" << sizeY
// << " sizeZ=" << sizeZ
// << " strideD=" << sD
// << " strideH=" << sH
// << " strideW=" << sW
// << " padingD=" << pD
// << " padingH=" << pH
// << " padingW=" << pW;
//
// testMaxPool3DFwdBwd(numSamples,
// channels,
// imgSizeD,
// imgSizeH,
// imgSizeW,
// sizeX,
// sizeY,
// sizeZ,
// sD,
// sH,
// sW,
// pD,
// pH,
// pW);
// testAvgPool3DFwdBwd(numSamples,
// channels,
// imgSizeD,
// imgSizeH,
// imgSizeW,
// sizeX,
// sizeY,
// sizeZ,
// sD,
// sH,
// sW,
// pD,
// pH,
// pW);
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
// }
}
void testMatrixCol2Vol(int depth, int height, int width) {
int channel = 3;
int filterX = 3, filterY = 4, filterZ = 5;
int strideX = 2, strideY = 2, strideZ = 2;
int padX = 1, padY = 1, padZ = 1;
MatrixPtr cpuImage =
std::make_shared<CpuMatrix>(channel, depth * height * width);
MatrixPtr gpuImage =
std::make_shared<GpuMatrix>(channel, depth * height * width);
cpuImage->randomizeUniform();
gpuImage->copyFrom(*cpuImage);
int outD = outputSize(depth, filterZ, padZ, strideZ, true);
int outH = outputSize(height, filterY, padY, strideY, true);
int outW = outputSize(width, filterX, padX, strideX, true);
int colBufHeight = channel * filterZ * filterY * filterX;
int colBufWidth = outD * outH * outW;
MatrixPtr cpuColBuf = std::make_shared<CpuMatrix>(colBufHeight, colBufWidth);
MatrixPtr gpuColBuf = std::make_shared<GpuMatrix>(colBufHeight, colBufWidth);
cpuColBuf->vol2Col(cpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX);
gpuColBuf->vol2Col(gpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX);
TensorCheckEqual(*cpuColBuf, *gpuColBuf);
cpuColBuf->randomizeUniform();
gpuColBuf->copyFrom(*cpuColBuf);
cpuColBuf->col2Vol(cpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX,
1.0,
1.0);
gpuColBuf->col2Vol(gpuImage->getData(),
channel,
depth,
height,
width,
filterZ,
filterY,
filterX,
strideZ,
strideY,
strideX,
padZ,
padY,
padX,
1.0,
1.0);
TensorCheckErr(*cpuImage, *gpuImage);
}
TEST(Matrix, col2Vol) {
for (auto depth : {9, 16, 64}) {
for (auto height : {9, 11, 128}) {
for (auto width : {9, 32, 128}) {
VLOG(3) << "depth=" << depth << " height=" << height
<< " width=" << width;
testMatrixCol2Vol(depth, height, width);
}
}
}
}
#endif
......@@ -515,6 +515,7 @@ message LayerConfig {
// for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ];
// for 3D data
optional double depth = 58 [ default = 1 ];
}
......
......@@ -886,6 +886,36 @@ class Conv(Cfg):
config_assert(output_x <= 0)
# please refer to the comments in proto/ModelConfig.proto
@config_class
class Conv3D(Cfg):
def __init__(self,
filter_size,
channels,
padding=None,
stride=None,
groups=None,
filter_channels=None,
output_x=None,
img_size=None,
caffe_mode=True,
filter_size_y=None,
padding_y=None,
stride_y=None,
filter_size_z=None,
padding_z=None,
stride_z=None):
self.add_keys(locals())
self.filter_size_y = filter_size_y if filter_size_y else filter_size
self.filter_size_z = filter_size_z if filter_size_z else filter_size
self.padding_y = padding_y if padding_y else padding
self.padding_z = padding_z if padding_z else padding
self.stride_y = stride_y if stride_y else stride
self.stride_z = stride_z if stride_z else stride
if output_x is not None:
config_assert(output_x <= 0)
@config_class
class BilinearInterp(Cfg):
def __init__(self, out_size_x=None, out_size_y=None, channels=None):
......@@ -908,6 +938,31 @@ class Pool(Cfg):
self.add_keys(locals())
@config_class
class Pool3d(Cfg):
def __init__(
self,
pool_type,
channels,
size_x,
size_y=None,
size_z=None,
start=None,
stride=None, # 1 by defalut in protobuf
stride_y=None,
stride_z=None,
padding=None, # 0 by defalut in protobuf
padding_y=None,
padding_z=None):
self.add_keys(locals())
self.filter_size_y = size_y if size_y else size_x
self.filter_size_z = size_z if size_z else size_x
self.padding_y = padding_y if padding_y else padding
self.padding_z = padding_z if padding_z else padding
self.stride_y = stride_y if stride_y else stride
self.stride_z = stride_z if stride_z else stride
@config_class
class SpatialPyramidPool(Cfg):
def __init__(self, pool_type, pyramid_height, channels):
......@@ -1223,6 +1278,45 @@ def parse_pool(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.stride_y, not ceil_mode)
def parse_pool3d(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf.pool_type = pool.pool_type
config_assert(pool.pool_type in ['max-projection', 'avg-projection'],
"pool-type %s is not in "
"['max-projection', 'avg-projection']" % pool.pool_type)
pool_conf.channels = pool.channels
pool_conf.size_x = pool.size_x
pool_conf.stride = pool.stride
pool_conf.padding = pool.padding
pool_conf.size_y = default(pool.size_y, pool_conf.size_x)
pool_conf.size_z = default(pool.size_z, pool_conf.size_x)
pool_conf.stride_y = default(pool.stride_y, pool_conf.stride)
pool_conf.stride_z = default(pool.stride_z, pool_conf.stride)
pool_conf.padding_y = default(pool.padding_y, pool_conf.padding)
pool_conf.padding_z = default(pool.padding_z, pool_conf.padding)
pool_conf.img_size, pool_conf.img_size_y, pool_conf.img_size_z = \
get_img3d_size(input_layer_name, pool.channels)
config_assert(not pool.start, "start is deprecated in pooling.")
if pool.padding is not None:
pool_conf.padding = pool.padding
pool_conf.padding_y = default(pool.padding_y, pool_conf.padding)
pool_conf.padding_z = default(pool.padding_z, pool_conf.padding)
pool_conf.output_x = cnn_output_size(pool_conf.img_size, pool_conf.size_x,
pool_conf.padding, pool_conf.stride,
not ceil_mode)
pool_conf.output_y = cnn_output_size(pool_conf.img_size_y, pool_conf.size_y,
pool_conf.padding_y,
pool_conf.stride_y, not ceil_mode)
pool_conf.output_z = cnn_output_size(pool_conf.img_size_z, pool_conf.size_z,
pool_conf.padding_z,
pool_conf.stride_z, not ceil_mode)
def parse_spp(spp, input_layer_name, spp_conf):
parse_image(spp, input_layer_name, spp_conf.image_conf)
spp_conf.pool_type = spp.pool_type
......@@ -1302,6 +1396,50 @@ def parse_conv(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.stride_y, conv_conf.caffe_mode)
#caffe_mode: compute the output size using floor instead of ceil,
# which is consistent of caffe and CuDNN's convention.
def parse_conv3d(conv, input_layer_name, conv_conf, num_filters, trans=False):
conv_conf.filter_size = conv.filter_size
conv_conf.filter_size_y = conv.filter_size_y
conv_conf.filter_size_z = conv.filter_size_z
conv_conf.channels = conv.channels
conv_conf.padding = conv.padding
conv_conf.padding_y = conv.padding_y
conv_conf.padding_z = conv.padding_z
conv_conf.stride = conv.stride
conv_conf.stride_y = conv.stride_y
conv_conf.stride_z = conv.stride_z
conv_conf.groups = conv.groups
conv_conf.caffe_mode = conv.caffe_mode
if not trans:
conv_conf.filter_channels = conv.channels / conv.groups
conv_conf.img_size, conv_conf.img_size_y, conv_conf.img_size_z = \
get_img3d_size(input_layer_name, conv.channels)
conv_conf.output_x = cnn_output_size(
conv_conf.img_size, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.output_y = cnn_output_size(
conv_conf.img_size_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.output_z = cnn_output_size(
conv_conf.img_size_z, conv_conf.filter_size_z, conv_conf.padding_z,
conv_conf.stride_z, conv_conf.caffe_mode)
else:
conv_conf.filter_channels = num_filters / conv.groups
conv_conf.output_x, conv_conf.output_y, conv_conf.output_z = \
get_img3d_size(input_layer_name, conv.channels)
conv_conf.img_size = cnn_image_size(
conv_conf.output_x, conv_conf.filter_size, conv_conf.padding,
conv_conf.stride, conv_conf.caffe_mode)
conv_conf.img_size_y = cnn_image_size(
conv_conf.output_y, conv_conf.filter_size_y, conv_conf.padding_y,
conv_conf.stride_y, conv_conf.caffe_mode)
conv_conf.img_size_z = cnn_image_size(
conv_conf.output_z, conv_conf.filter_size_z, conv_conf.padding_z,
conv_conf.stride_z, conv_conf.caffe_mode)
def parse_block_expand(block_expand, input_layer_name, block_expand_conf):
block_expand_conf.channels = block_expand.channels
block_expand_conf.stride_x = block_expand.stride_x
......@@ -1953,7 +2091,7 @@ class ConvLayerBase(LayerBase):
def calc_parameter_size(self, conv_conf):
return self.config.num_filters * conv_conf.filter_channels \
* (conv_conf.filter_size * conv_conf.filter_size_y)
* (conv_conf.filter_size * conv_conf.filter_size_y)
@config_layer('exconv')
......@@ -2037,6 +2175,87 @@ class ConvTransLayer(ConvTransLayerBase):
layer_type = 'cudnn_convt'
@config_layer('conv_3d')
class Conv3DLayerBase(LayerBase):
def __init__(self,
name,
inputs=[],
bias=True,
num_filters=None,
shared_biases=True,
**xargs):
super(Conv3DLayerBase, self).__init__(
name, self.layer_type, 0, inputs=inputs, **xargs)
if num_filters is not None:
self.config.num_filters = num_filters
# need to specify layer in config
self.config.type = self.layer_type
trans = False
if self.config.type == "deconv3d":
trans = True
if shared_biases is not None:
self.config.shared_biases = shared_biases
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
conv_conf = self.config.inputs[input_index].conv_conf
parse_conv3d(
self.inputs[input_index].conv,
input_layer.name,
conv_conf,
num_filters,
trans=trans
) # for z-axis pad:0, strid:1, filter_size:1, img_size:1
psize = self.calc_parameter_size(conv_conf)
self.create_input_parameter(input_index, psize)
if trans:
self.set_cnn_layer(name, conv_conf.img_size_z,
conv_conf.img_size_y, conv_conf.img_size,
self.config.num_filters)
else:
self.set_cnn_layer(name, conv_conf.output_z, conv_conf.output_y,
conv_conf.output_x, self.config.num_filters)
psize = self.config.size
if shared_biases:
psize = self.config.num_filters
self.create_bias_parameter(bias, psize, [psize, 1])
def calc_parameter_size(self, conv_conf):
return self.config.num_filters * conv_conf.filter_channels \
* (conv_conf.filter_size * conv_conf.filter_size_y \
* conv_conf.filter_size_z)
def set_cnn_layer(self,
input_layer_name,
depth,
height,
width,
channels,
is_print=True):
size = depth * height * width * channels
self.set_layer_size(size)
self.set_layer_height_width(height, width)
self.set_layer_depth(depth)
if is_print:
print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, depth, height, width, size))
@config_layer('conv3d')
class Conv3DLayer(Conv3DLayerBase):
layer_type = 'conv3d'
@config_layer('deconv3d')
class Conv3DLayer(Conv3DLayerBase):
layer_type = 'deconv3d'
@config_layer('norm')
class NormLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......@@ -2066,6 +2285,35 @@ class PoolLayer(LayerBase):
pool_conf.channels)
@config_layer('pool3d')
class Pool3DLayer(LayerBase):
def __init__(self, name, inputs, ceil_mode=True, **xargs):
super(Pool3DLayer, self).__init__(
name, 'pool3d', 0, inputs=inputs, **xargs)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
pool_conf = self.config.inputs[input_index].pool_conf
parse_pool3d(self.inputs[input_index].pool, input_layer.name,
pool_conf, ceil_mode)
self.set_cnn_layer(name, pool_conf.output_z, pool_conf.output_y,
pool_conf.output_x, pool_conf.channels)
def set_cnn_layer(self,
input_layer_name,
depth,
height,
width,
channels,
is_print=True):
size = depth * height * width * channels
self.set_layer_size(size)
self.set_layer_height_width(height, width)
self.set_layer_depth(depth)
if is_print:
print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
(input_layer_name, channels, depth, height, width, size))
@config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase):
def __init__(self, name, inputs, **xargs):
......
文件模式从 100755 更改为 100644
......@@ -138,7 +138,9 @@ __all__ = [
'slice_projection',
'seq_slice_layer',
'kmax_sequence_score_layer',
'img_pool3d_layer',
'scale_shift_layer',
'img_conv3d_layer',
]
......@@ -221,6 +223,9 @@ class LayerType(object):
CRF_DECODING_LAYER = 'crf_decoding'
NCE_LAYER = 'nce'
CONV3D_LAYER = 'conv3d'
DECONV3D_LAYER = 'deconv3d'
RANK_COST = 'rank-cost'
LAMBDA_COST = 'lambda_cost'
HUBER_REGRESSION = 'huber_regression'
......@@ -936,7 +941,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
if height is not None and width is not None:
num_filters = size / (width * height * depth)
assert num_filters * width * height * depth == size, \
"size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
"size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
......@@ -2660,6 +2665,146 @@ def img_pool_layer(input,
size=l.config.size)
@wrap_name_default("pool3d")
@layer_support()
def img_pool3d_layer(input,
pool_size,
name=None,
num_channels=None,
pool_type=None,
stride=1,
padding=0,
layer_attr=None,
pool_size_y=None,
stride_y=None,
padding_y=None,
pool_size_z=None,
stride_z=None,
padding_z=None,
ceil_mode=True):
"""
Image pooling Layer.
The details of pooling layer, please refer ufldl's pooling_ .
.. _pooling: http://ufldl.stanford.edu/tutorial/supervised/Pooling/
- ceil_mode=True:
.. math::
w = 1 + int(ceil(input\_width + 2 * padding - pool\_size) / float(stride))
h = 1 + int(ceil(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y))
d = 1 + int(ceil(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z))
- ceil_mode=False:
.. math::
w = 1 + int(floor(input\_width + 2 * padding - pool\_size) / float(stride))
h = 1 + int(floor(input\_height + 2 * padding\_y - pool\_size\_y) / float(stride\_y))
d = 1 + int(floor(input\_depth + 2 * padding\_z - pool\_size\_z) / float(stride\_z))
The example usage is:
.. code-block:: python
maxpool = img_pool3d_layer(input=conv,
pool_size=3,
num_channels=8,
stride=1,
padding=1,
pool_type=MaxPooling())
:param padding: pooling padding width.
:type padding: int|tuple|list
:param name: name of pooling layer
:type name: basestring.
:param input: layer's input
:type input: LayerOutput
:param pool_size: pooling window width
:type pool_size: int|tuple|list
:param num_channels: number of input channel.
:type num_channels: int
:param pool_type: pooling type. MaxPooling or AvgPooling. Default is
MaxPooling.
:type pool_type: BasePoolingType
:param stride: stride width of pooling.
:type stride: int|tuple|list
:param layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:param ceil_mode: Wether to use ceil mode to calculate output height and with.
Defalut is True. If set false, Otherwise use floor.
:type ceil_mode: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
type_name = pool_type.name + '-projection' \
if (
isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \
else pool_type.name
if isinstance(pool_size, collections.Sequence):
assert len(pool_size) == 3
pool_size, pool_size_y, pool_size_z = pool_size
else:
pool_size_y = pool_size
pool_size_z = pool_size
if isinstance(stride, collections.Sequence):
assert len(stride) == 3
stride, stride_y, stride_z = stride
else:
stride_y = stride
stride_z = stride
if isinstance(padding, collections.Sequence):
assert len(padding) == 3
padding, padding_y, padding_y = padding
else:
padding_y = padding
padding_z = padding
l = Layer(
name=name,
type=LayerType.POOL3D_LAYER,
inputs=[
Input(
input.name,
pool=Pool3d(
pool_type=type_name,
channels=num_channels,
size_x=pool_size,
start=None,
stride=stride,
padding=padding,
size_y=pool_size_y,
stride_y=stride_y,
padding_y=padding_y,
size_z=pool_size_z,
stride_z=stride_z,
padding_z=padding_z))
],
ceil_mode=ceil_mode,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
LayerType.POOL_LAYER,
parents=[input],
num_filters=num_channels,
size=l.config.size)
@wrap_name_default("spp")
@layer_support()
def spp_layer(input,
......@@ -6490,6 +6635,149 @@ def kmax_sequence_score_layer(input, name=None, beam_size=1):
name, LayerType.KMAX_SEQ_SCORE, parents=[input], size=input.size)
@wrap_name_default("conv3d")
@wrap_param_attr_default()
@wrap_bias_attr_default()
@wrap_act_default(act=ReluActivation())
@layer_support(DROPOUT)
def img_conv3d_layer(input,
filter_size,
num_filters,
name=None,
num_channels=None,
act=None,
groups=1,
stride=1,
padding=0,
bias_attr=None,
param_attr=None,
shared_biases=True,
layer_attr=None,
trans=False,
layer_type=None):
"""
The example usage is:
.. code-block:: python
conv = img_conv3d_layer(input=data, filter_size=1,
num_channels=8,
num_filters=16, stride=1,
bias_attr=False,
act=ReluActivation())
:param name: Layer name.
:type name: basestring
:param input: Layer Input.
:type input: LayerOutput
:param filter_size: The x dimension of a filter kernel. Or input a list.
:type filter_size: int|tuple|list
:param num_filters: Each filter group's number of filter
:param act: Activation type. Default is tanh
:type act: BaseActivation
:param groups: Group size of filters.
:type groups: int
:param stride: The x dimension of the stride. Or input a tuple for two image
dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension
:type padding: int|tuple|list
:param bias_attr: Convolution bias attribute. None means default bias.
False means no bias.
:type bias_attr: ParameterAttribute|False
:param num_channels: number of input channels. If None will be set
automatically from previous output.
:type num_channels: int
:param param_attr: Convolution param attribute. None means default attribute
:type param_attr: ParameterAttribute
:param shared_biases: Is biases will be shared between filters or not.
:type shared_biases: bool
:param layer_attr: Layer Extra Attribute.
:type layer_attr: ExtraLayerAttribute
:param trans: true if it is a convTransLayer, false if it is a convLayer
:type trans: bool
:param layer_type: specify the layer_type, default is None. If trans=True,
layer_type has to be "exconvt" or "cudnn_convt",
otherwise layer_type has to be either "exconv" or
"cudnn_conv"
:type layer_type: String
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if isinstance(filter_size, collections.Sequence):
assert len(filter_size) == 3
filter_size, filter_size_y, filter_size_z = filter_size
else:
filter_size_y = filter_size
filter_size_z = filter_size
if isinstance(stride, collections.Sequence):
assert len(stride) == 3
stride, stride_y, stride_z = stride
else:
stride_y = stride
stride_z = stride
if isinstance(padding, collections.Sequence):
assert len(padding) == 3
padding, padding_y, padding_z = padding
else:
padding_y = padding
padding_z = padding
if param_attr.attr.get('initial_smart'):
# special initial for conv layers.
init_w = (2.0 / (filter_size**2 * num_channels))**0.5
param_attr.attr["initial_mean"] = 0.0
param_attr.attr["initial_std"] = init_w
param_attr.attr["initial_strategy"] = 0
param_attr.attr["initial_smart"] = False
if layer_type:
if trans:
assert layer_type in ["deconv3d"]
lt = layer_type
else:
lt = LayerType.DECONV3D_LAYER if trans else LayerType.CONV3D_LAYER
l = Layer(
name=name,
inputs=Input(
input.name,
conv=Conv3D(
filter_size=filter_size,
padding=padding,
stride=stride,
channels=num_channels,
groups=groups,
filter_size_y=filter_size_y,
padding_y=padding_y,
stride_y=stride_y,
filter_size_z=filter_size_z,
padding_z=padding_z,
stride_z=stride_z),
**param_attr.attr),
active_type=act.name,
num_filters=num_filters,
bias=ParamAttr.to_bias(bias_attr),
shared_biases=shared_biases,
type=lt,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name,
lt,
parents=[input],
activation=act,
num_filters=num_filters,
size=l.config.size)
@wrap_name_default("scale_shift")
@wrap_param_attr_default()
@wrap_bias_attr_default()
......
文件模式从 100755 更改为 100644
......@@ -9,6 +9,7 @@ test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_seq_select_layers test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_BatchNorm3D)
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
test_conv3d_layer test_deconv3d_layer test_BatchNorm3D)
export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "data"
type: "data"
size: 36288
active_type: ""
height: 48
width: 42
depth: 6
}
layers {
name: "conv3d_1"
type: "conv3d"
size: 24192
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_conv3d_1.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 3
output_x: 21
img_size: 42
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 24
img_size_y: 48
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 3
img_size_z: 6
}
}
bias_parameter_name: "_conv3d_1.wbias"
num_filters: 16
shared_biases: true
height: 24
width: 21
depth: 3
}
layers {
name: "conv3d_2"
type: "conv3d"
size: 24192
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_conv3d_2.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 3
output_x: 21
img_size: 42
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 24
img_size_y: 48
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 3
img_size_z: 6
}
}
bias_parameter_name: "_conv3d_2.wbias"
num_filters: 16
shared_biases: true
height: 24
width: 21
depth: 3
}
parameters {
name: "_conv3d_1.w0"
size: 1296
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_1.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_2.w0"
size: 1296
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_conv3d_2.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "conv3d_2"
sub_models {
name: "root"
layer_names: "data"
layer_names: "conv3d_1"
layer_names: "conv3d_2"
input_layer_names: "data"
output_layer_names: "conv3d_2"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "data"
type: "data"
size: 36288
active_type: ""
height: 48
width: 42
depth: 6
}
layers {
name: "deconv3d_1"
type: "deconv3d"
size: 1387760
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_deconv3d_1.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 16
output_x: 42
img_size: 83
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 48
img_size_y: 95
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 6
img_size_z: 11
}
}
bias_parameter_name: "_deconv3d_1.wbias"
num_filters: 16
shared_biases: true
height: 95
width: 83
depth: 11
}
layers {
name: "deconv3d_2"
type: "deconv3d"
size: 1387760
active_type: ""
inputs {
input_layer_name: "data"
input_parameter_name: "_deconv3d_2.w0"
conv_conf {
filter_size: 3
channels: 3
stride: 2
padding: 1
groups: 1
filter_channels: 16
output_x: 42
img_size: 83
caffe_mode: true
filter_size_y: 3
padding_y: 1
stride_y: 2
output_y: 48
img_size_y: 95
filter_size_z: 3
padding_z: 1
stride_z: 2
output_z: 6
img_size_z: 11
}
}
bias_parameter_name: "_deconv3d_2.wbias"
num_filters: 16
shared_biases: true
height: 95
width: 83
depth: 11
}
parameters {
name: "_deconv3d_1.w0"
size: 6912
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_1.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_2.w0"
size: 6912
initial_mean: 0.0
initial_std: 0.272165526976
initial_strategy: 0
initial_smart: false
}
parameters {
name: "_deconv3d_2.wbias"
size: 16
initial_mean: 0.0
initial_std: 0.0
dims: 16
dims: 1
initial_strategy: 0
initial_smart: false
}
input_layer_names: "data"
output_layer_names: "deconv3d_2"
sub_models {
name: "root"
layer_names: "data"
layer_names: "deconv3d_1"
layer_names: "deconv3d_2"
input_layer_names: "data"
output_layer_names: "deconv3d_2"
is_recurrent_layer_group: false
}
type: "nn"
layers {
name: "data_2d"
type: "data"
size: 6000
active_type: ""
height: 20
width: 10
}
layers {
name: "pool___2d"
type: "pool"
size: 840
active_type: ""
inputs {
input_layer_name: "data_2d"
pool_conf {
pool_type: "avg-projection"
channels: 30
size_x: 5
stride: 3
output_x: 4
img_size: 10
padding: 1
size_y: 5
stride_y: 3
output_y: 7
img_size_y: 20
padding_y: 1
}
}
height: 7
width: 4
}
layers {
name: "data_3d_1"
type: "data"
size: 60000
active_type: ""
height: 20
width: 10
depth: 10
}
layers {
name: "pool_3d_1"
type: "pool3d"
size: 3360
active_type: ""
inputs {
input_layer_name: "data_3d_1"
pool_conf {
pool_type: "avg-projection"
channels: 30
size_x: 5
stride: 3
output_x: 4
img_size: 10
padding: 1
size_y: 5
stride_y: 3
output_y: 7
img_size_y: 20
padding_y: 1
size_z: 5
stride_z: 3
output_z: 4
img_size_z: 10
padding_z: 1
}
}
height: 7
width: 4
depth: 4
}
layers {
name: "pool_3d_2"
type: "pool3d"
size: 3360
active_type: ""
inputs {
input_layer_name: "data_3d_1"
pool_conf {
pool_type: "max-projection"
channels: 30
size_x: 5
stride: 3
output_x: 4
img_size: 10
padding: 1
size_y: 5
stride_y: 3
output_y: 7
img_size_y: 20
padding_y: 1
size_z: 5
stride_z: 3
output_z: 4
img_size_z: 10
padding_z: 1
}
}
height: 7
width: 4
depth: 4
}
input_layer_names: "data_2d"
output_layer_names: "pool___2d"
output_layer_names: "pool_3d_1"
output_layer_names: "pool_3d_2"
sub_models {
name: "root"
layer_names: "data_2d"
layer_names: "pool___2d"
layer_names: "data_3d_1"
layer_names: "pool_3d_1"
layer_names: "pool_3d_2"
input_layer_names: "data_2d"
output_layer_names: "pool___2d"
output_layer_names: "pool_3d_1"
output_layer_names: "pool_3d_2"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
num_channels = 3
filter_size = 3
filter_size_y = 3
filter_size_z = 3
stride = 2
stride_y = 2
stride_z = 2
padding = 1
padding_y = 1
padding_z = 1
groups = 1
data = data_layer(
name='data', size=12096 * num_channels, height=48, width=42, depth=6)
# first
conv3d_1 = img_conv3d_layer(
input=data,
name='conv3d_1',
num_filters=16,
num_channels=num_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=True,
shared_biases=True,
trans=False,
layer_type="conv3d",
act=LinearActivation())
# second
conv3d_2 = img_conv3d_layer(
input=data,
name='conv3d_2',
num_filters=16,
num_channels=num_channels,
filter_size=[filter_size, filter_size_y, filter_size_z],
stride=[stride, stride_y, stride_z],
padding=[padding, padding_y, padding_z],
groups=groups,
bias_attr=True,
shared_biases=True,
trans=False,
layer_type="conv3d",
act=LinearActivation())
outputs(conv3d_2)
from paddle.trainer_config_helpers import *
settings(batch_size=1000, learning_rate=1e-5)
num_channels = 3
filter_size = 3
filter_size_y = 3
filter_size_z = 3
stride = 2
stride_y = 2
stride_z = 2
padding = 1
padding_y = 1
padding_z = 1
groups = 1
data = data_layer(
name='data', size=12096 * num_channels, height=48, width=42, depth=6)
# first
deconv3d_1 = img_conv3d_layer(
input=data,
name='deconv3d_1',
num_filters=16,
num_channels=num_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
bias_attr=True,
shared_biases=True,
trans=True,
layer_type="deconv3d",
act=LinearActivation())
# second
deconv3d_2 = img_conv3d_layer(
input=data,
name='deconv3d_2',
num_filters=16,
num_channels=num_channels,
filter_size=[filter_size, filter_size_y, filter_size_z],
stride=[stride, stride_y, stride_z],
padding=[padding, padding_y, padding_z],
groups=groups,
bias_attr=True,
shared_biases=True,
trans=True,
layer_type="deconv3d",
act=LinearActivation())
outputs(deconv3d_2)
from paddle.trainer_config_helpers import *
settings(batch_size=100, learning_rate=1e-5)
data_2d = data_layer(name='data_2d', size=6000, height=20, width=10)
pool_2d = img_pool_layer(
name="pool___2d",
input=data_2d,
num_channels=30,
pool_size=5,
stride=3,
padding=1,
pool_type=AvgPooling())
outputs(pool_2d)
data_3d = data_layer(
name='data_3d_1', size=60000, depth=10, height=20, width=10)
pool_3d_1 = img_pool3d_layer(
name="pool_3d_1",
input=data_3d,
num_channels=30,
pool_size=5,
stride=3,
padding=1,
pool_type=AvgPooling())
outputs(pool_3d_1)
pool_3d_2 = img_pool3d_layer(
name="pool_3d_2",
input=data_3d,
num_channels=30,
pool_size=[5, 5, 5],
stride=[3, 3, 3],
padding=[1, 1, 1],
pool_type=MaxPooling())
outputs(pool_3d_2)
......@@ -17,3 +17,4 @@ from paddle.trainer.config_parser import parse_config_and_serialize
if __name__ == '__main__':
parse_config_and_serialize(
'trainer_config_helpers/tests/layers_test_config.py', '')
# layers_test_config.py
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册