未验证 提交 b7b231a6 编写于 作者: S sneaxiy 提交者: GitHub

support pure bfloat16 for more ops (#46364)

* support pure bfloat16

* support bf16 linear

* update PR to pass CI

* tiny fix where_grad_kernel.cu

* add bfloat16 to selu_grad to pass CI

* fix selu grad compilation error
上级 9012787f
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -63,6 +64,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -63,6 +64,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
if (std::is_same<T, paddle::platform::float16>::value) { if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F; mat_type = CUDA_R_16F;
} }
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F; mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F; scale_type = CUDA_R_64F;
...@@ -354,6 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -354,6 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
if (std::is_same<T, paddle::platform::float16>::value) { if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F; mat_type = CUDA_R_16F;
} }
if (std::is_same<T, platform::bfloat16>::value) {
mat_type = CUDA_R_16BF;
}
if (std::is_same<T, double>::value) { if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F; mat_type = CUDA_R_64F;
scale_type = CUDA_R_64F; scale_type = CUDA_R_64F;
...@@ -688,12 +695,14 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -688,12 +695,14 @@ REGISTER_OP_CUDA_KERNEL(
fused_gemm_epilogue, fused_gemm_epilogue,
ops::FusedGemmEpilogueKernel<phi::GPUContext, float>, ops::FusedGemmEpilogueKernel<phi::GPUContext, float>,
ops::FusedGemmEpilogueKernel<phi::GPUContext, double>, ops::FusedGemmEpilogueKernel<phi::GPUContext, double>,
ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::float16>); ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::float16>,
ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::bfloat16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_gemm_epilogue_grad, fused_gemm_epilogue_grad,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float>, ops::FusedGemmEpilogueGradKernel<phi::GPUContext, float>,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, double>, ops::FusedGemmEpilogueGradKernel<phi::GPUContext, double>,
ops::FusedGemmEpilogueGradKernel<phi::GPUContext, ops::FusedGemmEpilogueGradKernel<phi::GPUContext,
paddle::platform::float16>); paddle::platform::float16>,
ops::FusedGemmEpilogueKernel<phi::GPUContext, paddle::platform::bfloat16>);
#endif #endif
...@@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr, ...@@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
T value) { T value) {
CudaAtomicAdd(arr + index, value); CudaAtomicAdd(arr + index, value);
} }
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template <typename T,
typename std::enable_if<
!std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template <typename T,
typename std::enable_if<
std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
#else
int i = 0;
int loops = len / 2 * 2;
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);
if (aligned_half2) {
for (i = tid * 2; i < loops; i += threads_per_block * 2) {
__half2 value2;
T value_1 = in[i];
T value_2 = in[i + 1];
value2.x = *reinterpret_cast<__half *>(&value_1);
value2.y = *reinterpret_cast<__half *>(&value_2);
atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
}
for (; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
} else {
for (int i = tid; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
}
#endif
}
#endif
#endif #endif
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
...@@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { ...@@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
} }
#endif #endif
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_WITH_CUDA
/*
* One thead block deals with elementwise atomicAdd for vector of len.
* @in: [x1, x2, x3, ...]
* @out:[y1+x1, y2+x2, y3+x3, ...]
* */
template <typename T,
typename std::enable_if<
!std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
}
// Note: assume that len is even. If len is odd, call fastAtomicAdd directly.
template <typename T,
typename std::enable_if<
std::is_same<platform::float16, T>::value>::type * = nullptr>
__device__ __forceinline__ void VectorizedAtomicAddPerBlock(
const int64_t len, int tid, int threads_per_block, const T *in, T *out) {
#if ((CUDA_VERSION < 10000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
for (int i = tid; i < len; i += threads_per_block) {
CudaAtomicAdd(&out[i], in[i]);
}
#else
int i = 0;
int loops = len / 2 * 2;
bool aligned_half2 =
(reinterpret_cast<std::uintptr_t>(out) % sizeof(__half2) == 0);
if (aligned_half2) {
for (i = tid * 2; i < loops; i += threads_per_block * 2) {
__half2 value2;
T value_1 = in[i];
T value_2 = in[i + 1];
value2.x = *reinterpret_cast<__half *>(&value_1);
value2.y = *reinterpret_cast<__half *>(&value_2);
atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2);
}
for (; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
} else {
for (int i = tid; i < len; i += threads_per_block) {
fastAtomicAdd(out, i, len, in[i]);
}
}
#endif
}
#endif
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty, ...@@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty,
int64_t, int64_t,
bool, bool,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
......
...@@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor<T> { ...@@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor<T> {
} }
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
T res = x; using MT =
if (res <= zero) { typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type;
MT res = static_cast<MT>(x);
if (x <= zero) {
res = alpha * expf(res) - alpha; res = alpha * expf(res) - alpha;
} }
res *= scale; res *= scale;
return res; return static_cast<T>(res);
} }
private: private:
......
...@@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int); ...@@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int);
INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcast, int64_t);
INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, bool);
INSTANTIATION(EigenBroadcastGrad, float); INSTANTIATION(EigenBroadcastGrad, float);
INSTANTIATION(EigenBroadcastGrad, dtype::bfloat16);
INSTANTIATION(EigenBroadcastGrad, dtype::float16); INSTANTIATION(EigenBroadcastGrad, dtype::float16);
INSTANTIATION(EigenBroadcastGrad, double); INSTANTIATION(EigenBroadcastGrad, double);
INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>); INSTANTIATION(EigenBroadcastGrad, dtype::complex<float>);
......
...@@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad, ...@@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow, ...@@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
PD_REGISTER_KERNEL(selu, GPU, ALL_LAYOUT, phi::SeluKernel, float, double) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(selu,
GPU,
ALL_LAYOUT,
phi::SeluKernel,
float,
double,
phi::dtype::bfloat16) {}
...@@ -372,7 +372,8 @@ PD_REGISTER_KERNEL(adam, ...@@ -372,7 +372,8 @@ PD_REGISTER_KERNEL(adam,
phi::AdamDenseKernel, phi::AdamDenseKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow, skip_update data transform // Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
...@@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(merged_adam, ...@@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(merged_adam,
phi::MergedAdamKernel, phi::MergedAdamKernel,
float, float,
double, double,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
// Skip beta1_pow, beta2_pow data transform // Skip beta1_pow, beta2_pow data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
......
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -256,7 +256,8 @@ PD_REGISTER_KERNEL(embedding_grad, ...@@ -256,7 +256,8 @@ PD_REGISTER_KERNEL(embedding_grad,
phi::EmbeddingGradKernel, phi::EmbeddingGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(embedding_sparse_grad, PD_REGISTER_KERNEL(embedding_sparse_grad,
GPU, GPU,
...@@ -264,4 +265,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, ...@@ -264,4 +265,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad,
phi::EmbeddingSparseGradKernel, phi::EmbeddingSparseGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding, ...@@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding,
phi::EmbeddingKernel, phi::EmbeddingKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad, ...@@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad,
phi::GeluGradKernel, phi::GeluGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu, ...@@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu,
phi::GeluKernel, phi::GeluKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad, ...@@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad,
phi::Pad3dGradKernel, phi::Pad3dGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d, ...@@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d,
ALL_LAYOUT, ALL_LAYOUT,
phi::Pad3dKernel, phi::Pad3dKernel,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad, ...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::PixelShuffleGradKernel, phi::PixelShuffleGradKernel,
float, float,
double) {} double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -18,5 +18,11 @@ ...@@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h" #include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(pixel_shuffle,
pixel_shuffle, GPU, ALL_LAYOUT, phi::PixelShuffleKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::PixelShuffleKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -18,5 +18,10 @@ ...@@ -18,5 +18,10 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h" #include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(selu_grad,
selu_grad, GPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::SeluGradKernel,
float,
double,
phi::dtype::bfloat16) {}
...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad, ...@@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel( ...@@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel(
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < N; idx += blockDim.x * gridDim.x) { for (; idx < N; idx += blockDim.x * gridDim.x) {
if (dx != nullptr) { if (dx != nullptr) {
dx[idx] = cond[idx] ? dout[idx] : 0.; dx[idx] = cond[idx] ? dout[idx] : static_cast<T>(0.);
} }
if (dy != nullptr) { if (dy != nullptr) {
dy[idx] = cond[idx] ? 0. : dout[idx]; dy[idx] = cond[idx] ? static_cast<T>(0.) : dout[idx];
} }
} }
} }
...@@ -61,6 +61,8 @@ PD_REGISTER_KERNEL(where_grad, ...@@ -61,6 +61,8 @@ PD_REGISTER_KERNEL(where_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::WhereGradKernel, phi::WhereGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float, float,
double, double,
int, int,
......
...@@ -45,5 +45,13 @@ void WhereKernel(const Context& ctx, ...@@ -45,5 +45,13 @@ void WhereKernel(const Context& ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(where,
where, GPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {} GPU,
ALL_LAYOUT,
phi::WhereKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -57,14 +57,17 @@ struct SeluGradFunctor { ...@@ -57,14 +57,17 @@ struct SeluGradFunctor {
dx_data_ptr_(dx_data_ptr) {} dx_data_ptr_(dx_data_ptr) {}
HOSTDEVICE void operator()(size_t idx) const { HOSTDEVICE void operator()(size_t idx) const {
T y_ele = y_data_ptr_[idx]; using MT =
T dy_ele = dy_data_ptr_[idx]; typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type;
float tmp = scale_; auto y_ele = static_cast<MT>(y_data_ptr_[idx]);
auto dy_ele = static_cast<MT>(dy_data_ptr_[idx]);
auto tmp = static_cast<MT>(scale_);
if (y_ele <= 0) { if (y_ele <= 0) {
tmp = y_ele + la_; tmp = y_ele + static_cast<MT>(la_);
} }
dx_data_ptr_[idx] = dy_ele * tmp; dx_data_ptr_[idx] = static_cast<T>(dy_ele * tmp);
} }
const T* y_data_ptr_; const T* y_data_ptr_;
const T* dy_data_ptr_; const T* dy_data_ptr_;
......
...@@ -50,8 +50,9 @@ def _clip_by_global_norm_using_mp_type(*args): ...@@ -50,8 +50,9 @@ def _clip_by_global_norm_using_mp_type(*args):
def _cast_to_mp_type_if_enabled(x): def _cast_to_mp_type_if_enabled(x):
if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type( if (x.dtype == core.VarDesc.VarType.FP16
): or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
return x.astype(core.VarDesc.VarType.FP32) return x.astype(core.VarDesc.VarType.FP32)
else: else:
return x return x
...@@ -63,7 +64,8 @@ def _squared_l2_norm(x): ...@@ -63,7 +64,8 @@ def _squared_l2_norm(x):
""" """
x = _cast_to_mp_type_if_enabled(x) x = _cast_to_mp_type_if_enabled(x)
if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16: if core.is_compiled_with_xpu(
) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16:
square = layers.square(x) square = layers.square(x)
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
return sum_square return sum_square
...@@ -499,7 +501,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -499,7 +501,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16: if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16:
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32: elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square) sum_square_list_fp32.append(sum_square)
...@@ -552,8 +554,8 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -552,8 +554,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue continue
# TODO(wangxi): use inplace elementwise_mul # TODO(wangxi): use inplace elementwise_mul
if need_clip: if need_clip:
clip_input = (clip_var.astype('float16') if g.dtype clip_input = (clip_var.astype(g.dtype)
== core.VarDesc.VarType.FP16 else clip_var) if clip_var.dtype != g.dtype else clip_var)
new_grad = layers.elementwise_mul(g, clip_input) new_grad = layers.elementwise_mul(g, clip_input)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
else: else:
......
...@@ -275,7 +275,7 @@ class Adam(Optimizer): ...@@ -275,7 +275,7 @@ class Adam(Optimizer):
def _add_moments_pows(self, p): def _add_moments_pows(self, p):
acc_dtype = p.dtype acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16: if acc_dtype == core.VarDesc.VarType.FP16 or acc_dtype == core.VarDesc.VarType.BF16:
acc_dtype = core.VarDesc.VarType.FP32 acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
......
...@@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): ...@@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None):
u = mean(x, axis, True, name) u = mean(x, axis, True, name)
out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name) out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name)
n = paddle.cast(paddle.numel(x), x.dtype) \ dtype = x.dtype
/ paddle.cast(paddle.numel(out), x.dtype) n = paddle.cast(paddle.numel(x), paddle.int64) \
/ paddle.cast(paddle.numel(out), paddle.int64)
n = n.astype(dtype)
if unbiased: if unbiased:
one_const = paddle.ones([1], x.dtype) one_const = paddle.ones([1], x.dtype)
n = where(n > one_const, n - 1., one_const) n = where(n > one_const, n - 1., one_const)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册