未验证 提交 674aa061 编写于 作者: W Wilber 提交者: GitHub

fix concat mkldnn in extream condition. test=develop test=release/1.7 (#22877)

上级 143023ba
...@@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -142,9 +142,12 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator; ConcatPrimitiveFactory<T> prim_creator;
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
std::string key = platform::CreateKey( std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()), paddle::framework::vectorize<int>(multi_input[0]->dims()),
ctx.OutputName("Out"), dt, platform::ThreadIDasStr()); multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
const std::string key_prim = key + "@concat_p"; const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd"; const std::string key_concat_pd = key + "@concat_pd";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册