提交 38d5ae7f 编写于 作者: D Dang Qingqing

Update code when get_inference_program is removed.

上级 ba8ba300
......@@ -177,21 +177,30 @@ class TestQuantizeTranspiler(unittest.TestCase):
self.residual_block_quant('range_abs_max')
def freeze_program(self, use_cuda):
def build_program(main, startup, is_test):
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
main = fluid.Program()
startup = fluid.Program()
quant_transpiler = QuantizeTranspiler()
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
quant_transpiler.training_transpile(main)
test_program = fluid.Program()
test_program = main.clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program(loss)
feeds, loss = build_program(main, startup, True)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
quant_transpiler = QuantizeTranspiler()
quant_transpiler.training_transpile(main)
quant_transpiler.training_transpile(test_program)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -206,7 +215,7 @@ class TestQuantizeTranspiler(unittest.TestCase):
batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.program_guard(main):
for _ in range(iter):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册