未验证 提交 cbeff5fc 编写于 作者: C Charles-hit 提交者: GitHub

support activation prim op bf16 dtype (#54193)

* support activation prim op bf16 dtype

* remove useless code
上级 2db64d08
......@@ -3083,11 +3083,15 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType minus_one_half = static_cast<MPType>(-0.5f);
// dx = -0.5 * dout * out^3
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return minus_one_half * dout * out * out * out;
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
return static_cast<T>(minus_one_half * dout * out * out * out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
......
......@@ -578,7 +578,7 @@ class PrimForwardChecker:
# forward comp only for comp op
if self.prim_op_type == "prim":
return
paddle.enable_static()
with paddle.fluid.framework._static_guard():
core._set_prim_forward_enabled(self.enable_fw_comp)
startup_program, main_program = (
paddle.static.Program(),
......@@ -1024,7 +1024,6 @@ class PrimGradChecker(PrimForwardChecker):
core.set_prim_eager_enabled(False)
def check_static_comp(self):
paddle.enable_static()
if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp)
else:
......@@ -1032,6 +1031,7 @@ class PrimGradChecker(PrimForwardChecker):
core._set_prim_backward_enabled(self.enable_rev_comp)
atol = self.rev_comp_atol if self.enable_rev_comp else self.fw_comp_atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else self.fw_comp_rtol
with paddle.fluid.framework._static_guard():
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
......@@ -1079,7 +1079,9 @@ class PrimGradChecker(PrimForwardChecker):
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
ret = paddle.static.gradients(
ys, xs, vs, no_grad_set=no_grad_vars
)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
......
......@@ -631,12 +631,13 @@ def rsqrt_composite(x):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype == "float16" or dtype == "uint16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
res = pow(x, y)
return res if not is_amp else cast(res, "float16")
return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('group_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册