From e5f7a834d4200ad9d7e8b748d2d96fc7faeb0e63 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 24 Mar 2021 08:41:47 +0100 Subject: [PATCH] fix cache key in concat oneDNN kernel (#31820) * fix cache key in concat oneDNN kernel * key simplified --- .../operators/mkldnn/concat_mkldnn_op.cc | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 4beb7ad0178..df1b5af121d 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -71,6 +71,15 @@ static const std::vector ReduceMultiInput( return reduced; } +static const std::vector GetDimsForKey( + const std::vector& inputs) { + auto dims_key = paddle::framework::vectorize(inputs[0]->dims()); + for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) { + dims_key.push_back((*it)->dims()[0]); + } + return dims_key; +} + template class ConcatPrimitiveFactory { public: @@ -134,6 +143,8 @@ template class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { + // If any of the multiple inputs of concat has an input size of 0, the + // actual size of the multi_input will change auto multi_input = ReduceMultiInput(ctx.MultiInput("X")); EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); @@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); ConcatPrimitiveFactory prim_creator; - // If one of the multiple inputs of concat has an input size of 0, the - // actual size of the multi_input will change - std::string key = platform::CreateKey( - dev_ctx, paddle::framework::vectorize(multi_input[0]->dims()), - multi_input.size(), ctx.OutputName("Out"), dt, - platform::ThreadIDasStr()); + std::string key = + platform::CreateKey(dev_ctx, GetDimsForKey(multi_input), + multi_input.size(), ctx.OutputName("Out"), dt); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); const std::string key_prim = key + "@concat_p"; -- GitLab