未验证 提交 fddf4424 编写于 作者: W Wojciech Uss 提交者: GitHub

add batch size to the mkldnn matmul cache key (#24408)

test=develop
上级 53125c2f
......@@ -322,9 +322,10 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const ExecutionContext& ctx) {
const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), out_name);
platform::CreateKey(platform::ThreadIDasStr(), batch_size, out_name);
auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册