提交 c792ef7d 编写于 作者: C chengduoZH

fix DeConv3D, Conv3D

上级 424b325d
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "Conv3DLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include "Conv3DLayer.h"
namespace paddle { namespace paddle {
...@@ -22,32 +22,30 @@ REGISTER_LAYER(conv3d, Conv3DLayer); ...@@ -22,32 +22,30 @@ REGISTER_LAYER(conv3d, Conv3DLayer);
bool Conv3DLayer::init(const LayerMap &layerMap, bool Conv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) { const ParameterMap &parameterMap) {
if (!ConvBaseLayer::init(layerMap, parameterMap)) if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
return false;
int index = 0; int index = 0;
for (auto &inputConfig : config_.inputs()) { for (auto &inputConfig : config_.inputs()) {
const ConvConfig &conf = inputConfig.conv_conf(); const ConvConfig &conf = inputConfig.conv_conf();
M_.push_back(numFilters_ / conf.groups()); M_.push_back(numFilters_ / conf.groups());
K_.push_back( K_.push_back(filterPixels_[index] * filterChannels_[index]);
conf.filter_channels() * conf.filter_size_z() * \ if (nullptr != weights_[index]->getW())
conf.filter_size_y() * conf.filter_size()); weights_[index]->getW()->reshape(weights_[index]->getW()->getWidth(),
weights_[index]->getW()->reshape(
weights_[index]->getW()->getWidth(),
weights_[index]->getW()->getHeight()); weights_[index]->getW()->getHeight());
if (nullptr != weights_[index]->getWGrad())
weights_[index]->getWGrad()->reshape( weights_[index]->getWGrad()->reshape(
weights_[index]->getWGrad()->getWidth(), weights_[index]->getWGrad()->getWidth(),
weights_[index]->getWGrad()->getHeight()); weights_[index]->getWGrad()->getHeight());
++index; ++index;
} }
biases_->getWGrad()->reshape( if (nullptr != biases_->getWGrad())
biases_->getWGrad()->width_, biases_->getWGrad()->height_); biases_->getWGrad()->reshape(biases_->getWGrad()->width_,
biases_->getW()->reshape( biases_->getWGrad()->height_);
biases_->getW()->width_, biases_->getW()->height_); if (nullptr != biases_->getW())
biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_);
CHECK(inputLayers_.size() == parameters_.size()); CHECK(inputLayers_.size() == parameters_.size());
return true; return true;
} }
size_t Conv3DLayer::getSize() { size_t Conv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL); CHECK_NE(inputLayers_.size(), 0UL);
// imgSizeH_.clear(); // imgSizeH_.clear();
...@@ -63,14 +61,11 @@ size_t Conv3DLayer::getSize() { ...@@ -63,14 +61,11 @@ size_t Conv3DLayer::getSize() {
// imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
// imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth());
outputW_.push_back(outputSize( outputW_.push_back(outputSize(
imgSizeW_[i], filterSize_[i], imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
padding_[i], stride_[i], true));
outputH_.push_back(outputSize( outputH_.push_back(outputSize(
imgSizeH_[i], filterSizeY_[i], imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
paddingY_[i], strideY_[i], true));
outputD_.push_back(outputSize( outputD_.push_back(outputSize(
imgSizeD_[i], filterSizeZ_[i], imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
paddingZ_[i], strideZ_[i], true));
N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
...@@ -88,32 +83,40 @@ void Conv3DLayer::forward(PassType passType) { ...@@ -88,32 +83,40 @@ void Conv3DLayer::forward(PassType passType) {
int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
int outWidth = getSize(); int outWidth = getSize();
resetOutput(batchSize, outWidth); resetOutput(batchSize, outWidth);
const MatrixPtr outMat = getOutputValue();
for (size_t i = 0; i != inputLayers_.size(); ++i) { for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdConv3D", getName().c_str()); REGISTER_TIMER_INFO("FwdConv3D", getName().c_str());
const MatrixPtr& inMat = getInputValue(i); const MatrixPtr &inMat = getInputValue(i);
int width = inMat->getWidth(); const MatrixPtr &outMat = getOutputValue();
int M = M_[i]; int M = M_[i];
int N = N_[i]; int N = N_[i];
int K = K_[i]; int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW(); MatrixPtr wMat = weights_[i]->getW();
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], channels_[i],
filterSizeZ_[i], filterSizeY_[i], filterSize_[i], imgSizeD_[i],
strideZ_[i], strideY_[i], stride_[i], imgSizeH_[i],
paddingZ_[i], paddingY_[i], padding_[i]); imgSizeW_[i],
filterSizeZ_[i],
real *outData = outMat->getData() + n * outWidth; filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outData = outMat->getData() + n * outMat->getStride();
MatrixPtr outMatSub = MatrixPtr outMatSub =
Matrix::create(outData, groups_[i] * M, N, false, useGpu_); Matrix::create(outData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; g++) { for (int g = 0; g < groups_[i]; g++) {
MatrixPtr wMatSub = wMat->subMatrix(g * M, M); MatrixPtr wMatSub = wMat->subMatrix(g * M, M);
MatrixPtr in = colBuf_->subMatrix(g * K, K); MatrixPtr in = colBuf_->subMatrix(g * K, K);
MatrixPtr out = outMatSub->subMatrix(g * M, M); MatrixPtr out = outMatSub->subMatrix(g * M, M);
out->mul(*wMatSub, *in, 1.0, 0.0); out->mul(*wMatSub, *in, 1.0, 1.0);
} }
} }
} }
...@@ -137,7 +140,7 @@ void Conv3DLayer::backward(const UpdateCallback &callback) { ...@@ -137,7 +140,7 @@ void Conv3DLayer::backward(const UpdateCallback &callback) {
if (weights_[i]->getWGrad()) { if (weights_[i]->getWGrad()) {
bpropWeights(i); bpropWeights(i);
} }
if (this->needGradient_) { if (getInputGrad(i)) {
bpropData(i); bpropData(i);
} }
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
...@@ -149,20 +152,28 @@ void Conv3DLayer::bpropWeights(int i) { ...@@ -149,20 +152,28 @@ void Conv3DLayer::bpropWeights(int i) {
int M = M_[i]; int M = M_[i];
int N = N_[i]; int N = N_[i];
int K = K_[i]; int K = K_[i];
const MatrixPtr& inMat = getInputValue(i); const MatrixPtr &inMat = getInputValue(i);
int width = inMat->getWidth();
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wGradMat = weights_[i]->getWGrad(); MatrixPtr wGradMat = weights_[i]->getWGrad();
real* outGradData = getOutputGrad()->getData();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
colBuf_->vol2Col(inMat->getData() + n * width, channels_[i], colBuf_->vol2Col(inMat->getData() + n * inMat->getStride(),
imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], channels_[i],
filterSizeZ_[i], filterSizeY_[i], filterSize_[i], imgSizeD_[i],
strideZ_[i], strideY_[i], stride_[i], imgSizeH_[i],
paddingZ_[i], paddingY_[i], padding_[i]); imgSizeW_[i],
outGradData += n * getOutputGrad()->getWidth(); filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
real *outGradData =
getOutputGrad()->getData() + n * getOutputGrad()->getStride();
MatrixPtr outGradSub = MatrixPtr outGradSub =
Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_); Matrix::create(outGradData, groups_[i] * M, N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) { for (int g = 0; g < groups_[i]; ++g) {
...@@ -180,12 +191,12 @@ void Conv3DLayer::bpropData(int i) { ...@@ -180,12 +191,12 @@ void Conv3DLayer::bpropData(int i) {
int K = K_[i]; int K = K_[i];
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
MatrixPtr wMat = weights_[i]->getW(); MatrixPtr wMat = weights_[i]->getW();
real* outGradData = getOutputGrad()->getData();
real* preGradData = getInputGrad(i)->getData();
int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
outGradData += n * getOutputGrad()->getWidth(); real *outGradData =
preGradData += n * getInputGrad(i)->getWidth(); getOutputGrad()->getData() + n * getOutputGrad()->getStride();
real *preGradData =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
MatrixPtr outGradSub = MatrixPtr outGradSub =
Matrix::create(outGradData, M * groups_[i], N, false, useGpu_); Matrix::create(outGradData, M * groups_[i], N, false, useGpu_);
for (int g = 0; g < groups_[i]; ++g) { for (int g = 0; g < groups_[i]; ++g) {
...@@ -194,12 +205,22 @@ void Conv3DLayer::bpropData(int i) { ...@@ -194,12 +205,22 @@ void Conv3DLayer::bpropData(int i) {
MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K); MatrixPtr inGradMatSub = colBuf_->subMatrix(g * K, K);
inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0); inGradMatSub->mul(*(wMatSub->getTranspose()), *outG, 1.0, 0.0);
} }
colBuf_->col2Vol(preGradData, channels_[i], colBuf_->col2Vol(preGradData,
imgSizeD_[i], imgSizeH_[i], imgSizeW_[i], channels_[i],
filterSizeZ_[i], filterSizeY_[i], filterSize_[i], imgSizeD_[i],
strideZ_[i], strideY_[i], stride_[i], imgSizeH_[i],
paddingZ_[i], paddingY_[i], padding_[i], imgSizeW_[i],
1.0, 1.0); filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
} }
} }
...@@ -214,7 +235,6 @@ void Conv3DLayer::bpropBiases() { ...@@ -214,7 +235,6 @@ void Conv3DLayer::bpropBiases() {
void Conv3DLayer::addBias() { void Conv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue(); MatrixPtr outMat = getOutputValue();
if (this->sharedBiases_) { if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f); outMat->addSharedBias(*(biases_->getW()), 1.0f);
} else { } else {
......
...@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "DeConv3DLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include "DeConv3DLayer.h"
namespace paddle { namespace paddle {
REGISTER_LAYER(deconv3d, DeConv3DLayer); REGISTER_LAYER(deconv3d, DeConv3DLayer);
#define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \ #define DECONV_OUTPUT_SIZE(IN_SIZE, STRID, PAD, KSIZE) \
(((IN_SIZE) - 1) * (STRID) - 2 * (PAD) + (KSIZE)) (((IN_SIZE)-1) * (STRID)-2 * (PAD) + (KSIZE))
bool DeConv3DLayer::init(const LayerMap &layerMap, bool DeConv3DLayer::init(const LayerMap &layerMap,
const ParameterMap &parameterMap) { const ParameterMap &parameterMap) {
...@@ -31,24 +31,23 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, ...@@ -31,24 +31,23 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
// Matrix storage format: (output * depth * height * weigth) x channel // Matrix storage format: (output * depth * height * weigth) x channel
for (int index = 0; index < config_.inputs().size(); ++index) { for (int index = 0; index < config_.inputs().size(); ++index) {
M_.push_back(filterChannels_[index]); M_.push_back(filterChannels_[index]);
K_.push_back( K_.push_back(filterPixels_[index] * (numFilters_ / groups_[index]));
filterPixels_[index] * (numFilters_/groups_[index])); if (weights_[index]->getW())
weights_[index]->getW()->reshape( weights_[index]->getW()->reshape(filterPixels_[index] * numFilters_,
filterPixels_[index] * numFilters_,
filterChannels_[index]); filterChannels_[index]);
weights_[index]->getWGrad()->reshape( if (weights_[index]->getWGrad())
filterPixels_[index] * numFilters_, weights_[index]->getWGrad()->reshape(filterPixels_[index] * numFilters_,
filterChannels_[index]); filterChannels_[index]);
} }
biases_->getWGrad()->reshape( if (biases_->getWGrad())
biases_->getWGrad()->width_, biases_->getWGrad()->height_); biases_->getWGrad()->reshape(biases_->getWGrad()->width_,
biases_->getW()->reshape( biases_->getWGrad()->height_);
biases_->getW()->width_, biases_->getW()->height_); if (biases_->getW())
biases_->getW()->reshape(biases_->getW()->width_, biases_->getW()->height_);
CHECK(inputLayers_.size() == parameters_.size()); CHECK(inputLayers_.size() == parameters_.size());
return true; return true;
} }
size_t DeConv3DLayer::getSize() { size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL); CHECK_NE(inputLayers_.size(), 0UL);
// imgSizeH_.clear(); // imgSizeH_.clear();
...@@ -64,18 +63,12 @@ size_t DeConv3DLayer::getSize() { ...@@ -64,18 +63,12 @@ size_t DeConv3DLayer::getSize() {
// imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); // imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
// imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); // imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
// imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth()); // imgSizeD_.push_back(inputLayers_[i]->getOutput().getFrameDepth());
outputW_.push_back( outputW_.push_back(DECONV_OUTPUT_SIZE(
DECONV_OUTPUT_SIZE( imgSizeW_[i], stride_[i], padding_[i], filterSize_[i]));
imgSizeW_[i], stride_[i], outputH_.push_back(DECONV_OUTPUT_SIZE(
padding_[i], filterSize_[i])); imgSizeH_[i], strideY_[i], paddingY_[i], filterSizeY_[i]));
outputH_.push_back( outputD_.push_back(DECONV_OUTPUT_SIZE(
DECONV_OUTPUT_SIZE( imgSizeD_[i], strideZ_[i], paddingZ_[i], filterSizeZ_[i]));
imgSizeH_[i], strideY_[i],
paddingY_[i], filterSizeY_[i]));
outputD_.push_back(
DECONV_OUTPUT_SIZE(
imgSizeD_[i], strideZ_[i],
paddingZ_[i], filterSizeZ_[i]));
No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]); No_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]); N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize); CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
...@@ -96,32 +89,37 @@ void DeConv3DLayer::forward(PassType passType) { ...@@ -96,32 +89,37 @@ void DeConv3DLayer::forward(PassType passType) {
for (size_t i = 0; i != inputLayers_.size(); ++i) { for (size_t i = 0; i != inputLayers_.size(); ++i) {
REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str()); REGISTER_TIMER_INFO("FwdDeConv3D", getName().c_str());
const MatrixPtr& inMat = getInputValue(i); const MatrixPtr &inMat = getInputValue(i);
int width = inMat->getWidth();
int M = M_[i]; int M = M_[i];
int N = N_[i]; int N = N_[i];
int K = K_[i]; int K = K_[i];
MatrixPtr wMat = weights_[i]->getW(); MatrixPtr wMat = weights_[i]->getW();
Matrix::resizeOrCreate(colBuf_, K * groups_[i] , N, false, useGpu_); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
real *inData = inMat->getData() + n * width; real *inData = inMat->getData() + n * inMat->getStride();
real *colBufData = colBuf_->getData(); for (int g = 0; g < groups_[i]; ++g) {
for (int g = 0; g < groups_[i]; g++) { MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
MatrixPtr wMatSub = wMat->subMatrix(g * K, K); MatrixPtr wMatSub = wMat->subMatrix(g * K, K);
MatrixPtr inMatSub = MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
Matrix::create(inData, M, N, false, useGpu_);
MatrixPtr colBufDataSub =
Matrix::create(colBufData, K, N, false, useGpu_);
colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0); colBufDataSub->mul(*wMatSub, *inMatSub, 1.0, 0.0);
colBufData += K * N;
inData += M * N; inData += M * N;
} }
colBuf_->col2Vol(outMat->getData()+ n * outMat->getWidth(), colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_, outputD_[i], outputH_[i], outputW_[i], numFilters_,
filterSizeZ_[i], filterSizeY_[i], filterSize_[i], outputD_[i],
strideZ_[i], strideY_[i], stride_[i], outputH_[i],
paddingZ_[i], paddingY_[i], padding_[i], 1.0, 1.0); outputW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i],
1.0,
1.0);
} }
} }
if (nullptr != this->biasParameter_) { if (nullptr != this->biasParameter_) {
...@@ -134,63 +132,69 @@ void DeConv3DLayer::forward(PassType passType) { ...@@ -134,63 +132,69 @@ void DeConv3DLayer::forward(PassType passType) {
void DeConv3DLayer::backward(const UpdateCallback &callback) { void DeConv3DLayer::backward(const UpdateCallback &callback) {
backwardActivation(); backwardActivation();
int batchSize = getOutputGrad()->getHeight(); int batchSize = getOutputGrad()->getHeight();
int outputWidth = getOutputGrad()->getWidth();
if (biases_ && biases_->getWGrad()) { if (biases_ && biases_->getWGrad()) {
bpropBiases(); bpropBiases();
biases_->getParameterPtr()->incUpdate(callback); biases_->getParameterPtr()->incUpdate(callback);
} }
for (size_t i =0; i < inputLayers_.size(); ++i) { for (size_t i = 0; i < inputLayers_.size(); ++i) {
if (weights_[i]->getWGrad() || this->needGradient_) {
int M = M_[i]; int M = M_[i];
int N = N_[i]; int N = N_[i];
int K = K_[i]; int K = K_[i];
REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str());
Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_); Matrix::resizeOrCreate(colBuf_, K * groups_[i], N, false, useGpu_);
const MatrixPtr& inMat = getInputValue(i); const MatrixPtr &inMat = getInputValue(i);
for (int n = 0; n < batchSize; ++n) { for (int n = 0; n < batchSize; ++n) {
REGISTER_TIMER_INFO("BwdDeConv3D", getName().c_str()); colBuf_->vol2Col(
if (weights_[i]->getWGrad() || this->needGradient_) { getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
colBuf_->vol2Col(getOutputGrad()->getData() + n * outputWidth, numFilters_,
numFilters_, outputD_[i], outputH_[i], outputW_[i], outputD_[i],
filterSizeZ_[i], filterSizeY_[i], filterSize_[i], outputH_[i],
strideZ_[i], strideY_[i], stride_[i], outputW_[i],
paddingZ_[i], paddingY_[i], padding_[i]); filterSizeZ_[i],
} filterSizeY_[i],
filterSize_[i],
strideZ_[i],
strideY_[i],
stride_[i],
paddingZ_[i],
paddingY_[i],
padding_[i]);
if (weights_[i]->getWGrad()) { if (weights_[i]->getWGrad()) {
real *inData = inMat->getData() + n * inMat->getWidth();; real *inData = inMat->getData() + n * inMat->getStride();
real *wGradData = weights_[i]->getWGrad()->getData(); for (int g = 0; g < groups_[i]; ++g) {
for (int g = 0; g < groups_[i]; g++) {
MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K); MatrixPtr colBufDataSub = colBuf_->subMatrix(g * K, K);
MatrixPtr inMatSub = Matrix::create( MatrixPtr wGradMatSub =
inData, M, N, false, useGpu_); weights_[i]->getWGrad()->subMatrix(g * K, K);
MatrixPtr wGradMatSub = Matrix::create( MatrixPtr inMatSub = Matrix::create(inData, M, N, false, useGpu_);
wGradData, K, M, false, useGpu_); wGradMatSub->mul(
wGradMatSub->mul(*colBufDataSub, *colBufDataSub, *(inMatSub->getTranspose()), 1.0, 1.0);
*(inMatSub->getTranspose()), 1.0, 1.0);
wGradData += K * M;
inData += M * N; inData += M * N;
} }
weights_[i]->getParameterPtr()->incUpdate(callback);
} }
if (this->needGradient_) { if (getInputGrad(i)) {
real* preGrad = getInputGrad(i)->getData(); real *preGrad =
getInputGrad(i)->getData() + n * getInputGrad(i)->getStride();
for (int g = 0; g < groups_[i]; ++g) { for (int g = 0; g < groups_[i]; ++g) {
MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K); MatrixPtr w = weights_[i]->getW()->subMatrix(g * K, K);
MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K); MatrixPtr outGradMat = colBuf_->subMatrix(g * K, K);
MatrixPtr inGradMatSub = Matrix::create( MatrixPtr inGradMatSub =
preGrad, M, N, false, useGpu_); Matrix::create(preGrad, M, N, false, useGpu_);
inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 0.0); inGradMatSub->mul(*(w->getTranspose()), *outGradMat, 1.0, 1.0);
preGrad += M * N; preGrad += M * N;
} }
} }
}
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weights_[i]->getParameterPtr()->incUpdate(callback);
} }
} }
} }
void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropWeights(int i) { } void DeConv3DLayer::bpropData(int i) {}
void DeConv3DLayer::bpropData(int i) { }
void DeConv3DLayer::bpropBiases() { void DeConv3DLayer::bpropBiases() {
MatrixPtr outGradMat = getOutputGrad(); const MatrixPtr &outGradMat = getOutputGrad();
if (this->sharedBiases_) { if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册