未验证 提交 2f2b1f23 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][B009][B010] use normal property access instead of getattr/setattr (#51530)

上级 d1e2c61b
...@@ -54,6 +54,10 @@ select = [ ...@@ -54,6 +54,10 @@ select = [
# NumPy-specific rules # NumPy-specific rules
"NPY001", "NPY001",
# Bugbear
"B009",
"B010",
] ]
unfixable = [ unfixable = [
"NPY001" "NPY001"
......
...@@ -138,10 +138,7 @@ def fetch_all(): ...@@ -138,10 +138,7 @@ def fetch_all():
if "fetch" in dir( if "fetch" in dir(
importlib.import_module("paddle.dataset.%s" % module_name) importlib.import_module("paddle.dataset.%s" % module_name)
): ):
getattr( importlib.import_module('paddle.dataset.%s' % module_name).fetch()
importlib.import_module("paddle.dataset.%s" % module_name),
"fetch",
)()
def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump): def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
......
...@@ -1282,7 +1282,7 @@ class Fleet: ...@@ -1282,7 +1282,7 @@ class Fleet:
self.origin_main_program = loss.block.program self.origin_main_program = loss.block.program
# add distributed attr # add distributed attr
if not hasattr(self.origin_main_program, "distributed_info_"): if not hasattr(self.origin_main_program, "distributed_info_"):
setattr(self.origin_main_program, "distributed_info_", dict()) self.origin_main_program.distributed_info_ = dict()
self.origin_main_program.distributed_info_[ self.origin_main_program.distributed_info_[
"dp_degree" "dp_degree"
] = self._user_defined_strategy.sharding_configs["dp_degree"] ] = self._user_defined_strategy.sharding_configs["dp_degree"]
......
...@@ -143,7 +143,7 @@ class VocabParallelEmbedding(Layer): ...@@ -143,7 +143,7 @@ class VocabParallelEmbedding(Layer):
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed: if self.weight.is_distributed:
setattr(self.weight, "split_axis", 0) self.weight.split_axis = 0
def forward(self, x): def forward(self, x):
if self.is_mp: if self.is_mp:
...@@ -277,7 +277,7 @@ class ColumnParallelLinear(Layer): ...@@ -277,7 +277,7 @@ class ColumnParallelLinear(Layer):
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed: if self.weight.is_distributed:
setattr(self.weight, "split_axis", 1) self.weight.split_axis = 1
if has_bias: if has_bias:
# initialize bias to zero like Megatron # initialize bias to zero like Megatron
...@@ -289,7 +289,7 @@ class ColumnParallelLinear(Layer): ...@@ -289,7 +289,7 @@ class ColumnParallelLinear(Layer):
) )
self.bias.is_distributed = True if self.is_mp else False self.bias.is_distributed = True if self.is_mp else False
if self.bias.is_distributed: if self.bias.is_distributed:
setattr(self.bias, "split_axis", 0) self.bias.split_axis = 0
else: else:
self.bias = None self.bias = None
...@@ -443,7 +443,7 @@ class RowParallelLinear(Layer): ...@@ -443,7 +443,7 @@ class RowParallelLinear(Layer):
self.weight.is_distributed = True if self.is_mp else False self.weight.is_distributed = True if self.is_mp else False
if self.weight.is_distributed: if self.weight.is_distributed:
setattr(self.weight, "split_axis", 0) self.weight.split_axis = 0
if has_bias: if has_bias:
self.bias = self.create_parameter( self.bias = self.create_parameter(
......
...@@ -493,7 +493,7 @@ class PipelineLayer(nn.Layer): ...@@ -493,7 +493,7 @@ class PipelineLayer(nn.Layer):
for param in comm['layer'].parameters(): for param in comm['layer'].parameters():
if self.global_rank != min(comm['ranks']): if self.global_rank != min(comm['ranks']):
setattr(param, 'is_firstly_shared', False) param.is_firstly_shared = False
def allreduce_shared_weight_gradients(self): def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items(): for key, comm in self.shared_comm.items():
...@@ -641,7 +641,7 @@ class PipelineLayer(nn.Layer): ...@@ -641,7 +641,7 @@ class PipelineLayer(nn.Layer):
for param in self.shared_layers[ for param in self.shared_layers[
layer.layer_name layer.layer_name
].parameters(): ].parameters():
setattr(param, "is_firstly_shared", True) param.is_firstly_shared = True
if layer.forward_func is None: if layer.forward_func is None:
run_function.append(self.shared_layers[layer.layer_name]) run_function.append(self.shared_layers[layer.layer_name])
......
...@@ -1047,18 +1047,18 @@ def _create_params_grad(trainable_params, param2buffer_size, task_flow): ...@@ -1047,18 +1047,18 @@ def _create_params_grad(trainable_params, param2buffer_size, task_flow):
def _PartitionParam(param): def _PartitionParam(param):
if not hasattr(param, "fw_storage"): if not hasattr(param, "fw_storage"):
setattr(param, "fw_storage", None) param.fw_storage = None
setattr(param, "bw_storage", None) param.bw_storage = None
setattr(param, "master_weight", None) param.master_weight = None
setattr(param, "status", "all") param.status = "all"
setattr(param, "use_count", 0) param.use_count = 0
return param return param
def _UnsliceParam(param): def _UnsliceParam(param):
if not hasattr(param, "unslice"): if not hasattr(param, "unslice"):
setattr(param, "unslice", True) param.unslice = True
setattr(param, "master_weight", None) param.master_weight = None
return param return param
...@@ -1078,11 +1078,11 @@ def _VarBaseWrapper(param): ...@@ -1078,11 +1078,11 @@ def _VarBaseWrapper(param):
def _OptimizerWrapper(optimizer, offload, group, update_params_slice): def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
if not hasattr(optimizer, "_optim"): if not hasattr(optimizer, "_optim"):
setattr(optimizer, "_optim", optimizer) optimizer._optim = optimizer
setattr(optimizer, "offload", offload) optimizer.offload = offload
setattr(optimizer, "_group", group) optimizer._group = group
setattr(optimizer, "update_scaler", None) optimizer.update_scaler = None
setattr(optimizer, "update_slice", update_params_slice) optimizer.update_slice = update_params_slice
return optimizer return optimizer
......
...@@ -67,7 +67,7 @@ class TestCustomRawReluOp(unittest.TestCase): ...@@ -67,7 +67,7 @@ class TestCustomRawReluOp(unittest.TestCase):
def custom_raw_relu(self, x): def custom_raw_relu(self, x):
module = importlib.import_module(MODULE_NAME) module = importlib.import_module(MODULE_NAME)
custom_raw_relu_op = getattr(module, "custom_raw_relu") custom_raw_relu_op = module.custom_raw_relu
self.assertIsNotNone(custom_raw_relu_op) self.assertIsNotNone(custom_raw_relu_op)
return custom_raw_relu_op(x) return custom_raw_relu_op(x)
......
...@@ -31,7 +31,7 @@ class ForwardNotExist(paddle.nn.Layer): ...@@ -31,7 +31,7 @@ class ForwardNotExist(paddle.nn.Layer):
net = ForwardNotExist() net = ForwardNotExist()
setattr(net, "forward", "A string so that convert forward will fail") net.forward = "A string so that convert forward will fail"
class TestConvertCall(unittest.TestCase): class TestConvertCall(unittest.TestCase):
......
...@@ -449,7 +449,7 @@ class OpTest(unittest.TestCase): ...@@ -449,7 +449,7 @@ class OpTest(unittest.TestCase):
) )
or ( or (
hasattr(self, 'mkldnn_data_type') hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "bfloat16" and self.mkldnn_data_type == "bfloat16"
) )
or ( or (
hasattr(self, 'attrs') hasattr(self, 'attrs')
...@@ -469,7 +469,7 @@ class OpTest(unittest.TestCase): ...@@ -469,7 +469,7 @@ class OpTest(unittest.TestCase):
) )
or ( or (
hasattr(self, 'mkldnn_data_type') hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "float16" and self.mkldnn_data_type == "float16"
) )
or ( or (
hasattr(self, 'attrs') hasattr(self, 'attrs')
...@@ -1713,7 +1713,7 @@ class OpTest(unittest.TestCase): ...@@ -1713,7 +1713,7 @@ class OpTest(unittest.TestCase):
prim_checker = PrimForwardChecker(self, place) prim_checker = PrimForwardChecker(self, place)
prim_checker.check() prim_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) self.__class__.check_prim = True
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
# set some flags by the combination of arguments. # set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
...@@ -1728,8 +1728,9 @@ class OpTest(unittest.TestCase): ...@@ -1728,8 +1728,9 @@ class OpTest(unittest.TestCase):
if self.is_mkldnn_op(): if self.is_mkldnn_op():
check_dygraph = False check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr( if (
self, 'force_fp32_output' hasattr(self, 'force_fp32_output')
and self.force_fp32_output
): ):
atol = 1e-2 if atol < 1e-2 else atol atol = 1e-2 if atol < 1e-2 else atol
else: else:
...@@ -2078,7 +2079,7 @@ class OpTest(unittest.TestCase): ...@@ -2078,7 +2079,7 @@ class OpTest(unittest.TestCase):
) )
prim_grad_checker.check() prim_grad_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) self.__class__.check_prim = True
self._check_grad_helper() self._check_grad_helper()
if only_check_prim: if only_check_prim:
return return
......
...@@ -451,7 +451,7 @@ class OpTest(unittest.TestCase): ...@@ -451,7 +451,7 @@ class OpTest(unittest.TestCase):
) )
or ( or (
hasattr(self, 'mkldnn_data_type') hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "bfloat16" and self.mkldnn_data_type == "bfloat16"
) )
or ( or (
hasattr(self, 'attrs') hasattr(self, 'attrs')
...@@ -471,7 +471,7 @@ class OpTest(unittest.TestCase): ...@@ -471,7 +471,7 @@ class OpTest(unittest.TestCase):
) )
or ( or (
hasattr(self, 'mkldnn_data_type') hasattr(self, 'mkldnn_data_type')
and getattr(self, 'mkldnn_data_type') == "float16" and self.mkldnn_data_type == "float16"
) )
or ( or (
hasattr(self, 'attrs') hasattr(self, 'attrs')
...@@ -1502,7 +1502,7 @@ class OpTest(unittest.TestCase): ...@@ -1502,7 +1502,7 @@ class OpTest(unittest.TestCase):
prim_checker = PrimForwardChecker(self, place) prim_checker = PrimForwardChecker(self, place)
prim_checker.check() prim_checker.check()
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) self.__class__.check_prim = True
self.__class__.op_type = self.op_type self.__class__.op_type = self.op_type
# disable legacy dygraph check when check_eager is True # disable legacy dygraph check when check_eager is True
if check_eager: if check_eager:
...@@ -1907,8 +1907,9 @@ class OpTest(unittest.TestCase): ...@@ -1907,8 +1907,9 @@ class OpTest(unittest.TestCase):
if self.is_mkldnn_op(): if self.is_mkldnn_op():
check_dygraph = False check_dygraph = False
check_eager = False check_eager = False
if hasattr(self, 'force_fp32_output') and getattr( if (
self, 'force_fp32_output' hasattr(self, 'force_fp32_output')
and self.force_fp32_output
): ):
atol = 1e-2 if atol < 1e-2 else atol atol = 1e-2 if atol < 1e-2 else atol
else: else:
...@@ -2288,7 +2289,7 @@ class OpTest(unittest.TestCase): ...@@ -2288,7 +2289,7 @@ class OpTest(unittest.TestCase):
) )
prim_grad_checker.check() prim_grad_checker.check()
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 # Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True) self.__class__.check_prim = True
self._check_grad_helper() self._check_grad_helper()
if only_check_prim: if only_check_prim:
return return
......
...@@ -312,8 +312,8 @@ def summary_string(model, input_size=None, dtypes=None, input=None): ...@@ -312,8 +312,8 @@ def summary_string(model, input_size=None, dtypes=None, input=None):
params += np.prod(v.shape) params += np.prod(v.shape)
try: try:
if (getattr(getattr(layer, k), 'trainable')) and ( if (getattr(layer, k).trainable) and (
not getattr(getattr(layer, k), 'stop_gradient') not getattr(layer, k).stop_gradient
): ):
summary[m_key]["trainable_params"] += np.prod(v.shape) summary[m_key]["trainable_params"] += np.prod(v.shape)
summary[m_key]["trainable"] = True summary[m_key]["trainable"] = True
......
...@@ -219,7 +219,7 @@ def _get_dims_mapping(dist_parameter, mp_group): ...@@ -219,7 +219,7 @@ def _get_dims_mapping(dist_parameter, mp_group):
dist_shape = np.array(dist_parameter.shape) dist_shape = np.array(dist_parameter.shape)
if hasattr(dist_parameter, "split_axis"): if hasattr(dist_parameter, "split_axis"):
aixs = getattr(dist_parameter, "split_axis") aixs = dist_parameter.split_axis
mapping = [-1 for _ in dist_shape] mapping = [-1 for _ in dist_shape]
mapping[aixs] = 1 mapping[aixs] = 1
logger.debug( logger.debug(
...@@ -351,7 +351,7 @@ def _get_wrapped_dist_state_dict(dist_state_dict): ...@@ -351,7 +351,7 @@ def _get_wrapped_dist_state_dict(dist_state_dict):
logger.debug(f"not first used : {v.name}") logger.debug(f"not first used : {v.name}")
continue continue
wrapped_state_dict[name_mapping[v.name]] = v wrapped_state_dict[name_mapping[v.name]] = v
setattr(v, "dims_mapping", _get_dims_mapping(v, mp_group)) v.dims_mapping = _get_dims_mapping(v, mp_group)
logger.debug( logger.debug(
f"saving param: {v.name} -> {name_mapping[v.name]} shape: {v.shape}" f"saving param: {v.name} -> {name_mapping[v.name]} shape: {v.shape}"
) )
......
...@@ -312,7 +312,7 @@ def convert_call(func): ...@@ -312,7 +312,7 @@ def convert_call(func):
# Bound mothod will be convert into plain function after `convert_to_static`. # Bound mothod will be convert into plain function after `convert_to_static`.
# So descriptor mechanism is used to bound `self` instance on function to # So descriptor mechanism is used to bound `self` instance on function to
# keep it as bound method. # keep it as bound method.
setattr(func, 'forward', forward_func.__get__(func)) func.forward = forward_func.__get__(func)
except (IOError, OSError, TypeError): except (IOError, OSError, TypeError):
# NOTE: func.forward may have been decorated. # NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self func_self = None if func_self else func_self
......
...@@ -314,8 +314,8 @@ class StaticFunction: ...@@ -314,8 +314,8 @@ class StaticFunction:
# save the instance `self` while decorating a method of class. # save the instance `self` while decorating a method of class.
if inspect.ismethod(function): if inspect.ismethod(function):
self._dygraph_function = getattr(function, '__func__') self._dygraph_function = function.__func__
self._class_instance = getattr(function, '__self__') self._class_instance = function.__self__
if not hasattr(self._class_instance, '_original_funcs'): if not hasattr(self._class_instance, '_original_funcs'):
raise TypeError( raise TypeError(
...@@ -885,7 +885,7 @@ class HookHelper: ...@@ -885,7 +885,7 @@ class HookHelper:
self.need_apply_hook = ( self.need_apply_hook = (
with_hook with_hook
and isinstance(self.class_instance, layers.Layer) and isinstance(self.class_instance, layers.Layer)
and getattr(func, "__name__") == "forward" and func.__name__ == "forward"
) )
def apply_pre_hooks(self, inputs): def apply_pre_hooks(self, inputs):
......
...@@ -576,7 +576,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): ...@@ -576,7 +576,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
# The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained # The 'forward' or 'another_forward' of 'TranslatedLayer' cannot be obtained
# through 'func_name'. So set the special function name '__i_m_p_l__'. # through 'func_name'. So set the special function name '__i_m_p_l__'.
if hasattr(module, '__i_m_p_l__'): if hasattr(module, '__i_m_p_l__'):
callable_func = getattr(module, '__i_m_p_l__') callable_func = module.__i_m_p_l__
callable_func.__name__ = func_name callable_func.__name__ = func_name
elif hasattr(module, func_name): elif hasattr(module, func_name):
callable_func = getattr(module, func_name) callable_func = getattr(module, func_name)
...@@ -1120,11 +1120,11 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1120,11 +1120,11 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
def _reset_name_scope(self, node): def _reset_name_scope(self, node):
# always reset the node as empty namescope. # always reset the node as empty namescope.
setattr(node, "pd_scope", NameScope()) node.pd_scope = NameScope()
def _get_name_scope(self, node): def _get_name_scope(self, node):
if not hasattr(node, "pd_scope"): if not hasattr(node, "pd_scope"):
setattr(node, "pd_scope", NameScope()) node.pd_scope = NameScope()
return node.pd_scope return node.pd_scope
def _current_name_scope(self): def _current_name_scope(self):
...@@ -1224,11 +1224,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor): ...@@ -1224,11 +1224,7 @@ class FunctionNameLivenessAnalysis(gast.NodeVisitor):
) )
def pre_func(): def pre_func():
setattr( node.before_created = self._nearest_function_scope().existed_vars()
node,
"before_created",
self._nearest_function_scope().existed_vars(),
)
self._visit_scope_node(node, pre_func, post_func) self._visit_scope_node(node, pre_func, post_func)
......
...@@ -320,7 +320,7 @@ def grid_sample( ...@@ -320,7 +320,7 @@ def grid_sample(
'use_cudnn', 'use_cudnn',
use_cudnn, use_cudnn,
) )
out = getattr(_legacy_C_ops, 'grid_sampler')(x, grid, *attrs) out = _legacy_C_ops.grid_sampler(x, grid, *attrs)
else: else:
helper = LayerHelper("grid_sample", **locals()) helper = LayerHelper("grid_sample", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sample') check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'grid_sample')
......
...@@ -30,18 +30,18 @@ class QuantedConv2D(ConvertibleQuantedLayer): ...@@ -30,18 +30,18 @@ class QuantedConv2D(ConvertibleQuantedLayer):
super(QuantedConv2D, self).__init__() super(QuantedConv2D, self).__init__()
# For Conv2D # For Conv2D
self._groups = getattr(layer, '_groups') self._groups = layer._groups
self._stride = getattr(layer, '_stride') self._stride = layer._stride
self._padding = getattr(layer, '_padding') self._padding = layer._padding
self._padding_mode = getattr(layer, '_padding_mode') self._padding_mode = layer._padding_mode
if self._padding_mode != 'zeros': if self._padding_mode != 'zeros':
self._reversed_padding_repeated_twice = getattr( self._reversed_padding_repeated_twice = (
layer, '_reversed_padding_repeated_twice' layer._reversed_padding_repeated_twice
) )
self._dilation = getattr(layer, '_dilation') self._dilation = layer._dilation
self._data_format = getattr(layer, '_data_format') self._data_format = layer._data_format
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
self.weight_quanter = None self.weight_quanter = None
self.activation_quanter = None self.activation_quanter = None
......
...@@ -28,9 +28,9 @@ class QuantedLinear(ConvertibleQuantedLayer): ...@@ -28,9 +28,9 @@ class QuantedLinear(ConvertibleQuantedLayer):
def __init__(self, layer: Layer, q_config): def __init__(self, layer: Layer, q_config):
super(QuantedLinear, self).__init__() super(QuantedLinear, self).__init__()
# For Linear # For Linear
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
self.name = getattr(layer, 'name') self.name = layer.name
# For FakeQuant # For FakeQuant
self.weight_quanter = None self.weight_quanter = None
......
...@@ -533,18 +533,18 @@ class QuantizedConv2D(Layer): ...@@ -533,18 +533,18 @@ class QuantizedConv2D(Layer):
): ):
super().__init__() super().__init__()
# For Conv2D # For Conv2D
self._groups = getattr(layer, '_groups') self._groups = layer._groups
self._stride = getattr(layer, '_stride') self._stride = layer._stride
self._padding = getattr(layer, '_padding') self._padding = layer._padding
self._padding_mode = getattr(layer, '_padding_mode') self._padding_mode = layer._padding_mode
if self._padding_mode != 'zeros': if self._padding_mode != 'zeros':
self._reversed_padding_repeated_twice = getattr( self._reversed_padding_repeated_twice = (
layer, '_reversed_padding_repeated_twice' layer._reversed_padding_repeated_twice
) )
self._dilation = getattr(layer, '_dilation') self._dilation = layer._dilation
self._data_format = getattr(layer, '_data_format') self._data_format = layer._data_format
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
# For FakeQuant # For FakeQuant
self._conv2d_quant_axis = 0 self._conv2d_quant_axis = 0
...@@ -654,14 +654,14 @@ class QuantizedConv2DTranspose(Layer): ...@@ -654,14 +654,14 @@ class QuantizedConv2DTranspose(Layer):
""" """
super().__init__() super().__init__()
# For Conv2DTranspose # For Conv2DTranspose
self._groups = getattr(layer, '_groups') self._groups = layer._groups
self._stride = getattr(layer, '_stride') self._stride = layer._stride
self._padding = getattr(layer, '_padding') self._padding = layer._padding
self._output_padding = getattr(layer, 'output_padding') self._output_padding = layer.output_padding
self._dilation = getattr(layer, '_dilation') self._dilation = layer._dilation
self._data_format = getattr(layer, '_data_format') self._data_format = layer._data_format
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
# For FakeQuant # For FakeQuant
self._conv2d_transpose_quant_axis = 1 self._conv2d_transpose_quant_axis = 1
if weight_quant_layer is not None: if weight_quant_layer is not None:
...@@ -748,9 +748,9 @@ class QuantizedLinear(Layer): ...@@ -748,9 +748,9 @@ class QuantizedLinear(Layer):
): ):
super().__init__() super().__init__()
# For Linear # For Linear
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
self.name = getattr(layer, 'name') self.name = layer.name
# For FakeQuant # For FakeQuant
self._linear_quant_axis = 1 self._linear_quant_axis = 1
...@@ -829,15 +829,15 @@ class QuantizedColumnParallelLinear(Layer): ...@@ -829,15 +829,15 @@ class QuantizedColumnParallelLinear(Layer):
act_quant_layer is None act_quant_layer is None
), "When quantizing ColumnParallelLinear, act_quant_layer should be None." ), "When quantizing ColumnParallelLinear, act_quant_layer should be None."
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
self.name = getattr(layer, '_name') self.name = layer._name
# For FakeQuant # For FakeQuant
self._linear_quant_axis = 1 self._linear_quant_axis = 1
self.is_mp = getattr(layer, 'is_mp') self.is_mp = layer.is_mp
self.model_parallel_group = getattr(layer, 'model_parallel_group') self.model_parallel_group = layer.model_parallel_group
self.gather_output = getattr(layer, 'gather_output') self.gather_output = layer.gather_output
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, weight_quantize_type,
...@@ -923,15 +923,15 @@ class QuantizedRowParallelLinear(Layer): ...@@ -923,15 +923,15 @@ class QuantizedRowParallelLinear(Layer):
), "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself." ), "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself."
# For Linear # For Linear
self.weight = getattr(layer, 'weight') self.weight = layer.weight
self.bias = getattr(layer, 'bias') self.bias = layer.bias
self.name = getattr(layer, '_name') self.name = layer._name
# For FakeQuant # For FakeQuant
self._linear_quant_axis = 1 self._linear_quant_axis = 1
self.input_is_parallel = getattr(layer, 'input_is_parallel') self.input_is_parallel = layer.input_is_parallel
self.is_mp = getattr(layer, 'is_mp') self.is_mp = layer.is_mp
self.model_parallel_group = getattr(layer, 'model_parallel_group') self.model_parallel_group = layer.model_parallel_group
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, weight_quantize_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册