From d0a1e2b6af50ea64d319dda114ed96df6f8198bd Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:25:14 +0800 Subject: [PATCH] fix demo. (#1645) * fix demo. * fix code style. --- demo/quant/quant_embedding/train.py | 11 +++-------- demo/quant/quant_post_hpo/quant_post_hpo.py | 6 ++---- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/demo/quant/quant_embedding/train.py b/demo/quant/quant_embedding/train.py index f5343b96..1236a87d 100755 --- a/demo/quant/quant_embedding/train.py +++ b/demo/quant/quant_embedding/train.py @@ -97,21 +97,16 @@ def convert_python_to_tensor(weight, batch_size, sample_reader): if len(result[0]) == batch_size: tensor_result = [] for tensor in result: - t = paddle.Tensor() dat = np.array(tensor, dtype='int64') if len(dat.shape) > 2: dat = dat.reshape((dat.shape[0], dat.shape[2])) elif len(dat.shape) == 1: dat = dat.reshape((-1, 1)) - t.set(dat, paddle.CPUPlace()) - tensor_result.append(t) - tt = paddle.Tensor() + tensor_result.append(dat) neg_array = cs.searchsorted(np.random.sample(args.nce_num)) neg_array = np.tile(neg_array, batch_size) - tt.set( - neg_array.reshape((batch_size, args.nce_num)), - paddle.CPUPlace()) - tensor_result.append(tt) + tensor_result.append( + neg_array.reshape((batch_size, args.nce_num))) yield tensor_result result = [[], []] diff --git a/demo/quant/quant_post_hpo/quant_post_hpo.py b/demo/quant/quant_post_hpo/quant_post_hpo.py index f96e869c..c6bd5509 100755 --- a/demo/quant/quant_post_hpo/quant_post_hpo.py +++ b/demo/quant/quant_post_hpo/quant_post_hpo.py @@ -1,7 +1,5 @@ import os import sys -import math -import time import numpy as np import paddle import logging @@ -52,8 +50,8 @@ def quantize(args): eval_sample_generator=reader_generator(reader.val()), model_filename=args.model_filename, params_filename=args.params_filename, - save_model_filename='__model__', - save_params_filename='__params__', + save_model_filename='model.pdmodel', + save_params_filename='model.pdiparams', quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], weight_quantize_type=['channel_wise_abs_max'], runcount_limit=args.max_model_quant_count) -- GitLab