未验证 提交 71a513c2 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] support promote kernel for static graph (#52514)

* support promote dtype for static amp training

* unify o1 and o2

* update for unittest

* fix op_role

* add use_promote arg

* fix doc

* add promote unittest

* polish unittests

* fix controflow and test
上级 040f8aa5
......@@ -4982,8 +4982,8 @@ class PipelineOptimizer:
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or op.type == "scale") and self._is_backward_op(
op
elif (op.type == "cast" or op.type == "scale") and (
self._is_backward_op(op) or self._is_forward_op(op)
):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
......
......@@ -356,7 +356,9 @@ class TestAdadeltaMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -467,7 +469,9 @@ class TestAdadeltaMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -322,7 +322,9 @@ class TestAdagradMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -431,7 +433,9 @@ class TestAdagradMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -1235,7 +1235,9 @@ class TestMultiTensorAdam(unittest.TestCase):
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -352,7 +352,9 @@ class TestAdamaxMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -459,7 +461,9 @@ class TestAdamaxMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -1059,7 +1059,9 @@ class TestMultiTensorMomentumStatic(unittest.TestCase):
optimizer.minimize(loss)
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place=place, scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = numpy.random.random(size=(2, 2)).astype('float16')
else:
x = numpy.random.random(size=(2, 2)).astype('float32')
......
......@@ -474,7 +474,9 @@ class TestRMSPropMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -585,7 +587,9 @@ class TestRMSPropMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -382,7 +382,9 @@ class TestSGDMultiPrecision2_0(unittest.TestCase):
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......@@ -492,7 +494,9 @@ class TestSGDMultiPrecision1_0(unittest.TestCase):
exe.run(startup_program)
if mp:
optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
optimizer.amp_init(
place=paddle.CUDAPlace(0), scope=paddle.static.global_scope()
)
x = np.random.random(size=(2, 2)).astype('float16')
else:
x = np.random.random(size=(2, 2)).astype('float32')
......
......@@ -294,8 +294,8 @@ class PartialProgramLayer:
def _create_amp_program(self, is_infer_mode=False):
amp_program = self._origin_main_program.clone(for_test=is_infer_mode)
with program_guard(amp_program):
paddle.static.amp.fp16_utils.rewrite_program(
amp_program, self._amp_list
paddle.static.amp.fp16_utils.cast_model_to_fp16(
amp_program, self._amp_list, use_fp16_guard=False, level='O1'
)
if is_infer_mode:
if self._hooker:
......
......@@ -29,7 +29,6 @@ from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import (
cast_model_to_fp16,
cast_parameters_to_fp16,
rewrite_program,
update_role_var_grad,
)
from .function_overload import FunctionType, overload
......@@ -67,6 +66,7 @@ class OptimizerWithMixedPrecision:
the loss scaling.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`.
use_promote(bool): Whether to promotes to fp32 when op has any float32 inputs. Default is False.
"""
def __init__(
......@@ -82,6 +82,7 @@ class OptimizerWithMixedPrecision:
incr_ratio,
decr_ratio,
use_amp_guard=None,
use_promote=False,
):
self._optimizer = optimizer
self._amp_lists = amp_lists
......@@ -116,6 +117,7 @@ class OptimizerWithMixedPrecision:
self._decr_ratio = decr_ratio
self._num_good_steps = None
self._num_bad_steps = None
self.use_promote = use_promote
def _set_distributed(self, flag):
# if distributed, all cards will communication with each other,
......@@ -231,10 +233,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
level='O2',
use_promote=self.use_promote,
)
else:
rewrite_program(
self._train_program, self._amp_lists, self._amp_vartype
# use_fp16_guard is not support amp-o1.
cast_model_to_fp16(
self._train_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
)
if loss.dtype != core.VarDesc.VarType.FP32:
......@@ -362,10 +372,18 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
level='O2',
use_promote=self.use_promote,
)
elif use_fp16_test:
rewrite_program(
test_program, self._amp_lists, self._amp_vartype
# use_fp16_guard is not support amp-o1.
cast_model_to_fp16(
test_program,
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
use_promote=self.use_promote,
)
def apply_gradients(self, params_grads):
......@@ -624,6 +642,7 @@ def decorate(
use_pure_fp16=False,
use_fp16_guard=None,
use_bf16=False,
use_promote=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -736,6 +755,7 @@ def decorate(
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard,
use_promote=use_promote,
)
return mp_optimizer
......@@ -754,6 +774,7 @@ def decorate(
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
use_promote=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -781,6 +802,7 @@ def decorate(
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard,
use_promote=use_promote,
)
return mp_optimizer
......@@ -98,6 +98,20 @@ def _get_sys_unsupported_list(dtype):
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list
......@@ -108,6 +122,29 @@ def _get_unsupported_list(dtype):
return unsupported_list
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
_only_supported_fp16_list = {'resnet_unit', 'fused_bn_add_activation'}
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
def _get_white_list(dtype):
white_list_for_dtype = copy.copy(white_list)
if dtype == 'float16':
white_list_for_dtype = white_list_for_dtype | _only_supported_fp16_list
return white_list_for_dtype
class AutoMixedPrecisionLists:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
......@@ -132,7 +169,7 @@ class AutoMixedPrecisionLists:
self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
......@@ -143,6 +180,9 @@ class AutoMixedPrecisionLists:
"""
Update black and white list according to users' custom list.
"""
_logger.debug(f"---- custom_white_list {self._custom_white_list} ---- ")
_logger.debug(f"---- custom_black_list {self._custom_black_list} ---- ")
_logger.debug(f"---- custom_black_varnames {self.black_varnames} ---- ")
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
......@@ -177,18 +217,6 @@ class AutoMixedPrecisionLists:
)
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import numpy as np
......@@ -22,7 +21,11 @@ from paddle.fluid import core, framework, global_scope
from paddle.fluid.log_helper import get_logger
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from .fp16_lists import AutoMixedPrecisionLists, get_low_precision_dtypestr
from .fp16_lists import (
AutoMixedPrecisionLists,
black_list,
get_low_precision_dtypestr,
)
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -144,7 +147,7 @@ def _keep_fp32_output(op, out_name):
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
Insert cast op and rename op's input.
Args:
block (Program): The block in which the operator is.
......@@ -167,8 +170,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
# op's input is already casted to dest_dtype before. Set the in_var.name to cast_name.
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
casted_var = block._find_var_recursive(cast_name)
if casted_var and casted_var.dtype == dest_dtype:
_rename_arg(op, in_var.name, casted_var.name)
continue
# insert cast for op's input.
if in_var.dtype == src_dtype:
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
op_device = op.attr('op_device')
......@@ -206,6 +216,13 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
stop_gradient=in_var.stop_gradient,
)
# Only forward program will be inserted cast op, but some ops
# has no op_role attr, so here set it direcly. eg. resnet_unit.
op_role = (
int(core.op_proto_and_checker_maker.OpRole.Forward)
if not op.has_attr('op_role')
else op.attr('op_role')
)
block._insert_op_without_sync(
idx,
type="cast",
......@@ -215,70 +232,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
"op_device": op_device,
"op_role": op.attr("op_role"),
"op_role": op_role,
},
)
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', dest_dtype)
return num_cast_ops
def _insert_cast_post_op(
block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map
):
num_cast_ops = 0
target_var = block.var(target_name)
if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
return num_cast_ops
assert (
target_var.dtype == src_dtype
), "The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)
)
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient,
)
block._insert_op(
idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
},
)
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(op.attr(attr_name)):
op._set_attr(attr_name, dest_dtype)
return num_cast_ops
......@@ -420,11 +382,204 @@ def fp16_guard():
yield
def is_float_dtype(dtype):
return (
dtype == core.VarDesc.VarType.FP32
or dtype == core.VarDesc.VarType.FP16
or dtype == core.VarDesc.VarType.BF16
or dtype == core.VarDesc.VarType.FP64
)
def set_var_dst_dtype(
op, var_names, block, global_block, dtype, need_set_dtype
):
low_precison_var_names = set()
for var_name in var_names:
var = None
try:
var = block._var_recursive(var_name)
except ValueError as e:
_logger.debug(f"-- {e}, try to get it in the global block --")
var = global_block.var(var_name)
if var is not None:
_logger.debug(
f"-- var {var_name} is got in the global block --"
)
if var is None or var.type not in _valid_types:
continue
if is_float_dtype(var.dtype):
low_precison_var_names.add(var_name)
if need_set_dtype:
var.desc.set_dtype(dtype)
_logger.debug(
"---- op type: {}, var name: {}, var dtype: {} ----".format(
op.type, var_name, var.dtype
)
)
return low_precison_var_names
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if level == "O1":
return
keep_fp32_var_names = set()
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
ops = block.ops
for op in ops:
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
for in_name in op.input_names:
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
)
else:
for in_name in op.input_names:
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name
):
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
)
for param in all_parameters:
if param.name not in keep_fp32_var_names:
_logger.debug(f"-- set param {param.name} to {dtype} --.")
param.desc.set_dtype(dtype)
def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
need_keep_fp32 = False
if _need_keep_fp32(
op,
amp_lists.unsupported_list,
use_fp16_guard,
):
need_keep_fp32 = True
elif amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists
):
need_keep_fp32 = True
elif op.type in amp_lists.black_list:
need_keep_fp32 = True
return need_keep_fp32
def get_promote_dtype(op, amp_dtype, block):
dst_dtype = amp_dtype
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name):
_logger.debug(
"---- Input {} {} should be kept fp32 ----".format(
in_name, op.input(in_name)
)
)
continue
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var and in_var.dtype == core.VarDesc.VarType.FP32:
dst_dtype = core.VarDesc.VarType.FP32
break
else:
dst_dtype = core.VarDesc.VarType.FP32
return dst_dtype
def get_amp_dst_dtype(
op, amp_dtype, level, block, amp_lists, keep_fp32_ops, keep_fp16_ops
):
if level == 'O2':
return amp_dtype
ops = block.ops
dst_dtype = amp_dtype
if op.type in amp_lists.gray_list:
keep_fp32 = False
keep_fp16 = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
prev_op in keep_fp32_ops
or prev_op.type in amp_lists.black_list
):
dst_dtype = core.VarDesc.VarType.FP32
elif (
prev_op in keep_fp16_ops
or prev_op.type in amp_lists.white_list
):
dst_dtype = amp_dtype
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
dst_dtype = core.VarDesc.VarType.FP32
return dst_dtype
def process_op_input_and_outputs(op, block, global_block, dtype):
low_precison_var_names = set()
# Get the FP16 input because the low_precison_var_names is required for the parameter casting.
# The dtype of the input is not set to fp16, because it is done in the step 3 of cast_model_to_fp16.
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(op, in_name):
continue
in_vars = set_var_dst_dtype(
op,
op.input(in_name),
block,
global_block,
dtype,
need_set_dtype=False,
)
low_precison_var_names = low_precison_var_names.union(in_vars)
# Set the output to FP16 because its consumer OP needs to determine if the dtype needs
# to be promoted.
for out_name in op.output_names:
# for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output(op, out_name):
continue
set_var_dst_dtype(
op,
op.output(out_name),
block,
global_block,
dtype,
need_set_dtype=True,
)
return low_precison_var_names
def cast_model_to_fp16(
program,
amp_lists=None,
use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16,
level='O2',
use_promote=False,
):
"""
Traverse all ops in the whole model and set their inputs and outputs
......@@ -438,158 +593,132 @@ def cast_model_to_fp16(
constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
_logger.debug("---- before cast model to fp16 ----")
_logger.debug(program)
if amp_lists is None:
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
"conditional_block_infer",
"select_input",
"while",
"while_grad",
"cast",
"tensor_array_to_tensor",
"lod_array_length",
"write_to_array",
}
# For amp o2 there is no blacklist by default.
if level == 'O2':
amp_lists.black_list = amp_lists.black_list - black_list
global_block = program.global_block()
keep_fp32_ops = set()
keep_fp16_ops = set()
to_fp16_var_names = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)
# step 1: set params dtype.
set_param_dtype(
program,
dtype=dest_type,
amp_lists=amp_lists,
use_fp16_guard=use_fp16_guard,
level=level,
)
def need_process(op):
need_process = True
if op.type in ["cast", "create_py_reader", "read"]:
need_process = False
else:
for attr_name in ['out_dtype', 'dtype']:
if op.has_attr(attr_name) and is_float_dtype(
op.attr(attr_name)
):
need_process = False
return need_process
# step 2: divide op into different sets according to the black/unsupported and white lists.
for block in program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
_logger.debug(f"-- process op: {op} --")
if not need_process(op):
_logger.debug("---- The op does not need to be processed ----.")
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name
):
continue
for in_var_name in op.input(in_name):
in_var = None
try:
in_var = block._var_recursive(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".format(
e
)
process_op_input_and_outputs(
op, block, global_block, core.VarDesc.VarType.FP32
)
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block --".format(
in_var_name
"---- Add into keep_fp32_ops because the op needs to be kept fp32 ----"
)
elif op.type in amp_lists.white_list:
keep_fp16_ops.add(op)
# get fp16 inputs and set op's outputs to fp16 for promote judgments
fp16_var_names = process_op_input_and_outputs(
op, block, global_block, dest_type
)
if in_var is None or in_var.type not in _valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name)
to_fp16_var_names = to_fp16_var_names.union(fp16_var_names)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".format(
op.type, in_var_name, in_var.dtype
"---- Add into keep_fp16_ops because the op in white_list ----"
)
else:
# divide others ops into fp16/fp32 sets according to promoting principle.
dst_dtype = dest_type
if not use_promote:
dst_dtype = get_amp_dst_dtype(
op,
dest_type,
level,
block,
amp_lists,
keep_fp32_ops,
keep_fp16_ops,
)
else:
dst_dtype = get_promote_dtype(op, dest_type, block)
for out_name in op.output_names:
# for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output(
op, out_name
):
continue
for out_var_name in op.output(out_name):
out_var = None
try:
out_var = block._var_recursive(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".format(
e
)
if dst_dtype == dest_type:
keep_fp16_ops.add(op)
fp16_var_names = process_op_input_and_outputs(
op, block, global_block, dest_type
)
out_var = global_block.var(out_var_name)
if out_var is not None:
to_fp16_var_names = to_fp16_var_names.union(fp16_var_names)
_logger.debug(
"-- var {} is got in the global block --".format(
out_var_name
"---- Add into keep_fp16_ops because it should be promoted to fp16 ----"
)
else:
keep_fp32_ops.add(op)
process_op_input_and_outputs(
op, block, global_block, core.VarDesc.VarType.FP32
)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(dest_type)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".format(
op.type, out_var_name, out_var.dtype
"---- Add into keep_fp32_ops because it should be promoted to fp32 ----"
)
)
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
):
op._set_attr(attr_name, dest_type)
# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
# step 3: insert cast op for op's inputs.
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in keep_fp32_ops:
pre_cast_num = _insert_cast_op(
if op in keep_fp16_ops:
in_var_cast_num = _insert_cast_op(
block,
op,
idx,
dest_type,
core.VarDesc.VarType.FP32,
dest_type,
)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
num_cast_ops += in_var_cast_num
if op in keep_fp32_ops:
in_var_cast_num = _insert_cast_op(
block,
op,
idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
idx,
dest_type,
out_var_name,
op_var_rename_map,
core.VarDesc.VarType.FP32,
)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1
num_cast_ops += in_var_cast_num
_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
idx += num_cast_ops + 1
_logger.debug("---- after cast model to fp16 ----")
_logger.debug(program)
return to_fp16_var_names
......@@ -646,108 +775,6 @@ def cast_parameters_to_fp16(
_logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
block = main_prog.global_block()
block._sync_with_cpp()
ops = block.ops
white_op_set = set()
black_op_set = set()
for op in ops:
# NOTE(zhiqiu): 'create_py_reader' and 'read' is used in non-iterable DataLoder,
# we don't need to handle reader op and the input of 'create_py_reader' is not
# in block, which may result in errors.
# See GeneratorLoader._init_non_iterable() for details.
if op.type == 'create_py_reader' or op.type == 'read':
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists
):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
white_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if (
prev_op in black_op_set
or prev_op.type in amp_lists.black_list
):
is_black_op = True
elif (
prev_op in white_op_set
or prev_op.type in amp_lists.white_list
):
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(
block, op, idx, dest_type, core.VarDesc.VarType.FP32
)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(
block, op, idx, core.VarDesc.VarType.FP32, dest_type
)
else:
pass
idx += num_cast_ops + 1
def update_role_var_grad(main_prog, params_grads):
"""
Update op_role_var attr for some ops to make sure the gradients
......
......@@ -29,6 +29,7 @@ def _build_optimizer(
amp_level="O1",
amp_lists=None,
use_grad_clip=False,
use_promote=False,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
......@@ -45,7 +46,11 @@ def _build_optimizer(
)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer, amp_lists, level=amp_level, dtype=amp_dtype
optimizer,
amp_lists,
level=amp_level,
dtype=amp_dtype,
use_promote=use_promote,
)
return optimizer
......@@ -67,7 +72,9 @@ class SimpleAddNet(nn.Layer):
return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
......@@ -92,7 +99,11 @@ def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, amp_lists
use_amp,
amp_dtype,
amp_level,
amp_lists,
use_promote=use_promote,
)
optimizer.minimize(loss)
feed_vars = [x]
......@@ -104,30 +115,37 @@ class SimpleConvNet(nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10)
self.linear = nn.Linear(in_features=96, out_features=4)
def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out)
out = nn.functional.softmax(out)
return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
name='input', shape=[None, 1, 6, 6], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, use_promote=use_promote
)
optimizer.minimize(loss)
return main_program, startup_program
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleEmbeddingNet(nn.Layer):
......@@ -149,7 +167,9 @@ class SimpleEmbeddingNet(nn.Layer):
return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
......@@ -159,7 +179,12 @@ def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(
use_amp, amp_dtype, amp_level, None, True
use_amp,
amp_dtype,
amp_level,
None,
True,
use_promote=use_promote,
)
optimizer.minimize(loss)
return main_program, startup_program
......@@ -211,3 +236,48 @@ class AmpTestBase(unittest.TestCase):
def setUp(self):
self.amp_dtype = None
self.amp_level = None
def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
for op_type, value in expected_fp16_calls.items():
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
)
def run_program(
self,
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_np,
max_iters,
level,
):
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if level == 'O2':
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
paddle.enable_static()
class TestAMPPromote(AmpTestBase):
def check_promote_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_promote_results(
True,
'float16',
'O1',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
def test_static_amp_o2(self):
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 2,
"relu": 1,
"matmul_v2": 1,
"softmax": 1,
"reduce_mean": 1,
"adamw": 4,
}
self.check_promote_results(
True,
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
if __name__ == '__main__':
unittest.main()
......@@ -221,14 +221,6 @@ class TestModelCastBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1"
......@@ -245,7 +237,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model(
......@@ -263,7 +255,7 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 2,
"adamw": 2,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)
class TestStaticBF16(AmpTestBase):
......@@ -274,60 +266,35 @@ class TestStaticBF16(AmpTestBase):
return x_fp32, x_bf16
def test_compare_o1_o2(self):
def _run_o1(place, exe, x_np, max_iters):
def _run(place, exe, x_np, max_iters, level):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O1")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
) = build_add_model(True, "bfloat16", level)
def _run_o2(place, exe, x_np, max_iters):
(
losses = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O2")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
place,
exe,
x_np,
max_iters,
level,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(place, exe, x_fp32, max_iters)
losses_o2 = _run_o2(place, exe, x_bf16, max_iters)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
losses_o1 = _run(place, exe, x_fp32, max_iters, 'O1')
losses_o2 = _run(place, exe, x_bf16, max_iters, 'O2')
if __name__ == '__main__':
......
......@@ -314,7 +314,10 @@ class TestImageClassification(unittest.TestCase):
# infer(use_cuda, save_dirname)
def test_amp_lists(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -324,7 +327,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_1(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -338,7 +344,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_2(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -352,7 +361,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_3(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -365,7 +377,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_4(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -381,7 +396,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_5(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......@@ -397,7 +415,10 @@ class TestImageClassification(unittest.TestCase):
self.assertEqual(amp_lists.gray_list, gray_list)
def test_amp_lists_6(self):
white_list = copy.copy(paddle.static.amp.fp16_lists.white_list)
white_list = (
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)
......
......@@ -39,7 +39,7 @@ class TestFuseResNetUnit(unittest.TestCase):
startup_program = paddle.static.Program()
with paddle.static.amp.fp16_guard():
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [1, 64, 64, 8])
x = paddle.static.data("x", [1, 64, 64, 8], dtype="float16")
conv2d = paddle.nn.Conv2D(
8, 32, 1, bias_attr=False, data_format='NHWC'
)
......@@ -66,3 +66,7 @@ class TestFuseResNetUnit(unittest.TestCase):
np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005
)
if __name__ == '__main__':
unittest.main()
......@@ -25,10 +25,10 @@ paddle.enable_static()
def build_resnet50(use_amp=False):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
dtype = 'float16' if use_amp else 'float32'
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[32, 3, 224, 224], dtype='float32'
name='image', shape=[32, 3, 224, 224], dtype=dtype
)
label = paddle.static.data(name='label', shape=[32], dtype='int64')
model = paddle.vision.models.resnet50()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册