未验证 提交 dd27996c 编写于 作者: S sneaxiy 提交者: GitHub

fix adam thread num (#48297)

上级 bcf75132
......@@ -253,7 +253,7 @@ void AdamDenseKernel(const Context& dev_ctx,
param.numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow.data<MPDType>(),
......@@ -352,7 +352,7 @@ void MergedAdamKernel(
param[idx]->numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow[idx]->data<MPDType>(),
......
......@@ -282,7 +282,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
param.numel());
if (!use_global_beta_pow) {
// Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_,
beta2_,
beta1_pow.data<MPDType>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册