提交 0db41a9c 编写于 作者: W WangZhen

add op_role attr when creating op node.

上级 c67b29c1
......@@ -180,9 +180,14 @@ class QuantizationTransformPass(object):
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
increment_op = graph.create_op_node(
op_type='increment',
attrs={'step': 1.0},
attrs={
'step': 1.0,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
graph.link_to(global_step_in, increment_op)
......@@ -217,7 +222,10 @@ class QuantizationTransformPass(object):
var_dtype=var_node.var().dtype())
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={'bit_length': quant_bits},
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
......@@ -262,7 +270,8 @@ class QuantizationTransformPass(object):
attrs = {
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
......@@ -295,7 +304,10 @@ class QuantizationTransformPass(object):
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)},
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
......@@ -444,7 +456,10 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.var().dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)},
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': output_var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
......
......@@ -251,6 +251,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 10
batch_size = 128
train_exe = fluid.ParallelExecutor(
main_program=quantized_main_program,
use_cuda=bool(use_cuda),
loss_name=loss.name,
scope=scope)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
......@@ -261,9 +266,11 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
#loss_v = exe.run(program=quantized_main_program,
# feed=feeder.feed(data),
# fetch_list=[loss])
loss_v = train_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name])
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册