提交 5d226743 编写于 作者: T tensor-tang

fix mkldnn concat dimension issue for rnn

上级 38c61053
...@@ -43,7 +43,7 @@ void MKLDNNConcatLayer::reshape( ...@@ -43,7 +43,7 @@ void MKLDNNConcatLayer::reshape(
channels_[0] = ic; channels_[0] = ic;
oc = ic; oc = ic;
for (size_t i = 1; i < inputLayers_.size(); i++) { for (size_t i = 1; i < inputLayers_.size(); i++) {
int batchsize, height, witdh; int batchsize = 0, height = 0, witdh = 0;
reshapeInput(batchsize, height, witdh, i); reshapeInput(batchsize, height, witdh, i);
CHECK_EQ(bs, batchsize); CHECK_EQ(bs, batchsize);
CHECK_EQ(ih, height); CHECK_EQ(ih, height);
...@@ -84,6 +84,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, ...@@ -84,6 +84,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
bool has8c = false, has16c = false, hasnc = false; bool has8c = false, has16c = false, hasnc = false;
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
resetInValue(inputs[i], nullptr, i, channels_[i]); resetInValue(inputs[i], nullptr, i, channels_[i]);
inputs[i]->downSpatial();
CHECK(inputs[i]); CHECK(inputs[i]);
auto dm = inputs[i]->getDims(); auto dm = inputs[i]->getDims();
// inputs format can be different, but ndims must equal // inputs format can be different, but ndims must equal
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册