提交 f40d5f58 编写于 作者: T tensor-tang

remove syncOutputGrad, rename syncInputValue to updateInputData

上级 d4c07348
...@@ -220,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -220,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_); pipelineBwd_.push_back(*bwdWgt_);
/// backward data /// backward data
device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; const MatrixPtr& in = inputLayers_[0]->getOutput().grad;
const MatrixPtr& in = getInputGrad(0, device);
if (in == nullptr) { if (in == nullptr) {
return; return;
} }
if (getInput(0, device).getAllCount() > 1) { if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways when merge outgrad done // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else { } else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
} }
...@@ -243,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -243,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdData_); pipelineBwd_.push_back(*bwdData_);
} }
void MKLDNNFcLayer::updateInputData() {
if (inputLayers_[0]->getType() != "data") {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
inVal_->setData(iData);
}
void MKLDNNFcLayer::forward(PassType passType) { void MKLDNNFcLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
reshape(); reshape();
{ {
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
syncInputValue(); updateInputData();
// just submit forward pipeline // just submit forward pipeline
stream_->submit(pipelineFwd_); stream_->submit(pipelineFwd_);
...@@ -271,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { ...@@ -271,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
resetBwd(); resetBwd();
syncOutputGrad();
// just sumbmit backward pipeline // just sumbmit backward pipeline
stream_->submit(pipelineBwd_); stream_->submit(pipelineBwd_);
} }
......
...@@ -53,6 +53,8 @@ public: ...@@ -53,6 +53,8 @@ public:
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
void updateInputData() override;
protected: protected:
/** /**
* reshape the input image sizes * reshape the input image sizes
......
...@@ -113,6 +113,12 @@ public: ...@@ -113,6 +113,12 @@ public:
*/ */
virtual void convertWeightsToPaddle() {} virtual void convertWeightsToPaddle() {}
/**
* Update input value data when input layer is "data" type.
* Since the input value data address might be changed.
*/
virtual void updateInputData() {}
/** /**
* print info about sizes * print info about sizes
*/ */
...@@ -194,32 +200,6 @@ protected: ...@@ -194,32 +200,6 @@ protected:
return outputOtherDevice_.size() == 0; return outputOtherDevice_.size() == 0;
} }
/**
* Sync input value data
*/
void syncInputValue() {
if (inputIsOnlyMKLDNN()) {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
// update input data
// since it might be changed if this is after data layer
inVal_->setData(iData);
}
/**
* Sync output grad data
*/
void syncOutputGrad() {
if (outputIsOnlyMKLDNN()) {
return;
}
// update diff
real* oDiff = getOutput(CPU_DEVICE).grad->getData();
outGrad_->setData(oDiff);
}
/** /**
* Set deviceId of this layer. * Set deviceId of this layer.
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册