提交 98b7c673 编写于 作者: T tensor-tang

add todo

上级 4cc57836
...@@ -184,15 +184,14 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -184,15 +184,14 @@ void MKLDNNFcLayer::resetBwd() {
const MatrixPtr& wgt = weight_->getWGrad(); const MatrixPtr& wgt = weight_->getWGrad();
const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr; const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr;
// TODO(TJ): merge topdiffs
if (nextIsMKLDNN()) { if (nextIsMKLDNN()) {
// can not directly cast outputgrad to mkldnnmatrix, // can not directly cast outputgrad to mkldnnmatrix,
// since each layer can not write the inputgrad to mkldnn inputgrad. // since each layer can not write the inputgrad to mkldnn inputgrad.
// So just create from matrix with outputvalue format. // So just create from matrix with outputvalue format.
const MatrixPtr& out = getOutput(MKLDNN_DEVICE).grad; const MatrixPtr& out = getOutput(MKLDNN_DEVICE).grad;
outGrad_ = MKLDNNMatrix::create(out, outVal_->getPD()); outGrad_ = MKLDNNMatrix::create(out, outVal_->getPD());
// TODO: maybe need merge topdiffs
} else { } else {
// TODO: merge topdiffs
const MatrixPtr& out = getOutput(CPU_DEVICE).grad; const MatrixPtr& out = getOutput(CPU_DEVICE).grad;
// fc do not need to convert from cpu device since output always nc // fc do not need to convert from cpu device since output always nc
// only need create from cpu device // only need create from cpu device
...@@ -234,8 +233,7 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -234,8 +233,7 @@ void MKLDNNFcLayer::resetBwd() {
return; return;
} }
if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
// TODO: many mkldnn bots // TODO(TJ): use outputMaps_ ways when merge topdiff done
// add sum handle
} else { } else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD()); inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD());
} }
...@@ -245,8 +243,7 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -245,8 +243,7 @@ void MKLDNNFcLayer::resetBwd() {
return; return;
} }
if (getInput(0, CPU_DEVICE).getAllCount() > 1) { if (getInput(0, CPU_DEVICE).getAllCount() > 1) {
// TODO: many bots // TODO(TJ): use outputMaps_ ways when merge topdiff done
// add sum handle
} else { } else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD()); inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册