提交 d5768ebc 编写于 作者: C chengduoZH

fix above comments

上级 38cc5dad
...@@ -240,16 +240,25 @@ extern void hl_matrix_rotate( ...@@ -240,16 +240,25 @@ extern void hl_matrix_rotate(
* @param[in] strideW stride in the width. * @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth. * @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height. * @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width. * @param[in] paddingW padding in the width.
* @param[out] matDst output matrix. * @param[out] matDst output matrix.
* *
*/ */
extern void hl_matrix_vol2Col(real* matSrc, extern void hl_matrix_vol2Col(const real* dataSrc,
int channel, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real* matDst); int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst);
/** /**
* @brief Matrix col2Vol: Convert col matrix into 3D volume * @brief Matrix col2Vol: Convert col matrix into 3D volume
...@@ -267,19 +276,28 @@ extern void hl_matrix_vol2Col(real* matSrc, ...@@ -267,19 +276,28 @@ extern void hl_matrix_vol2Col(real* matSrc,
* @param[in] strideW stride in the width. * @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth. * @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height. * @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width. * @param[in] paddingW padding in the width.
* @param[in] matSrc input matrix. * @param[in] matSrc input matrix.
* @param[in] beta input * @param[in] beta input
* @param[in] alpha input * @param[in] alpha input
* *
*/ */
extern void hl_matrix_col2Vol(real* matDst, extern void hl_matrix_col2Vol(real* dataDst,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real* matSrc, int filterD,
real alpha, real beta); int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta);
#endif /* HL_MATRIX_H_ */ #endif /* HL_MATRIX_H_ */
...@@ -99,19 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d, ...@@ -99,19 +99,38 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline void hl_matrix_rotate( inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {} real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
inline void hl_matrix_vol2Col(real* data, inline void hl_matrix_vol2Col(const real* dataSrc,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
real* data_col) {} int filterD,
int filterH,
inline void hl_matrix_col2Vol(real* data, int filterW,
int channels, int depth, int height, int width, int strideD,
int filterD, int filterH, int filterW, int strideH,
int strideD, int strideH, int strideW, int strideW,
int paddingD, int paddingH, int paddingW, int paddingD,
real* data_Im, int paddingH,
real alpha, real beta) {} int paddingW,
real* dataDst) {}
inline void hl_matrix_col2Vol(real* dataDst,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
const real* dataSrc,
real alpha,
real beta) {}
#endif // HL_MATRIX_STUB_H_ #endif // HL_MATRIX_STUB_H_
...@@ -594,7 +594,7 @@ void hl_matrix_rotate( ...@@ -594,7 +594,7 @@ void hl_matrix_rotate(
} }
__global__ void keMatrixVol2Col(int num_kernels, __global__ void keMatrixVol2Col(int num_kernels,
real* dataSrc, const real* dataSrc,
real* dataDst, real* dataDst,
int depth, int depth,
int height, int height,
...@@ -643,7 +643,7 @@ __global__ void keMatrixVol2Col(int num_kernels, ...@@ -643,7 +643,7 @@ __global__ void keMatrixVol2Col(int num_kernels,
} }
} }
void hl_matrix_vol2Col(real* dataSrc, void hl_matrix_vol2Col(const real* dataSrc,
int channels, int channels,
int depth, int depth,
int height, int height,
...@@ -666,30 +666,30 @@ void hl_matrix_vol2Col(real* dataSrc, ...@@ -666,30 +666,30 @@ void hl_matrix_vol2Col(real* dataSrc,
const int threads = 512; const int threads = 512;
const int blocks = DIVUP(num_kernels, threads); const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<<blocks, threads>>>(num_kernels, keMatrixVol2Col<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataSrc, dataSrc,
dataDst, dataDst,
depth, depth,
height, height,
width, width,
filterD, filterD,
filterH, filterH,
filterW, filterW,
strideD, strideD,
strideH, strideH,
strideW, strideW,
paddingD, paddingD,
paddingH, paddingH,
paddingW, paddingW,
depth_col, depth_col,
height_col, height_col,
width_col); width_col);
CHECK_SYNC("hl_matrix_vol2Col failed"); CHECK_SYNC("hl_matrix_vol2Col failed");
} }
__global__ void keMatrixCol2Vol(int num_kernels, __global__ void keMatrixCol2Vol(int num_kernels,
real* dataDst, real* dataDst,
real* dataSrc, const real* dataSrc,
int depth, int depth,
int height, int height,
int width, int width,
...@@ -759,7 +759,7 @@ void hl_matrix_col2Vol(real* dataDst, ...@@ -759,7 +759,7 @@ void hl_matrix_col2Vol(real* dataDst,
int paddingD, int paddingD,
int paddingH, int paddingH,
int paddingW, int paddingW,
real* dataSrc, const real* dataSrc,
real alpha, real alpha,
real beta) { real beta) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
...@@ -770,26 +770,26 @@ void hl_matrix_col2Vol(real* dataDst, ...@@ -770,26 +770,26 @@ void hl_matrix_col2Vol(real* dataDst,
const int threads = 512; const int threads = 512;
const int blocks = DIVUP(num_kernels, threads); const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<<blocks, threads>>>(num_kernels, keMatrixCol2Vol<<<blocks, threads, 0, STREAM_DEFAULT>>>(num_kernels,
dataDst, dataDst,
dataSrc, dataSrc,
depth, depth,
height, height,
width, width,
filterD, filterD,
filterH, filterH,
filterW, filterW,
strideD, strideD,
strideH, strideH,
strideW, strideW,
paddingD, paddingD,
paddingH, paddingH,
paddingW, paddingW,
depth_col, depth_col,
height_col, height_col,
width_col, width_col,
alpha, alpha,
beta); beta);
CHECK_SYNC("hl_matrix_col2Vol failed"); CHECK_SYNC("hl_matrix_col2Vol failed");
} }
...@@ -28,16 +28,26 @@ bool Conv3DLayer::init(const LayerMap &layerMap, ...@@ -28,16 +28,26 @@ bool Conv3DLayer::init(const LayerMap &layerMap,
const ConvConfig &conf = inputConfig.conv_conf(); const ConvConfig &conf = inputConfig.conv_conf();
M_.push_back(numFilters_ / conf.groups()); M_.push_back(numFilters_ / conf.groups());
K_.push_back(filterPixels_[index] * filterChannels_[index]); K_.push_back(filterPixels_[index] * filterChannels_[index]);
if (nullptr != weights_[index]->getW())
weights_[index]->getW()->reshape(weights_[index]->getW()->getWidth(), // create a new weight
weights_[index]->getW()->getHeight()); size_t height, width;
if (nullptr != weights_[index]->getWGrad()) width = filterPixels_[index] * filterChannels_[index];
weights_[index]->getWGrad()->reshape( height = numFilters_;
weights_[index]->getWGrad()->getWidth(), CHECK_EQ(parameters_[index]->getSize(), width * height);
weights_[index]->getWGrad()->getHeight()); Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
++index; ++index;
} }
CHECK(inputLayers_.size() == parameters_.size()); if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
return true; return true;
} }
......
...@@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "ConvBaseLayer.h" #include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/MathUtils.h" #include "paddle/math/MathUtils.h"
#include <vector> #include "paddle/math/Matrix.h"
namespace paddle { namespace paddle {
...@@ -30,21 +28,17 @@ namespace paddle { ...@@ -30,21 +28,17 @@ namespace paddle {
class Conv3DLayer : public ConvBaseLayer { class Conv3DLayer : public ConvBaseLayer {
public: public:
explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} explicit Conv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~Conv3DLayer() {} ~Conv3DLayer() {}
bool init(const LayerMap &layerMap, const ParameterMap &parameterMap); bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
size_t getSize();
void forward(PassType passType); void forward(PassType passType);
void addBias(); void addBias();
void backward(const UpdateCallback& callback); void backward(const UpdateCallback& callback);
void bpropBiases(); void bpropBiases();
void bpropData(int i); void bpropData(int i);
void bpropWeights(int i); void bpropWeights(int i);
size_t getSize();
protected: protected:
// Figure out the dimensions for individual gemms. // Figure out the dimensions for individual gemms.
......
...@@ -21,8 +21,7 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -21,8 +21,7 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" || isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv")
config_.type() == "conv3d" || config_.type() == "deconv3d")
? false ? false
: true; : true;
...@@ -56,28 +55,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -56,28 +55,9 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
} }
CHECK(inputLayers_.size() == parameters_.size()); CHECK(inputLayers_.size() == parameters_.size());
for (size_t i = 0; i < inputLayers_.size(); i++) {
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
// create a new weight
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight* w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
}
/* initialize the biases_ */ // create new weights_ in derived class
if (biasParameter_.get()) { // create new biases_ in derived class
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
}
// default caffe model // default caffe model
caffeMode_ = true; caffeMode_ = true;
......
...@@ -23,7 +23,6 @@ namespace paddle { ...@@ -23,7 +23,6 @@ namespace paddle {
* with learned filters and (optionally) adds biases. * with learned filters and (optionally) adds biases.
*/ */
class ConvBaseLayer : public Layer { class ConvBaseLayer : public Layer {
protected: protected:
typedef std::vector<int> IntV; typedef std::vector<int> IntV;
......
...@@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap, ...@@ -46,8 +46,26 @@ bool CudnnConvBaseLayer::init(const LayerMap &layerMap,
projConf_.emplace_back(conf); projConf_.emplace_back(conf);
projections_.emplace_back( projections_.emplace_back(
Projection::create(*projConf_[i], parameters_[i], useGpu_)); Projection::create(*projConf_[i], parameters_[i], useGpu_));
// create a new weight
size_t height, width;
height = filterPixels_[i] * filterChannels_[i];
width = (!isDeconv_) ? numFilters_ : channels_[i];
CHECK_EQ(parameters_[i]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[i]);
weights_.emplace_back(w);
} }
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
if (biases_.get() && sharedBiases_) { if (biases_.get() && sharedBiases_) {
hl_create_tensor_descriptor(&biasDesc_); hl_create_tensor_descriptor(&biasDesc_);
hl_create_tensor_descriptor(&outputDesc_); hl_create_tensor_descriptor(&outputDesc_);
......
...@@ -20,9 +20,6 @@ namespace paddle { ...@@ -20,9 +20,6 @@ namespace paddle {
REGISTER_LAYER(deconv3d, DeConv3DLayer); REGISTER_LAYER(deconv3d, DeConv3DLayer);
#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \
(((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE))
bool DeConv3DLayer::init(const LayerMap &layerMap, bool DeConv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) { const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
...@@ -32,14 +29,25 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, ...@@ -32,14 +29,25 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
for (int index = 0; index < config_.inputs().size(); ++index) { for (int index = 0; index < config_.inputs().size(); ++index) {
M_.push_back(filterChannels_[index]); M_.push_back(filterChannels_[index]);
K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index])); K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index]));
if (weights_[index]->getW())
weights_[index]->getW()->reshape(filterPixels_[index] * numFilters_, // create a new weight
filterChannels_[index]); size_t height, width;
if (weights_[index]->getWGrad()) height = filterPixels_[index] * numFilters_;
weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_, width = filterChannels_[index];
filterChannels_[index]); CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
}
} }
CHECK(inputLayers_.size() == parameters_.size());
return true; return true;
} }
...@@ -52,22 +60,22 @@ size_t DeConv3DLayer::getSize() { ...@@ -52,22 +60,22 @@ size_t DeConv3DLayer::getSize() {
outputW_.clear(); outputW_.clear();
outputD_.clear(); outputD_.clear();
N_.clear(); N_.clear();
No_.clear(); NOut_.clear();
size_t layerSize = 0; size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) { for (size_t i = 0; i < inputLayers_.size(); ++i) {
// imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
// imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
// imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth());
outputW_.push_back(DECONV_OUTPUT_SIZE( outputW_.push_back(
imgSizeW_[i], stride_[i], padding_[i], filterSize_[i])); imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(DECONV_OUTPUT_SIZE( outputH_.push_back(imageSize(
imgSizeH_[i], strideY_[i], paddingY_[i], filterSizeY_[i])); imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(DECONV_OUTPUT_SIZE( outputD_.push_back(imageSize(
imgSizeD_[i], strideZ_[i], paddingZ_[i], filterSizeZ_[i])); imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += No_[i] * numFilters_; layerSize += NOut_[i] * numFilters_;
} }
getOutput().setFrameHeight(outputH_[0]); getOutput().setFrameHeight(outputH_[0]);
getOutput().setFrameWidth(outputW_[0]); getOutput().setFrameWidth(outputW_[0]);
......
...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "ConvBaseLayer.h" #include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
#include "paddle/math/MathUtils.h" #include "paddle/math/MathUtils.h"
#include <vector> #include "paddle/math/Matrix.h"
namespace paddle { namespace paddle {
...@@ -29,30 +28,25 @@ namespace paddle { ...@@ -29,30 +28,25 @@ namespace paddle {
*/ */
class DeConv3DLayer : public ConvBaseLayer { class DeConv3DLayer : public ConvBaseLayer {
public: public:
explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {} explicit DeConv3DLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~DeConv3DLayer() {}
~DeConv3DLayer() {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
bool init(const LayerMap &layerMap, const ParameterMap &parameterMap); void forward(PassType passType);
void addBias();
size_t getSize(); void backward(const UpdateCallback& callback);
void bpropBiases();
void forward(PassType passType); void bpropData(int i);
void addBias(); void bpropWeights(int i);
size_t getSize();
void backward(const UpdateCallback& callback);
void bpropBiases();
void bpropData(int i);
void bpropWeights(int i);
protected: protected:
// Figure out the dimensions for individual gemms. // Figure out the dimensions for individual gemms.
IntV M_; /// numFilters_ / filter_group_; IntV M_; /// numFilters_ / filter_group_;
IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_ IntV N_; /// channels_ * filterSizeZ_ * filterSize_ * filterSizeY_
IntV K_; /// outputD_ * outputH_ * outputW_ IntV K_; /// outputD_ * outputH_ * outputW_
IntV No_; IntV NOut_;
MatrixPtr colBuf_; MatrixPtr colBuf_;
}; };
} // namespace paddle } // namespace paddle
...@@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap, ...@@ -22,12 +22,31 @@ bool ExpandConvBaseLayer::init(const LayerMap &layerMap,
/* Initialize the basic convolutional parent class */ /* Initialize the basic convolutional parent class */
ConvBaseLayer::init(layerMap, parameterMap); ConvBaseLayer::init(layerMap, parameterMap);
int index = 0;
for (auto &inputConfig : config_.inputs()) { for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf(); const ConvConfig &conf = inputConfig.conv_conf();
/* Consistent caffe mode for multiple input */ /* Consistent caffe mode for multiple input */
caffeMode_ = conf.caffe_mode(); caffeMode_ = conf.caffe_mode();
}
// create a new weight
size_t height, width;
height = filterPixels_[index] * filterChannels_[index];
width = (!isDeconv_) ? numFilters_ : channels_[index];
CHECK_EQ(parameters_[index]->getSize(), width * height);
Weight *w = new Weight(height, width, parameters_[index]);
weights_.emplace_back(w);
index++;
}
if (biasParameter_.get()) {
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
getOutputSize(); getOutputSize();
return true; return true;
......
...@@ -2019,7 +2019,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -2019,7 +2019,7 @@ void test3DConvLayer(const string& type, bool trans, bool useGpu) {
const int CHANNELS = 3; const int CHANNELS = 3;
const int IMAGE_SIZE = 9; const int IMAGE_SIZE = 9;
const int IMAGE_SIZE_Y = 9; const int IMAGE_SIZE_Y = 9;
const int IMAGE_SIZE_Z = 9; // 2, 3, 5, 5, 5 const int IMAGE_SIZE_Z = 9;
TestConfig config; TestConfig config;
config.biasSize = NUM_FILTERS; config.biasSize = NUM_FILTERS;
...@@ -2084,10 +2084,6 @@ TEST(Layer, test3DConvLayer) { ...@@ -2084,10 +2084,6 @@ TEST(Layer, test3DConvLayer) {
#endif #endif
} }
int deConvOutputSize(int inSize, int kSize, int pad, int stride) {
return (inSize - 1) * stride - 2 * pad + kSize;
}
void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
// filter size // filter size
const int NUM_FILTERS = 6; const int NUM_FILTERS = 6;
...@@ -2126,16 +2122,21 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -2126,16 +2122,21 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
conv->set_img_size(IMAGE_SIZE); conv->set_img_size(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y); conv->set_img_size_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z); conv->set_img_size_z(IMAGE_SIZE_Z);
conv->set_output_x(deConvOutputSize( conv->set_output_x(imageSize(conv->img_size(),
conv->img_size(), conv->filter_size(), conv->padding(), conv->stride())); conv->filter_size(),
conv->set_output_y(deConvOutputSize(conv->img_size_y(), conv->padding(),
conv->filter_size_y(), conv->stride(),
conv->padding_y(), true));
conv->stride_y())); conv->set_output_y(imageSize(conv->img_size_y(),
conv->set_output_z(deConvOutputSize(conv->img_size_z(), conv->filter_size_y(),
conv->filter_size_z(), conv->padding_y(),
conv->padding_z(), conv->stride_y(),
conv->stride_z())); true));
conv->set_output_z(imageSize(conv->img_size_z(),
conv->filter_size_z(),
conv->padding_z(),
conv->stride_z(),
true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() * config.layerConfig.set_size(conv->output_x() * conv->output_y() *
conv->output_z() * NUM_FILTERS); conv->output_z() * NUM_FILTERS);
conv->set_groups(1); conv->set_groups(1);
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "TensorCheck.h" #include "TensorCheck.h"
#include "paddle/math/MathUtils.h"
#include "paddle/math/Matrix.h" #include "paddle/math/Matrix.h"
#include "paddle/math/SparseMatrix.h" #include "paddle/math/SparseMatrix.h"
#include "paddle/testing/TestUtil.h" #include "paddle/testing/TestUtil.h"
...@@ -1203,19 +1204,6 @@ TEST(Matrix, warpCTC) { ...@@ -1203,19 +1204,6 @@ TEST(Matrix, warpCTC) {
} }
} }
int outputSizeCol2Vol(
int imageSize, int filterSize, int padding, int stride, bool caffeMode) {
int outputSize;
if (!caffeMode) {
outputSize =
(imageSize - filterSize + 2 * padding + stride - 1) / stride + 1;
} else {
outputSize = (imageSize - filterSize + 2 * padding) / stride + 1;
}
CHECK_GE(outputSize, 1);
return outputSize;
}
void testMatrixCol2Vol(int depth, int height, int width) { void testMatrixCol2Vol(int depth, int height, int width) {
int channel = 3; int channel = 3;
int filterX = 3, filterY = 4, filterZ = 5; int filterX = 3, filterY = 4, filterZ = 5;
...@@ -1229,9 +1217,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { ...@@ -1229,9 +1217,9 @@ void testMatrixCol2Vol(int depth, int height, int width) {
cpuImage->randomizeUniform(); cpuImage->randomizeUniform();
gpuImage->copyFrom(*cpuImage); gpuImage->copyFrom(*cpuImage);
int outD = outputSizeCol2Vol(depth, filterZ, padZ, strideZ, true); int outD = outputSize(depth, filterZ, padZ, strideZ, true);
int outH = outputSizeCol2Vol(height, filterY, padZ, strideY, true); int outH = outputSize(height, filterY, padY, strideY, true);
int outW = outputSizeCol2Vol(width, filterX, padZ, strideX, true); int outW = outputSize(width, filterX, padX, strideX, true);
int colBufHeight = channel * filterZ * filterY * filterX; int colBufHeight = channel * filterZ * filterY * filterX;
int colBufWidth = outD * outH * outW; int colBufWidth = outD * outH * outW;
...@@ -1305,11 +1293,9 @@ void testMatrixCol2Vol(int depth, int height, int width) { ...@@ -1305,11 +1293,9 @@ void testMatrixCol2Vol(int depth, int height, int width) {
} }
TEST(Matrix, col2Vol) { TEST(Matrix, col2Vol) {
for (auto depth : {9, 16, 64, 128}) { for (auto depth : {9, 16, 64}) {
for (auto height : {9, 11, 73, 128, 256}) { for (auto height : {9, 11, 128}) {
for (auto width : { for (auto width : {9, 32, 128}) {
9, 32, 100, 512,
}) {
VLOG(3) << "depth=" << depth << " height=" << height VLOG(3) << "depth=" << depth << " height=" << height
<< " width=" << width; << " width=" << width;
testMatrixCol2Vol(depth, height, width); testMatrixCol2Vol(depth, height, width);
......
...@@ -82,7 +82,7 @@ message ConvConfig { ...@@ -82,7 +82,7 @@ message ConvConfig {
// if not set, use img_size // if not set, use img_size
optional uint32 img_size_y = 14; optional uint32 img_size_y = 14;
optional uint32 filter_size_z = 15 [ default = 1 ]; optional uint32 filter_size_z = 15 [ default = 1 ];
optional uint32 padding_z = 16 [ default = 1 ]; optional uint32 padding_z = 16 [ default = 1 ];
optional uint32 stride_z = 17 [ default = 1 ]; optional uint32 stride_z = 17 [ default = 1 ];
...@@ -637,4 +637,4 @@ message ModelConfig { ...@@ -637,4 +637,4 @@ message ModelConfig {
// For External Machine, defining how to split a neural network // For External Machine, defining how to split a neural network
// into multiple parts. // into multiple parts.
optional ExternalConfig external_config = 9; optional ExternalConfig external_config = 9;
}; };
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册