diff --git a/paddle/gserver/layers/ConvBaseLayerCpu.cpp b/paddle/gserver/layers/ConvBaseLayerCpu.cpp index 0da92bf0485b00898583eb1c4a4ce0387c376b36..13521bcd2b209eec78c3d7f7c5d03aeddb1da405 100644 --- a/paddle/gserver/layers/ConvBaseLayerCpu.cpp +++ b/paddle/gserver/layers/ConvBaseLayerCpu.cpp @@ -22,11 +22,20 @@ bool ConvBaseLayerCpu::init(const LayerMap &layerMap, /* Initialize the basic convolutional parent class */ ConvBaseLayer::init(layerMap, parameterMap); + /* The class fields channels_ and numFilters_ are the same as in the config + * i.e., channels_ is the for the input and numFilters_ is for the output + * + * But in order for the variables in convTrans having the same semantic + * meaning as in conv, we need to swap channels_ and numFilters here for + * convTrans, and in other functions too. + * */ int channel; + int nf; /* Initialize the projection */ for (auto &inputConfig : config_.inputs()) { const ConvConfig &conf = inputConfig.conv_conf(); - subM_.push_back(numFilters_ / conf.groups()); + nf = isConv_ ? numFilters_ : conf.channels(); + subM_.push_back(nf / conf.groups()); subN_.push_back(conf.output_x() * conf.output_x()); channel = isConv_ ? conf.channels() : numFilters_; subK_.push_back(channel * conf.filter_size() * conf.filter_size() / @@ -123,20 +132,19 @@ void ConvBaseLayerCpu::expandFwdOnce(MatrixPtr image, MatrixPtr out, } } -void ConvBaseLayerCpu::bpropActs(MatrixPtr image, MatrixPtr out, int inpIdx) { +void ConvBaseLayerCpu::bpropActs(MatrixPtr out, MatrixPtr image, int inpIdx) { int channel = isConv_ ? channels_[inpIdx] : numFilters_; int subM = subM_[inpIdx]; int subN = subN_[inpIdx]; int subK = subK_[inpIdx]; size_t batchSize = image->getHeight(); - MatrixPtr tgtGrad = out; /* reset the expand-grad memory */ resetExpandInput(subK * groups_[inpIdx], subN); - real *localGradData = image->getData(); - real *tgtGradData = tgtGrad->getData(); + real *localGradData = out->getData(); + real *tgtGradData = image->getData(); for (size_t n = 0; n < batchSize; n++) { real *wgtData = weights_[inpIdx]->getW()->getData(); real *expandInData = expandInput_->getData();