未验证 提交 d0a1e2b6 编写于 作者: Z zhouzj 提交者: GitHub

fix demo. (#1645)

* fix demo.

* fix code style.
上级 ccc52673
...@@ -97,21 +97,16 @@ def convert_python_to_tensor(weight, batch_size, sample_reader): ...@@ -97,21 +97,16 @@ def convert_python_to_tensor(weight, batch_size, sample_reader):
if len(result[0]) == batch_size: if len(result[0]) == batch_size:
tensor_result = [] tensor_result = []
for tensor in result: for tensor in result:
t = paddle.Tensor()
dat = np.array(tensor, dtype='int64') dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2: if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2])) dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1: elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1)) dat = dat.reshape((-1, 1))
t.set(dat, paddle.CPUPlace()) tensor_result.append(dat)
tensor_result.append(t)
tt = paddle.Tensor()
neg_array = cs.searchsorted(np.random.sample(args.nce_num)) neg_array = cs.searchsorted(np.random.sample(args.nce_num))
neg_array = np.tile(neg_array, batch_size) neg_array = np.tile(neg_array, batch_size)
tt.set( tensor_result.append(
neg_array.reshape((batch_size, args.nce_num)), neg_array.reshape((batch_size, args.nce_num)))
paddle.CPUPlace())
tensor_result.append(tt)
yield tensor_result yield tensor_result
result = [[], []] result = [[], []]
......
import os import os
import sys import sys
import math
import time
import numpy as np import numpy as np
import paddle import paddle
import logging import logging
...@@ -52,8 +50,8 @@ def quantize(args): ...@@ -52,8 +50,8 @@ def quantize(args):
eval_sample_generator=reader_generator(reader.val()), eval_sample_generator=reader_generator(reader.val()),
model_filename=args.model_filename, model_filename=args.model_filename,
params_filename=args.params_filename, params_filename=args.params_filename,
save_model_filename='__model__', save_model_filename='model.pdmodel',
save_params_filename='__params__', save_params_filename='model.pdiparams',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type=['channel_wise_abs_max'], weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count) runcount_limit=args.max_model_quant_count)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册