未验证 提交 ee44bcdd 编写于 作者: Z Zhen Wang 提交者: GitHub

add more unit tests for imperative qat. test=develop (#25486)

上级 fd961b0d
......@@ -254,6 +254,172 @@ class TestImperativeQat(unittest.TestCase):
np.allclose(after_save, before_save.numpy()),
msg='Failed to save the inference quantized model.')
def test_qat_acc(self):
def _build_static_lenet(main, startup, is_test=False, seed=1000):
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
main.random_seed = seed
startup.random_seed = seed
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
prediction = StaticLenet(img)
if not is_test:
loss = fluid.layers.cross_entropy(
input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
else:
avg_loss = prediction
return img, label, avg_loss
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
weight_quantize_type = 'abs_max'
activation_quant_type = 'moving_average_abs_max'
param_init_map = {}
seed = 1000
lr = 0.1
# imperative train
_logger.info(
"--------------------------dynamic graph qat--------------------------"
)
imperative_qat = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quant_type)
with fluid.dygraph.guard():
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
lenet = ImperativeLenet()
fixed_state = {}
for name, param in lenet.named_parameters():
p_shape = param.numpy().shape
p_value = param.numpy()
if name.endswith("bias"):
value = np.zeros_like(p_value).astype('float32')
else:
value = np.random.normal(
loc=0.0, scale=0.01, size=np.product(p_shape)).reshape(
p_shape).astype('float32')
fixed_state[name] = value
param_init_map[param.name] = value
lenet.set_dict(fixed_state)
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters())
dynamic_loss_rec = []
lenet.train()
for batch_id, data in enumerate(reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
dynamic_loss_rec.append(avg_loss.numpy()[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
imperative_qat.save_quantized_model(
dirname="./dynamic_mnist",
model=lenet,
input_shape=[(1, 28, 28)],
input_dtype=['float32'],
feed=[0],
fetch=[0])
# static graph train
_logger.info(
"--------------------------static graph qat--------------------------"
)
static_loss_rec = []
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
infer = fluid.Program()
startup = fluid.Program()
static_img, static_label, static_loss = _build_static_lenet(
main, startup, False, seed)
infer_img, _, infer_pre = _build_static_lenet(infer, startup, True,
seed)
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
opt = AdamOptimizer(learning_rate=lr)
opt.minimize(static_loss)
scope = core.Scope()
with fluid.scope_guard(scope):
exe.run(startup)
for param in main.all_parameters():
param_tensor = scope.var(param.name).get_tensor()
param_tensor.set(param_init_map[param.name], place)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
infer_graph = IrGraph(core.Graph(infer.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
transform_pass.apply(main_graph)
transform_pass.apply(infer_graph)
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=static_loss.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(
feed_list=[static_img, static_label], place=place)
with fluid.scope_guard(scope):
for batch_id, data in enumerate(reader()):
loss_v, = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[static_loss])
static_loss_rec.append(loss_v[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', loss_v))
save_program = infer_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model("./static_mnist", [infer_img.name],
[infer_pre], exe, save_program)
rtol = 1e-05
atol = 1e-08
for i, (loss_d,
loss_s) in enumerate(zip(dynamic_loss_rec, static_loss_rec)):
diff = np.abs(loss_d - loss_s)
if diff > (atol + rtol * np.abs(loss_s)):
_logger.info(
"diff({}) at {}, dynamic loss = {}, static loss = {}".
format(diff, i, loss_d, loss_s))
break
self.assertTrue(
np.allclose(
np.array(dynamic_loss_rec),
np.array(static_loss_rec),
rtol=rtol,
atol=atol,
equal_nan=True),
msg='Failed to do the imperative qat.')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册