提交 3117d977 编写于 作者: T tensor-tang

add inputChannel in resetInValue for concat layer

上级 c397599d
......@@ -84,10 +84,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
inputs.resize(inputLayers_.size());
bool has8c = false, has16c = false, hasnc = false;
for (size_t i = 0; i < inputs.size(); i++) {
// resetInValue will use ic_ so temporary change as current input's channel
// TODO(TJ): change ic_ as vector then can remove channels_
ic_ = channels_[i];
resetInValue(inputs[i], nullptr, i);
resetInValue(inputs[i], nullptr, i, channels_[i]);
CHECK(inputs[i]);
auto dm = inputs[i]->getDims();
// inputs format can be different, but ndims must equal
......@@ -108,8 +105,6 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
has16c = true;
}
}
// change back, ic_ always save the input 0 size
ic_ = channels_[0];
format outFmt;
if (has16c && oc_ % 16 == 0) {
......
......@@ -176,13 +176,15 @@ void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn,
void MKLDNNLayer::resetInValue(
MKLDNNMatrixPtr& in,
const std::shared_ptr<memory::primitive_desc>& intPD,
size_t inputIdx) {
size_t inputIdx,
int inputChannel) {
cvtInVal_ = nullptr;
extInVal_ = nullptr;
in = nullptr;
CHECK_GT(bs_ * ic_ * ih_ * iw_, 0);
inputChannel = inputChannel == 0 ? ic_ : inputChannel;
CHECK_GT(bs_ * inputChannel * ih_ * iw_, 0);
auto extPD = MKLDNNMatrix::createPrimitiveDesc(
{bs_, ic_, ih_, iw_}, format::nchw, engine_);
{bs_, inputChannel, ih_, iw_}, format::nchw, engine_);
const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue();
extInVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr);
......
......@@ -38,6 +38,7 @@ protected:
size_t inputElemenCnt_;
// batch size
int bs_;
// they sizes are always from the first input layer
// input image channel, height and width
int ic_, ih_, iw_;
// output image channel, height and width
......@@ -196,11 +197,13 @@ protected:
/**
* reset input value from input MKLDNNMatrix and internal primitive desc.
* reset both internal and external buffer and create reorder if necessary.
* input channel may be different in concat.
*/
void resetInValue(
MKLDNNMatrixPtr& in,
const std::shared_ptr<mkldnn::memory::primitive_desc>& intPD = nullptr,
size_t inputIdx = 0);
size_t inputIdx = 0,
int inputChannel = 0);
/**
* reset output value from internal primitive desc.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册