提交 4ec3a77a 编写于 作者: T tensor-tang

should run resetBwd before bwdAct

上级 fa722385
...@@ -449,13 +449,14 @@ void MKLDNNConvLayer::resetOutGrad( ...@@ -449,13 +449,14 @@ void MKLDNNConvLayer::resetOutGrad(
cvtOutGrad_ = nullptr; cvtOutGrad_ = nullptr;
if (!outputIsOnlyMKLDNN()) { if (!outputIsOnlyMKLDNN()) {
const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad;
outMat->setData(cpuOut->getData());
// same PrimitiveDesc with cpuInVal_ // same PrimitiveDesc with cpuInVal_
CHECK(cpuOutVal_); CHECK(cpuOutVal_);
cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc()); cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc());
if (cpuOutGrad_->getPrimitiveDesc() == out->getPrimitiveDesc()) { if (cpuOutGrad_->getPrimitiveDesc() == out->getPrimitiveDesc()) {
outMat->setData(cpuOut->getData());
out = cpuOutGrad_; out = cpuOutGrad_;
} else { } else {
out = MKLDNNMatrix::create(nullptr, wgtPD->diff_dst_primitive_desc());
cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out);
CHECK(cvtOutGrad_); CHECK(cvtOutGrad_);
} }
......
...@@ -232,6 +232,7 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, ...@@ -232,6 +232,7 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) { void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) {
// TODO(TJ): merge outgrad // TODO(TJ): merge outgrad
int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
output_.grad->setData(getOutput(device).grad->getData());
// for MKLDNN device: // for MKLDNN device:
// can not directly cast outputgrad to mkldnnmatrix, // can not directly cast outputgrad to mkldnnmatrix,
// since each layer can not write the inputgrad to mkldnn inputgrad. // since each layer can not write the inputgrad to mkldnn inputgrad.
......
...@@ -141,18 +141,16 @@ public: ...@@ -141,18 +141,16 @@ public:
} }
void backward(const UpdateCallback& callback) override { void backward(const UpdateCallback& callback) override {
/* Do derivation */ { if (needResetBwd_) {
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
{
REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); REGISTER_TIMER_INFO("BpActTimer", getName().c_str());
backwardActivation(); backwardActivation();
} }
{ {
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
if (needResetBwd_) {
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
needResetBwd_ = false;
}
stream_->submit(pipelineBwd_); stream_->submit(pipelineBwd_);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册