未验证 提交 cb57443e 编写于 作者: C Chang Xu 提交者: GitHub

AnalysisQAT fit new Quant (#1630)

上级 785004f5
...@@ -160,7 +160,7 @@ class AnalysisQAT(object): ...@@ -160,7 +160,7 @@ class AnalysisQAT(object):
op_node.input('X')[0]) op_node.input('X')[0])
out_var = graph._find_node_by_name(op_node.outputs, out_var = graph._find_node_by_name(op_node.outputs,
op_node.output('Y')[0]) op_node.output('Y')[0])
if 'quantized' in in_var.name(): if not in_var.persistable():
# act # act
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
o_ns = op.output_arg_names() o_ns = op.output_arg_names()
...@@ -173,8 +173,9 @@ class AnalysisQAT(object): ...@@ -173,8 +173,9 @@ class AnalysisQAT(object):
else: else:
# weight # weight
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
float_name = in_var.name().replace('.quantized', '')
float_weight = np.array( float_weight = np.array(
float_scope.find_var(in_var.name()).get_tensor()) float_scope.find_var(float_name).get_tensor())
with paddle.static.scope_guard(quant_scope): with paddle.static.scope_guard(quant_scope):
quant_scope.find_var(in_var.name()).get_tensor().set( quant_scope.find_var(in_var.name()).get_tensor().set(
float_weight, self.places) float_weight, self.places)
...@@ -216,7 +217,7 @@ class AnalysisQAT(object): ...@@ -216,7 +217,7 @@ class AnalysisQAT(object):
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
float_preds = executor.run(program=float_program, float_preds = executor.run(program=float_program,
feed=data, feed=data,
fetch_list=self.fetch_list, fetch_list=self.float_fetch_list,
return_numpy=False) return_numpy=False)
float_preds = float_preds[0] float_preds = float_preds[0]
with paddle.static.scope_guard(quant_scope): with paddle.static.scope_guard(quant_scope):
...@@ -253,11 +254,12 @@ class AnalysisQAT(object): ...@@ -253,11 +254,12 @@ class AnalysisQAT(object):
idx + 1, len(self.inputs_of_quantized_op), weight_name)) idx + 1, len(self.inputs_of_quantized_op), weight_name))
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
[float_program, _, _] = load_inference_model( [float_program, self.float_feed_list,
self.float_model_dir, self.float_fetch_list] = load_inference_model(
executor=executor, self.float_model_dir,
model_filename=self.model_filename, executor=executor,
params_filename=self.params_filename) model_filename=self.model_filename,
params_filename=self.params_filename)
with paddle.static.scope_guard(quant_scope): with paddle.static.scope_guard(quant_scope):
[program, self.feed_list, [program, self.feed_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册