未验证 提交 8ef020c1 编写于 作者: N niuliling123 提交者: GitHub

Fix type error in adagrad_kernel (#51790)

上级 57e368b8
...@@ -37,7 +37,7 @@ __global__ void AdagradGPUKernel(const T* param, ...@@ -37,7 +37,7 @@ __global__ void AdagradGPUKernel(const T* param,
MT* master_param_out, MT* master_param_out,
int num) { int num) {
auto idx = blockDim.x * blockIdx.x + threadIdx.x; auto idx = blockDim.x * blockIdx.x + threadIdx.x;
MT lr_data = static_cast<T>(lr[0]); MT lr_data = static_cast<MT>(lr[0]);
for (int i = idx; i < num; i += blockDim.x * gridDim.x) { for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
MT grad_data = static_cast<MT>(grad[i]); MT grad_data = static_cast<MT>(grad[i]);
...@@ -47,7 +47,7 @@ __global__ void AdagradGPUKernel(const T* param, ...@@ -47,7 +47,7 @@ __global__ void AdagradGPUKernel(const T* param,
MT param_out_data = MT param_out_data =
in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon); in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon);
param_out[i] = static_cast<MT>(param_out_data); param_out[i] = static_cast<T>(param_out_data);
if (master_param_out) { if (master_param_out) {
master_param_out[i] = param_out_data; master_param_out[i] = param_out_data;
......
...@@ -2081,7 +2081,11 @@ class AdagradOptimizer(Optimizer): ...@@ -2081,7 +2081,11 @@ class AdagradOptimizer(Optimizer):
for p in parameters: for p in parameters:
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._moment_acc_str, master_p) self._add_accumulator(
self._moment_acc_str,
master_p,
fill_value=self.initial_accumulator_value,
)
continue continue
if ( if (
self._is_dtype_fp16_or_bf16(p.dtype) self._is_dtype_fp16_or_bf16(p.dtype)
......
...@@ -146,7 +146,11 @@ class Adagrad(Optimizer): ...@@ -146,7 +146,11 @@ class Adagrad(Optimizer):
continue continue
if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_accumulator(self._moment_acc_str, master_p) self._add_accumulator(
self._moment_acc_str,
master_p,
fill_value=self.initial_accumulator_value,
)
self._already_create_accumulater.add(p.name) self._already_create_accumulater.add(p.name)
continue continue
if ( if (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册