未验证 提交 d0a1e2b6 编写于 作者: Z zhouzj 提交者: GitHub

fix demo. (#1645)

* fix demo.

* fix code style.
上级 ccc52673
......@@ -97,21 +97,16 @@ def convert_python_to_tensor(weight, batch_size, sample_reader):
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
t = paddle.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, paddle.CPUPlace())
tensor_result.append(t)
tt = paddle.Tensor()
tensor_result.append(dat)
neg_array = cs.searchsorted(np.random.sample(args.nce_num))
neg_array = np.tile(neg_array, batch_size)
tt.set(
neg_array.reshape((batch_size, args.nce_num)),
paddle.CPUPlace())
tensor_result.append(tt)
tensor_result.append(
neg_array.reshape((batch_size, args.nce_num)))
yield tensor_result
result = [[], []]
......
import os
import sys
import math
import time
import numpy as np
import paddle
import logging
......@@ -52,8 +50,8 @@ def quantize(args):
eval_sample_generator=reader_generator(reader.val()),
model_filename=args.model_filename,
params_filename=args.params_filename,
save_model_filename='__model__',
save_params_filename='__params__',
save_model_filename='model.pdmodel',
save_params_filename='model.pdiparams',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册