提交 a9490a10 编写于 作者: T tensor-tang

make output channels changeable in reshape function

上级 d2e30a2c
...@@ -38,7 +38,7 @@ bool MKLDNNAddtoLayer::init(const LayerMap& layerMap, ...@@ -38,7 +38,7 @@ bool MKLDNNAddtoLayer::init(const LayerMap& layerMap,
} }
void MKLDNNAddtoLayer::reshape( void MKLDNNAddtoLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
CHECK_EQ(layerSize_, getSize()) << "this layer size can not be changed"; CHECK_EQ(layerSize_, getSize()) << "this layer size can not be changed";
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw; ic = inputLayers_[0]->getSize() / ih / iw;
......
...@@ -50,7 +50,7 @@ public: ...@@ -50,7 +50,7 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
...@@ -116,7 +116,7 @@ void MKLDNNBatchNormLayer::calMovingMeanAndVar() { ...@@ -116,7 +116,7 @@ void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
} }
void MKLDNNBatchNormLayer::reshape( void MKLDNNBatchNormLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
oh = ih; oh = ih;
ow = iw; ow = iw;
......
...@@ -73,7 +73,7 @@ public: ...@@ -73,7 +73,7 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
...@@ -32,7 +32,7 @@ bool MKLDNNConcatLayer::init(const LayerMap& layerMap, ...@@ -32,7 +32,7 @@ bool MKLDNNConcatLayer::init(const LayerMap& layerMap,
} }
void MKLDNNConcatLayer::reshape( void MKLDNNConcatLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw; ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
...@@ -40,9 +40,7 @@ void MKLDNNConcatLayer::reshape( ...@@ -40,9 +40,7 @@ void MKLDNNConcatLayer::reshape(
CHECK_GT(inputLayers_.size(), 1UL); CHECK_GT(inputLayers_.size(), 1UL);
channels_.resize(inputLayers_.size()); channels_.resize(inputLayers_.size());
channels_[0] = ic; channels_[0] = ic;
// need change the output channel, so use oc_ instead oc = ic;
// TODO(TJ): change API, use &oc
oc_ = ic;
for (size_t i = 1; i < inputLayers_.size(); i++) { for (size_t i = 1; i < inputLayers_.size(); i++) {
int batchsize, height, witdh; int batchsize, height, witdh;
reshapeInput(batchsize, height, witdh, i); reshapeInput(batchsize, height, witdh, i);
...@@ -52,12 +50,12 @@ void MKLDNNConcatLayer::reshape( ...@@ -52,12 +50,12 @@ void MKLDNNConcatLayer::reshape(
channels_[i] = inputLayers_[i]->getSize() / height / witdh; channels_[i] = inputLayers_[i]->getSize() / height / witdh;
CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize()); CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize());
oc_ += channels_[i]; oc += channels_[i];
} }
oh = ih; oh = ih;
ow = iw; ow = iw;
reshapeOutput(oh, ow); reshapeOutput(oh, ow);
resizeOutput(bs, oc_ * oh * ow); resizeOutput(bs, oc * oh * ow);
} }
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
......
...@@ -47,7 +47,7 @@ public: ...@@ -47,7 +47,7 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
...@@ -90,7 +90,7 @@ void MKLDNNConvLayer::convertWeightsToPaddle() { ...@@ -90,7 +90,7 @@ void MKLDNNConvLayer::convertWeightsToPaddle() {
} }
void MKLDNNConvLayer::reshape( void MKLDNNConvLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
// cal output sizes // cal output sizes
......
...@@ -69,7 +69,7 @@ public: ...@@ -69,7 +69,7 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
...@@ -74,7 +74,7 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { ...@@ -74,7 +74,7 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
} }
void MKLDNNFcLayer::reshape( void MKLDNNFcLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize());
......
...@@ -52,7 +52,7 @@ public: ...@@ -52,7 +52,7 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
...@@ -55,6 +55,7 @@ void MKLDNNLayer::forward(PassType passType) { ...@@ -55,6 +55,7 @@ void MKLDNNLayer::forward(PassType passType) {
inputElemenCnt_ = elemenCnt; inputElemenCnt_ = elemenCnt;
pipelineFwd_.clear(); pipelineFwd_.clear();
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
printSizeInfo();
// all cpu device output grad or value share output's // all cpu device output grad or value share output's
shareCPUDevice(); shareCPUDevice();
resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_);
...@@ -72,7 +73,6 @@ void MKLDNNLayer::forward(PassType passType) { ...@@ -72,7 +73,6 @@ void MKLDNNLayer::forward(PassType passType) {
pipelineFwd_.push_back(*cvtOutVal_); pipelineFwd_.push_back(*cvtOutVal_);
} }
convertWeightsFromPaddle(); convertWeightsFromPaddle();
printSizeInfo();
printValueFormat(); printValueFormat();
needResetBwd_ = true; needResetBwd_ = true;
} }
......
...@@ -125,12 +125,11 @@ public: ...@@ -125,12 +125,11 @@ public:
virtual void backward(const UpdateCallback& callback); virtual void backward(const UpdateCallback& callback);
/** /**
* reshape the input image sizes * reshape the input and output channels and image sizes
* and reset output image and buffer size * and reset output buffer size
* output channel can not be changed
*/ */
virtual void reshape( virtual void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) = 0; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) = 0;
/** /**
* reset the mkldnn forward primitve and memories * reset the mkldnn forward primitve and memories
......
...@@ -58,7 +58,7 @@ bool MKLDNNPoolLayer::init(const LayerMap& layerMap, ...@@ -58,7 +58,7 @@ bool MKLDNNPoolLayer::init(const LayerMap& layerMap,
} }
void MKLDNNPoolLayer::reshape( void MKLDNNPoolLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
// ic_ and oc can not be changed // ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
......
...@@ -53,7 +53,7 @@ public: ...@@ -53,7 +53,7 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册