未验证 提交 cbeff5fc 编写于 作者: C Charles-hit 提交者: GitHub

support activation prim op bf16 dtype (#54193)

* support activation prim op bf16 dtype

* remove useless code
上级 2db64d08
......@@ -3083,11 +3083,15 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType minus_one_half = static_cast<MPType>(-0.5f);
// dx = -0.5 * dout * out^3
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return minus_one_half * dout * out * out * out;
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
return static_cast<T>(minus_one_half * dout * out * out * out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
......
......@@ -578,45 +578,45 @@ class PrimForwardChecker:
# forward comp only for comp op
if self.prim_op_type == "prim":
return
paddle.enable_static()
core._set_prim_forward_enabled(self.enable_fw_comp)
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), ret
with paddle.fluid.framework._static_guard():
core._set_prim_forward_enabled(self.enable_fw_comp)
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), ret
)
# check static forward
if len(ret) != len(self.eager_desire):
msg = (
......@@ -1024,7 +1024,6 @@ class PrimGradChecker(PrimForwardChecker):
core.set_prim_eager_enabled(False)
def check_static_comp(self):
paddle.enable_static()
if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp)
else:
......@@ -1032,67 +1031,70 @@ class PrimGradChecker(PrimForwardChecker):
core._set_prim_backward_enabled(self.enable_rev_comp)
atol = self.rev_comp_atol if self.enable_rev_comp else self.fw_comp_atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else self.fw_comp_rtol
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
fw_outs = _as_list(self.public_python_api(*args))
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
primapi.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs, vs_feed = self.gen_static_grad_outputs_and_feed()
feed.update(vs_feed)
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
actual_ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), actual_ret
with paddle.fluid.framework._static_guard():
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments(
self.public_python_api,
static_inputs,
attrs,
self.kernel_sig,
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
fw_outs = _as_list(self.public_python_api(*args))
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
primapi.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs, vs_feed = self.gen_static_grad_outputs_and_feed()
feed.update(vs_feed)
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(
ys, xs, vs, no_grad_set=no_grad_vars
)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
actual_ret = paddle.utils.map_structure(
lambda x: convert_uint16_to_float(x), actual_ret
)
# check static grad out
if len(actual_ret) != len(self.eager_desire):
msg = (
......
......@@ -631,12 +631,13 @@ def rsqrt_composite(x):
is_amp = False
from paddle.fluid.data_feeder import convert_dtype
if convert_dtype(x.dtype) == "float16":
dtype = convert_dtype(x.dtype)
if dtype == "float16" or dtype == "uint16":
is_amp = True
x = cast(x, "float32")
y = full(x.shape if len(x.shape) == 0 else [1], -0.5, x.dtype)
res = pow(x, y)
return res if not is_amp else cast(res, "float16")
return res if not is_amp else cast(res, dtype)
@REGISTER_COMPOSITE('group_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册