diff --git a/paddle/phi/kernels/gpu/adagrad_kernel.cu b/paddle/phi/kernels/gpu/adagrad_kernel.cu index 14046ed4491c62d32e356ee30734d449cf22342b..1ce78b7e0f6eaeb2ebc3c3b82054d6c4e1e4c50c 100644 --- a/paddle/phi/kernels/gpu/adagrad_kernel.cu +++ b/paddle/phi/kernels/gpu/adagrad_kernel.cu @@ -37,7 +37,7 @@ __global__ void AdagradGPUKernel(const T* param, MT* master_param_out, int num) { auto idx = blockDim.x * blockIdx.x + threadIdx.x; - MT lr_data = static_cast(lr[0]); + MT lr_data = static_cast(lr[0]); for (int i = idx; i < num; i += blockDim.x * gridDim.x) { MT grad_data = static_cast(grad[i]); @@ -47,7 +47,7 @@ __global__ void AdagradGPUKernel(const T* param, MT param_out_data = in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon); - param_out[i] = static_cast(param_out_data); + param_out[i] = static_cast(param_out_data); if (master_param_out) { master_param_out[i] = param_out_data; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index d835a3fbfcf8ba6aff16c7f3638d13634a5f27e9..ded8883ffdb533093c1c77162f0093af3a71662b 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2081,7 +2081,11 @@ class AdagradOptimizer(Optimizer): for p in parameters: if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) - self._add_accumulator(self._moment_acc_str, master_p) + self._add_accumulator( + self._moment_acc_str, + master_p, + fill_value=self.initial_accumulator_value, + ) continue if ( self._is_dtype_fp16_or_bf16(p.dtype) diff --git a/python/paddle/optimizer/adagrad.py b/python/paddle/optimizer/adagrad.py index 3d2935c74073199a0e44d21c5c8ef182c572a724..c19b3116de3fcff66fd058f2633a5ef71b673210 100644 --- a/python/paddle/optimizer/adagrad.py +++ b/python/paddle/optimizer/adagrad.py @@ -146,7 +146,11 @@ class Adagrad(Optimizer): continue if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): master_p = self._create_master_weight(p) - self._add_accumulator(self._moment_acc_str, master_p) + self._add_accumulator( + self._moment_acc_str, + master_p, + fill_value=self.initial_accumulator_value, + ) self._already_create_accumulater.add(p.name) continue if (