未验证 提交 dd27996c 编写于 作者: S sneaxiy 提交者: GitHub

fix adam thread num (#48297)

上级 bcf75132
...@@ -253,7 +253,7 @@ void AdamDenseKernel(const Context& dev_ctx, ...@@ -253,7 +253,7 @@ void AdamDenseKernel(const Context& dev_ctx,
param.numel()); param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>( UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
beta1_pow.data<MPDType>(), beta1_pow.data<MPDType>(),
...@@ -352,7 +352,7 @@ void MergedAdamKernel( ...@@ -352,7 +352,7 @@ void MergedAdamKernel(
param[idx]->numel()); param[idx]->numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>( UpdateBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
beta1_pow[idx]->data<MPDType>(), beta1_pow[idx]->data<MPDType>(),
......
...@@ -282,7 +282,7 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -282,7 +282,7 @@ void AdamwDenseKernel(const Context& dev_ctx,
param.numel()); param.numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// Update with gpu // Update with gpu
UpdateAdamWBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>( UpdateAdamWBetaPow<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
beta1_, beta1_,
beta2_, beta2_,
beta1_pow.data<MPDType>(), beta1_pow.data<MPDType>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册