提交 e9a1f2ad 编写于 作者: C chengduoZH

fix DeConv3D switch(imgSize*_, output*_)

上级 4fbc03d3
...@@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, ...@@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
size_t DeConv3DLayer::getSize() { size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL); CHECK_NE(inputLayers_.size(), 0UL);
outputH_.clear(); imgSizeW_.clear();
outputW_.clear(); imgSizeH_.clear();
outputD_.clear(); imgSizeD_.clear();
N_.clear(); N_.clear();
NOut_.clear(); NOut_.clear();
size_t layerSize = 0; size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) { for (size_t i = 0; i < inputLayers_.size(); ++i) {
outputW_.push_back( imgSizeW_.push_back(
imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true)); imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true));
outputH_.push_back(imageSize( imgSizeH_.push_back(imageSize(
imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true)); outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
outputD_.push_back(imageSize( imgSizeD_.push_back(imageSize(
imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true)); outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += NOut_[i] * numFilters_; layerSize += NOut_[i] * numFilters_;
} }
getOutput().setFrameHeight(outputH_[0]); getOutput().setFrameHeight(imgSizeH_[0]);
getOutput().setFrameWidth(outputW_[0]); getOutput().setFrameWidth(imgSizeW_[0]);
getOutput().setFrameDepth(outputD_[0]); getOutput().setFrameDepth(imgSizeD_[0]);
return layerSize; return layerSize;
} }
...@@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) { ...@@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) {
} }
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(), colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_, numFilters_,
outputD_[i], imgSizeD_[i],
outputH_[i], imgSizeH_[i],
outputW_[i], imgSizeW_[i],
filterSizeZ_[i], filterSizeZ_[i],
filterSizeY_[i], filterSizeY_[i],
filterSize_[i], filterSize_[i],
...@@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) { ...@@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) {
colBuf_->vol2Col( colBuf_->vol2Col(
getOutputGrad()->getData() + n * getOutputGrad()->getStride(), getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
numFilters_, numFilters_,
outputD_[i], imgSizeD_[i],
outputH_[i], imgSizeH_[i],
outputW_[i], imgSizeW_[i],
filterSizeZ_[i], filterSizeZ_[i],
filterSizeY_[i], filterSizeY_[i],
filterSize_[i], filterSize_[i],
......
...@@ -2302,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) { ...@@ -2302,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
conv->set_stride(2); conv->set_stride(2);
conv->set_stride_y(2); conv->set_stride_y(2);
conv->set_stride_z(2); conv->set_stride_z(2);
conv->set_img_size(IMAGE_SIZE); conv->set_output_x(IMAGE_SIZE);
conv->set_img_size_y(IMAGE_SIZE_Y); conv->set_output_y(IMAGE_SIZE_Y);
conv->set_img_size_z(IMAGE_SIZE_Z); conv->set_output_z(IMAGE_SIZE_Z);
conv->set_output_x(imageSize(conv->img_size(),
conv->set_img_size(imageSize(conv->output_x(),
conv->filter_size(), conv->filter_size(),
conv->padding(), conv->padding(),
conv->stride(), conv->stride(),
true)); true));
conv->set_output_y(imageSize(conv->img_size_y(), conv->set_img_size_y(imageSize(conv->output_y(),
conv->filter_size_y(), conv->filter_size_y(),
conv->padding_y(), conv->padding_y(),
conv->stride_y(), conv->stride_y(),
true)); true));
conv->set_output_z(imageSize(conv->img_size_z(), conv->set_img_size_z(imageSize(conv->output_z(),
conv->filter_size_z(), conv->filter_size_z(),
conv->padding_z(), conv->padding_z(),
conv->stride_z(), conv->stride_z(),
true)); true));
config.layerConfig.set_size(conv->output_x() * conv->output_y() * config.layerConfig.set_size(conv->img_size() * conv->img_size_y() *
conv->output_z() * NUM_FILTERS); conv->img_size_z() * NUM_FILTERS);
conv->set_groups(1); conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups()); conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back( config.inputDefs.push_back(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册