From 44e3306b9c2aa5ec80b2e8c690de7a8d9cf18c5d Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:56:25 +0800 Subject: [PATCH] fix demo's bug (#1651) * fix demo's bug * fix code style --- demo/quant/quant_embedding/infer.py | 78 +++++++++++++++-------------- demo/quant/quant_embedding/net.py | 2 +- 2 files changed, 42 insertions(+), 38 deletions(-) diff --git a/demo/quant/quant_embedding/infer.py b/demo/quant/quant_embedding/infer.py index bd9fa936..611b8080 100755 --- a/demo/quant/quant_embedding/infer.py +++ b/demo/quant/quant_embedding/infer.py @@ -1,11 +1,8 @@ import argparse import sys import time -import math -import unittest -import contextlib import numpy as np -import six +import os import paddle import net import utils @@ -72,7 +69,7 @@ def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): values, pred = net.infer_network(vocab_size, emb_size) for epoch in range(start_index, last_index + 1): copy_program = main_program.clone() - model_path = model_dir + "/pass-" + str(epoch) + model_path = os.path.join(model_dir, "pass-" + str(epoch)) paddle.static.load(copy_program, model_path, exe) if args.emb_quant: config = { @@ -92,29 +89,33 @@ def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): for data in test_reader(): step_id += 1 b_size = len([dat[0] for dat in data]) - wa = np.array( - [dat[0] for dat in data]).astype("int64").reshape( - b_size, 1) - wb = np.array( - [dat[1] for dat in data]).astype("int64").reshape( - b_size, 1) - wc = np.array( - [dat[2] for dat in data]).astype("int64").reshape( - b_size, 1) + wa = np.array([dat[0] + for dat in data]).astype("int64").reshape( + b_size, 1) + wb = np.array([dat[1] + for dat in data]).astype("int64").reshape( + b_size, 1) + wc = np.array([dat[2] + for dat in data]).astype("int64").reshape( + b_size, 1) label = [dat[3] for dat in data] input_word = [dat[4] for dat in data] - para = exe.run(copy_program, - feed={ - "analogy_a": wa, - "analogy_b": wb, - "analogy_c": wc, - "all_label": - np.arange(vocab_size).reshape( - vocab_size, 1).astype("int64"), - }, - fetch_list=[pred.name, values], - return_numpy=False) + para = exe.run( + copy_program, + feed={ + "analogy_a": + wa, + "analogy_b": + wb, + "analogy_c": + wc, + "all_label": + np.arange(vocab_size).reshape(vocab_size, + 1).astype("int64"), + }, + fetch_list=[pred.name, values], + return_numpy=False) pre = np.array(para[0]) val = np.array(para[1]) for ii in range(len(label)): @@ -156,24 +157,27 @@ def infer_step(args, vocab_size, test_reader, use_cuda, i2w): for data in test_reader(): step_id += 1 b_size = len([dat[0] for dat in data]) - wa = np.array( - [dat[0] for dat in data]).astype("int64").reshape( - b_size, 1) - wb = np.array( - [dat[1] for dat in data]).astype("int64").reshape( - b_size, 1) - wc = np.array( - [dat[2] for dat in data]).astype("int64").reshape( - b_size, 1) + wa = np.array([dat[0] for dat in + data]).astype("int64").reshape( + b_size, 1) + wb = np.array([dat[1] for dat in + data]).astype("int64").reshape( + b_size, 1) + wc = np.array([dat[2] for dat in + data]).astype("int64").reshape( + b_size, 1) label = [dat[3] for dat in data] input_word = [dat[4] for dat in data] para = exe.run( copy_program, feed={ - "analogy_a": wa, - "analogy_b": wb, - "analogy_c": wc, + "analogy_a": + wa, + "analogy_b": + wb, + "analogy_c": + wc, "all_label": np.arange(vocab_size).reshape(vocab_size, 1), }, diff --git a/demo/quant/quant_embedding/net.py b/demo/quant/quant_embedding/net.py index 27fdc4e0..aff9303d 100755 --- a/demo/quant/quant_embedding/net.py +++ b/demo/quant/quant_embedding/net.py @@ -131,7 +131,7 @@ def infer_network(vocab_size, emb_size): emb_c = paddle.static.nn.embedding( input=analogy_c, size=[vocab_size, emb_size], param_attr="emb") target = paddle.add(paddle.add(emb_b, -emb_a), emb_c) - emb_all_label_l2 = paddle.linalg.norm(emb_all_label, p=2, axis=1) + emb_all_label_l2 = F.normalize(emb_all_label, p=2, axis=1) dist = paddle.matmul(x=target, y=emb_all_label_l2, transpose_y=True) values, pred_idx = paddle.topk(x=dist, k=4) return values, pred_idx -- GitLab