提交 bac761fb 编写于 作者: W wopeizl 提交者: qingqing01

fix the out of memory issue test=develop (#3486)

上级 7dcc3890
......@@ -26,6 +26,7 @@ import numpy as np
import sys
import paddle.fluid as fluid
from paddle.fluid import core
import multiprocessing as mp
def print_arguments(args):
......@@ -68,21 +69,23 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
help=help + ' Default: %(default)s.',
**kwargs)
def fmt_time():
""" get formatted time for now
"""
now_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
return now_str
def recall_topk(fea, lab, k = 1):
def recall_topk_ori(fea, lab, k):
fea = np.array(fea)
fea = fea.reshape(fea.shape[0], -1)
n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1)
fea = fea / n
a = np.sum(fea ** 2, 1).reshape(-1, 1)
a = np.sum(fea**2, 1).reshape(-1, 1)
b = a.T
ab = np.dot(fea, fea.T)
d = a + b - 2*ab
d = a + b - 2 * ab
d = d + np.eye(len(fea)) * 1e8
sorted_index = np.argsort(d, 1)
res = 0
......@@ -95,13 +98,72 @@ def recall_topk(fea, lab, k = 1):
res = res / len(fea)
return res
def func(param):
sharedlist, s, e = param
fea, a, b = sharedlist
ab = np.dot(fea[s:e], fea.T)
d = a[s:e] + b - 2 * ab
for i in range(e - s):
d[i][s + i] += 1e8
sorted_index = np.argsort(d, 1)[:, :10]
return sorted_index
def recall_topk_parallel(fea, lab, k):
fea = np.array(fea)
fea = fea.reshape(fea.shape[0], -1)
n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1)
fea = fea / n
a = np.sum(fea**2, 1).reshape(-1, 1)
b = a.T
sharedlist = mp.Manager().list()
sharedlist.append(fea)
sharedlist.append(a)
sharedlist.append(b)
N = 100
L = fea.shape[0] / N
params = []
for i in xrange(N):
if i == N - 1:
s, e = int(i * L), int(fea.shape[0])
else:
s, e = int(i * L), int((i + 1) * L)
params.append([sharedlist, s, e])
pool = mp.Pool(processes=4)
sorted_index_list = pool.map(func, params)
pool.close()
pool.join()
sorted_index = np.vstack(sorted_index_list)
res = 0
for i in range(len(fea)):
for j in range(k):
pred = lab[sorted_index[i][j]]
if lab[i] == pred:
res += 1.0
break
res = res / len(fea)
return res
def recall_topk(fea, lab, k=1):
if fea.shape[0] < 20:
return recall_topk_ori(fea, lab, k)
else:
return recall_topk_parallel(fea, lab, k)
def get_gpu_num():
visibledevice = os.getenv('CUDA_VISIBLE_DEVICES')
if visibledevice:
devicenum = len(visibledevice.split(','))
else:
devicenum = subprocess.check_output(
[str.encode('nvidia-smi'), str.encode('-L')]).decode('utf-8').count('\n')
[str.encode('nvidia-smi'), str.encode('-L')]).decode('utf-8').count(
'\n')
return devicenum
def check_cuda(use_cuda, err = \
......@@ -114,4 +176,3 @@ def check_cuda(use_cuda, err = \
sys.exit(1)
except Exception as e:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册