未验证 提交 f686310d 编写于 作者: W Wilber 提交者: GitHub

fix concat_mkldnn op. test=develop (#22692)

fix concat_mkldnn op when encounter extreame conditions.
上级 de1b390b
......@@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator;
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()),
ctx.OutputName("Out"), dt, platform::ThreadIDasStr());
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册