未验证 提交 cb57443e 编写于 作者: C Chang Xu 提交者: GitHub

AnalysisQAT fit new Quant (#1630)

上级 785004f5
......@@ -160,7 +160,7 @@ class AnalysisQAT(object):
op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0])
if 'quantized' in in_var.name():
if not in_var.persistable():
# act
for op in graph.all_op_nodes():
o_ns = op.output_arg_names()
......@@ -173,8 +173,9 @@ class AnalysisQAT(object):
else:
# weight
with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array(
float_scope.find_var(in_var.name()).get_tensor())
float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, self.places)
......@@ -216,7 +217,7 @@ class AnalysisQAT(object):
with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program,
feed=data,
fetch_list=self.fetch_list,
fetch_list=self.float_fetch_list,
return_numpy=False)
float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope):
......@@ -253,11 +254,12 @@ class AnalysisQAT(object):
idx + 1, len(self.inputs_of_quantized_op), weight_name))
with paddle.static.scope_guard(float_scope):
[float_program, _, _] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
[float_program, self.float_feed_list,
self.float_fetch_list] = load_inference_model(
self.float_model_dir,
executor=executor,
model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope):
[program, self.feed_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册