未验证 提交 27e3b06f 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix submanifold conv (#45060)


* fix submanifold conv
上级 26c573de
...@@ -662,9 +662,9 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -662,9 +662,9 @@ int ProductRuleBook(const Context& dev_ctx,
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait(); dev_ctx.Wait();
size_t cache_size = kernel_size * 2 + kernel_size * size_t cache_size =
config.thread_per_block.x * 2 * kernel_size * 2 * sizeof(int) +
sizeof(int); kernel_size * config.thread_per_block.x * 2 * sizeof(int);
const int MAX_CACHE_SIZE = 48 * 1024; const int MAX_CACHE_SIZE = 48 * 1024;
while (cache_size >= MAX_CACHE_SIZE) { while (cache_size >= MAX_CACHE_SIZE) {
config.thread_per_block.x /= 2; config.thread_per_block.x /= 2;
...@@ -672,7 +672,7 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -672,7 +672,7 @@ int ProductRuleBook(const Context& dev_ctx,
PADDLE_ENFORCE_GE(config.thread_per_block.x, PADDLE_ENFORCE_GE(config.thread_per_block.x,
32, 32,
phi::errors::Fatal("the shared memory is not enough")); phi::errors::Fatal("the shared memory is not enough"));
cache_size = kernel_size * 2 + cache_size = kernel_size * 2 * sizeof(int) +
kernel_size * config.thread_per_block.x * 2 * sizeof(int); kernel_size * config.thread_per_block.x * 2 * sizeof(int);
} }
ProductSubmRuleBookKernel<IntT><<<config.block_per_grid.x, ProductSubmRuleBookKernel<IntT><<<config.block_per_grid.x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册