diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index bc1a8522b0fa86f6e186368f42cc426a20fecbe0..5ca0ed1182e74e681e9e36e55b61f58b5da66170 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -322,9 +322,10 @@ static std::shared_ptr> GetPrimitiveFactory( const ExecutionContext& ctx) { const auto& out_name = ctx.OutputName("Out"); const auto& dev_ctx = ctx.template device_context(); + const auto batch_size = ctx.Input("X")->dims()[0]; const std::string key = - platform::CreateKey(platform::ThreadIDasStr(), out_name); + platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name); auto factory = std::static_pointer_cast>(dev_ctx.GetBlob(key));