From fddf442470da6d0b1c8ea480423d290fbf48e6a9 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Tue, 12 May 2020 04:33:00 +0200 Subject: [PATCH] add batch size to the mkldnn matmul cache key (#24408) test=develop --- paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index bc1a8522b0f..5ca0ed1182e 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -322,9 +322,10 @@ static std::shared_ptr> GetPrimitiveFactory( const ExecutionContext& ctx) { const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); + const auto batch_size = ctx.Input("X")->dims()[0]; const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), out_name); + platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key)); -- GitLab