提交 0db41a9c 编写于 作者: W WangZhen

add op_role attr when creating op node.

上级 c67b29c1
...@@ -180,9 +180,14 @@ class QuantizationTransformPass(object): ...@@ -180,9 +180,14 @@ class QuantizationTransformPass(object):
Constant(value=0, force_cpu=True) Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
increment_op = graph.create_op_node( increment_op = graph.create_op_node(
op_type='increment', op_type='increment',
attrs={'step': 1.0}, attrs={
'step': 1.0,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': global_step_in}, inputs={'X': global_step_in},
outputs={'Out': global_step_out}) outputs={'Out': global_step_out})
graph.link_to(global_step_in, increment_op) graph.link_to(global_step_in, increment_op)
...@@ -217,7 +222,10 @@ class QuantizationTransformPass(object): ...@@ -217,7 +222,10 @@ class QuantizationTransformPass(object):
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={'bit_length': quant_bits}, attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, outputs={'Out': quant_var_node,
'OutScale': scale_var_node}) 'OutScale': scale_var_node})
...@@ -262,7 +270,8 @@ class QuantizationTransformPass(object): ...@@ -262,7 +270,8 @@ class QuantizationTransformPass(object):
attrs = { attrs = {
'window_size': self._window_size, 'window_size': self._window_size,
'bit_length': quant_bits, 'bit_length': quant_bits,
'is_test': self._is_test 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
} }
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max', op_type='fake_quantize_range_abs_max',
...@@ -295,7 +304,10 @@ class QuantizationTransformPass(object): ...@@ -295,7 +304,10 @@ class QuantizationTransformPass(object):
max_range = (1 << (quant_bits - 1)) - 1 max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)}, attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node, inputs={'X': var_node,
'Scale': scale_var_node}, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
...@@ -444,7 +456,10 @@ class QuantizationFreezePass(object): ...@@ -444,7 +456,10 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.var().dtype()) var_dtype=output_var_node.var().dtype())
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)}, attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': output_var_node, inputs={'X': output_var_node,
'Scale': scale_var_node}, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
......
...@@ -251,6 +251,11 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -251,6 +251,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 10 iters = 10
batch_size = 128 batch_size = 128
train_exe = fluid.ParallelExecutor(
main_program=quantized_main_program,
use_cuda=bool(use_cuda),
loss_name=loss.name,
scope=scope)
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
...@@ -261,9 +266,11 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -261,9 +266,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
for _ in range(iters): for _ in range(iters):
data = next(train_reader()) data = next(train_reader())
loss_v = exe.run(program=quantized_main_program, #loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data), # feed=feeder.feed(data),
fetch_list=[loss]) # fetch_list=[loss])
loss_v = train_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name])
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册