提交 183baceb 编写于 作者: Z Zhen Wang

clean codes and fix some bugs. test=develop

上级 d68a02af
...@@ -26,6 +26,17 @@ __all__ = [ ...@@ -26,6 +26,17 @@ __all__ = [
] ]
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.'
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
tensor = scope.var(var_node.name()).get_tensor()
tensor.set(value, place)
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
def __init__(self, def __init__(self,
scope=None, scope=None,
...@@ -88,14 +99,14 @@ class QuantizationTransformPass(object): ...@@ -88,14 +99,14 @@ class QuantizationTransformPass(object):
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'." assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type: if activation_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ", "Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." %
str(activation_quantize_type)) (str(activation_quantize_type)))
if weight_quantize_type not in quant_type: if weight_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ", "Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
str(weight_quantize_type)) % (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
...@@ -121,8 +132,6 @@ class QuantizationTransformPass(object): ...@@ -121,8 +132,6 @@ class QuantizationTransformPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
...@@ -203,9 +212,12 @@ class QuantizationTransformPass(object): ...@@ -203,9 +212,12 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=core.VarDesc.VarType.INT64) var_dtype=core.VarDesc.VarType.INT64)
self._init_var_node( _init_var_node(
global_step_in, np.zeros( global_step_in,
[1], dtype='int64')) np.zeros(
[1], dtype='int64'),
self._scope,
self._place)
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor. # The attribute of `op_role` is needed by ParallelExecutor.
...@@ -284,7 +296,12 @@ class QuantizationTransformPass(object): ...@@ -284,7 +296,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) _init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
...@@ -299,9 +316,13 @@ class QuantizationTransformPass(object): ...@@ -299,9 +316,13 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node( _init_var_node(
scales_node, np.zeros( scales_node,
[self._window_size], dtype=data_type)) np.zeros(
[self._window_size], dtype=data_type),
self._scope,
self._place)
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
attrs = { attrs = {
...@@ -343,7 +364,12 @@ class QuantizationTransformPass(object): ...@@ -343,7 +364,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) _init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
...@@ -356,13 +382,23 @@ class QuantizationTransformPass(object): ...@@ -356,13 +382,23 @@ class QuantizationTransformPass(object):
shape=[1]) shape=[1])
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type)) _init_var_node(
scale_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node( accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'), name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1]) shape=[1])
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type)) _init_var_node(
accum_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc(state_in_node.var( state_out_node = graph.create_var_node_from_desc(state_in_node.var(
)) ))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
...@@ -482,16 +518,6 @@ class QuantizationTransformPass(object): ...@@ -482,16 +518,6 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
...@@ -594,8 +620,8 @@ class QuantizationFreezePass(object): ...@@ -594,8 +620,8 @@ class QuantizationFreezePass(object):
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
else: else:
scale_v = self._to_node(op_node.outputs, scale_v = graph._find_node_by_name(
op_node.output('OutScale')[0]) op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -627,8 +653,8 @@ class QuantizationFreezePass(object): ...@@ -627,8 +653,8 @@ class QuantizationFreezePass(object):
return graph return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node): def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = self._to_node(op_node.outputs, op_node.output('Out')[0]) k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
v = self._to_node(op_node.inputs, op_node.input('X')[0]) v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
if v.node not in self._op_input_rename_map: if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k.node] = v self._op_input_rename_map[k.node] = v
else: else:
...@@ -663,8 +689,8 @@ class QuantizationFreezePass(object): ...@@ -663,8 +689,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs, output_var_node = graph._find_node_by_name(
op_node.output_arg_names()[0]) op_node.outputs, op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node( weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'), name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -672,7 +698,9 @@ class QuantizationFreezePass(object): ...@@ -672,7 +698,9 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.dtype()) var_dtype=output_var_node.dtype())
data_type = 'float64' if output_var_node.dtype( data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(weight_scale_node, channel_scale.astype(data_type)) _init_var_node(weight_scale_node,
channel_scale.astype(data_type), self._scope,
self._place)
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -724,8 +752,8 @@ class QuantizationFreezePass(object): ...@@ -724,8 +752,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs, output_var_node = graph._find_node_by_name(
op_node.output_arg_names()[0]) op_node.outputs, op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -746,24 +774,6 @@ class QuantizationFreezePass(object): ...@@ -746,24 +774,6 @@ class QuantizationFreezePass(object):
self._op_output_rename_map[output_var_node.node] = dequant_var_node self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _to_node(self, nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _load_var(self, name): def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor()) return np.array(self._scope.find_var(name).get_tensor())
......
...@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy): ...@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy):
activation_bits=8, activation_bits=8,
weight_bits=8, weight_bits=8,
activation_quantize_type='abs_max', activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
save_in_nodes=None, save_in_nodes=None,
save_out_nodes=None): save_out_nodes=None):
""" """
Args: Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0 start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0 end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights. float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. defalut: None. None means it doesn't save float model. defalut: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution. mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. defalut: None. None means it doesn't save mobile model. defalut: None.
...@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy): ...@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy):
dynamically each step in both training and testing period. If use dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated 'range_abs_max', a static quantization scale will be calculated
during training and used in inference. during training and used in inference.
save_in_nodes(list<str>): A list of variable names used to prune graph weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model. for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model. for saving inference model.
""" """
...@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy): ...@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy):
self.activation_bits = activation_bits self.activation_bits = activation_bits
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.activation_quantize_type = activation_quantize_type self.activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type
self.save_out_nodes = save_out_nodes self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes self.save_in_nodes = save_in_nodes
...@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy): ...@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy):
place=context.place, place=context.place,
weight_bits=self.weight_bits, weight_bits=self.weight_bits,
activation_bits=self.activation_bits, activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type) activation_quantize_type=self.activation_quantize_type,
weight_quantize_type=self.weight_quantize_type)
transform_pass.apply(train_ir_graph) transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph) transform_pass.apply(test_ir_graph)
...@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy): ...@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy):
scope=context.scope, scope=context.scope,
place=context.place, place=context.place,
weight_bits=self.weight_bits, weight_bits=self.weight_bits,
activation_bits=self.activation_bits) activation_bits=self.activation_bits,
weight_quantize_type=self.weight_quantize_type)
freeze_pass.apply(test_ir_graph) freeze_pass.apply(test_ir_graph)
# for other strategies # for other strategies
......
...@@ -35,6 +35,8 @@ strategies: ...@@ -35,6 +35,8 @@ strategies:
start_epoch: 0 start_epoch: 0
end_epoch: 0 end_epoch: 0
float_model_save_path: './output/float' float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8'
weight_bits: 8 weight_bits: 8
activation_bits: 8 activation_bits: 8
weight_quantize_type: 'abs_max' weight_quantize_type: 'abs_max'
......
...@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
place=place, place=place,
activation_quantize_type=activation_quant_type, activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
...@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, place=place, weight_quantize_type=weight_quant_type) scope=scope, place=place, weight_quantize_type=weight_quant_type)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
......
...@@ -104,14 +104,14 @@ def cuda_places(device_ids=None): ...@@ -104,14 +104,14 @@ def cuda_places(device_ids=None):
:code:`FLAGS_selected_gpus=0,1,2`, the returned list would :code:`FLAGS_selected_gpus=0,1,2`, the returned list would
be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)].
If :code:`FLAGS_selected_gpus` is not set, all visible If :code:`FLAGS_selected_gpus` is not set, all visible
gpu places would be returned. gpu places would be returned.
If :code:`device_ids` is not None, it should be the device If :code:`device_ids` is not None, it should be the device
ids of gpus. For example, if :code:`device_ids=[0,1,2]`, ids of gpus. For example, if :code:`device_ids=[0,1,2]`,
the returned list would be the returned list would be
[fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)].
Args: Args:
device_ids (None|list(int)|tuple(int)): gpu device id list. device_ids (None|list(int)|tuple(int)): gpu device id list.
Returns: Returns:
...@@ -133,11 +133,11 @@ def cuda_places(device_ids=None): ...@@ -133,11 +133,11 @@ def cuda_places(device_ids=None):
def cpu_places(device_count=None): def cpu_places(device_count=None):
''' '''
Create a list of :code:`fluid.CPUPlace` objects. Create a list of :code:`fluid.CPUPlace` objects.
If :code:`device_count` is None, the device count would If :code:`device_count` is None, the device count would
be determined by environment variable :code:`CPU_NUM`. be determined by environment variable :code:`CPU_NUM`.
If :code:`CPU_NUM` is not set, the device count would If :code:`CPU_NUM` is not set, the device count would
be determined by :code:`multiprocessing.cpu_count()`. be determined by :code:`multiprocessing.cpu_count()`.
Args: Args:
device_count (None|int): device number. device_count (None|int): device number.
...@@ -155,9 +155,9 @@ def cuda_pinned_places(device_count=None): ...@@ -155,9 +155,9 @@ def cuda_pinned_places(device_count=None):
Create a list of :code:`fluid.CUDAPinnedPlace` objects. Create a list of :code:`fluid.CUDAPinnedPlace` objects.
If :code:`device_count` is None, the device count would If :code:`device_count` is None, the device count would
be determined by environment variable :code:`CPU_NUM`. be determined by environment variable :code:`CPU_NUM`.
If :code:`CPU_NUM` is not set, the device count would If :code:`CPU_NUM` is not set, the device count would
be determined by :code:`multiprocessing.cpu_count()`. be determined by :code:`multiprocessing.cpu_count()`.
Args: Args:
device_count (None|int): device number. device_count (None|int): device number.
...@@ -2164,40 +2164,6 @@ class IrGraph(object): ...@@ -2164,40 +2164,6 @@ class IrGraph(object):
""" """
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def _find_var_node(self, key):
"""
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name or id.
"""
target_var_node = None
var_nodes = self.all_var_nodes()
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype): def create_persistable_node(self, name, var_type, shape, var_dtype):
""" """
Create a persistable variable node in the graph. In IrGraph, Create a persistable variable node in the graph. In IrGraph,
...@@ -2342,14 +2308,6 @@ class IrGraph(object): ...@@ -2342,14 +2308,6 @@ class IrGraph(object):
core.graph_safe_remove_nodes(self.graph, original_nodes) core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self): def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph) ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict() var_nodes = dict()
for node in ordered_nodes: for node in ordered_nodes:
...@@ -2357,16 +2315,17 @@ class IrGraph(object): ...@@ -2357,16 +2315,17 @@ class IrGraph(object):
for each_var_name in node.op().input_arg_names(): for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes: if each_var_name not in var_nodes:
var_nodes[each_var_name] = [ var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name) self._find_node_by_name(node.inputs, each_var_name)
] ]
for each_var_name in node.op().output_arg_names(): for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes: if each_var_name not in var_nodes:
var_nodes[each_var_name] = [ var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name) self._find_node_by_name(node.outputs, each_var_name)
] ]
else: else:
var_nodes[each_var_name].append( var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name)) self._find_node_by_name(node.outputs,
each_var_name))
self.graph.resolve_hazard(var_nodes) self.graph.resolve_hazard(var_nodes)
def has_circle(self): def has_circle(self):
...@@ -2479,6 +2438,17 @@ class IrGraph(object): ...@@ -2479,6 +2438,17 @@ class IrGraph(object):
program = Program._construct_from_desc(desc) program = Program._construct_from_desc(desc)
return program return program
def _find_node_by_name(self, nodes, node_name):
"""
Find a node in the giving nodes set by the name.
"""
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _update_desc_attr(self, desc, name, val): def _update_desc_attr(self, desc, name, val):
""" """
Update the value of desc's attribute by attribute's name. Update the value of desc's attribute by attribute's name.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册