未验证 提交 27e3b06f 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix submanifold conv (#45060)


* fix submanifold conv
上级 26c573de
......@@ -662,9 +662,9 @@ int ProductRuleBook(const Context& dev_ctx,
dev_ctx.stream());
dev_ctx.Wait();
size_t cache_size = kernel_size * 2 + kernel_size *
config.thread_per_block.x * 2 *
sizeof(int);
size_t cache_size =
kernel_size * 2 * sizeof(int) +
kernel_size * config.thread_per_block.x * 2 * sizeof(int);
const int MAX_CACHE_SIZE = 48 * 1024;
while (cache_size >= MAX_CACHE_SIZE) {
config.thread_per_block.x /= 2;
......@@ -672,7 +672,7 @@ int ProductRuleBook(const Context& dev_ctx,
PADDLE_ENFORCE_GE(config.thread_per_block.x,
32,
phi::errors::Fatal("the shared memory is not enough"));
cache_size = kernel_size * 2 +
cache_size = kernel_size * 2 * sizeof(int) +
kernel_size * config.thread_per_block.x * 2 * sizeof(int);
}
ProductSubmRuleBookKernel<IntT><<<config.block_per_grid.x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册