提交 f40d5f58 编写于 作者: T tensor-tang

remove syncOutputGrad, rename syncInputValue to updateInputData

上级 d4c07348
......@@ -220,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_);
/// backward data
device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
const MatrixPtr& in = getInputGrad(0, device);
const MatrixPtr& in = inputLayers_[0]->getOutput().grad;
if (in == nullptr) {
return;
}
if (getInput(0, device).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways when merge outgrad done
if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
}
......@@ -243,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdData_);
}
void MKLDNNFcLayer::updateInputData() {
if (inputLayers_[0]->getType() != "data") {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
inVal_->setData(iData);
}
void MKLDNNFcLayer::forward(PassType passType) {
Layer::forward(passType);
reshape();
{
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
syncInputValue();
updateInputData();
// just submit forward pipeline
stream_->submit(pipelineFwd_);
......@@ -271,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
resetBwd();
syncOutputGrad();
// just sumbmit backward pipeline
stream_->submit(pipelineBwd_);
}
......
......@@ -53,6 +53,8 @@ public:
void backward(const UpdateCallback& callback) override;
void updateInputData() override;
protected:
/**
* reshape the input image sizes
......
......@@ -113,6 +113,12 @@ public:
*/
virtual void convertWeightsToPaddle() {}
/**
* Update input value data when input layer is "data" type.
* Since the input value data address might be changed.
*/
virtual void updateInputData() {}
/**
* print info about sizes
*/
......@@ -194,32 +200,6 @@ protected:
return outputOtherDevice_.size() == 0;
}
/**
* Sync input value data
*/
void syncInputValue() {
if (inputIsOnlyMKLDNN()) {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
// update input data
// since it might be changed if this is after data layer
inVal_->setData(iData);
}
/**
* Sync output grad data
*/
void syncOutputGrad() {
if (outputIsOnlyMKLDNN()) {
return;
}
// update diff
real* oDiff = getOutput(CPU_DEVICE).grad->getData();
outGrad_->setData(oDiff);
}
/**
* Set deviceId of this layer.
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册