From d44d17306e8d00ffaa54ec3d4cafbe9cfc8bd2d8 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 25 Mar 2021 04:58:56 +0100 Subject: [PATCH] fix cache key in concat oneDNN kernel (#31820) (#31837) * fix cache key in concat oneDNN kernel * key simplified --- .../operators/mkldnn/concat_mkldnn_op.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 63aa2357bee..a69506076f9 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -71,6 +71,15 @@ static const std::vector ReduceMultiInput( return reduced; } +static const std::vector GetDimsForKey( + const std::vector& inputs) { + auto dims_key = paddle::framework::vectorize(inputs[0]->dims()); + for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) { + dims_key.push_back((*it)->dims()[0]); + } + return dims_key; +} + template class ConcatPrimitiveFactory { public: @@ -134,6 +143,8 @@ template class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { + // If any of the multiple inputs of concat has an input size of 0, the + // actual size of the multi_input will change auto multi_input = ReduceMultiInput(ctx.MultiInput("X")); EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); @@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); ConcatPrimitiveFactory prim_creator; - // If one of the multiple inputs of concat has an input size of 0, the - // actual size of the multi_input will change - std::string key = platform::CreateKey( - dev_ctx, paddle::framework::vectorize(multi_input[0]->dims()), - multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr()); + std::string key = + platform::CreateKey(dev_ctx, GetDimsForKey(multi_input), + multi_input.size(), ctx.OutputName("Out"), dt); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@concat_p"; -- GitLab