You need to sign in or sign up before continuing.
提交 b2c9efef 编写于 作者: Q Qiao Longfei

add more unit test for lookup_remote_table

test=develop
上级 40f68b13
...@@ -27,7 +27,7 @@ from paddle.fluid.op import Operator ...@@ -27,7 +27,7 @@ from paddle.fluid.op import Operator
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
def run_pserver(use_cuda, sync_mode): def run_pserver(pserver_id, use_cuda, sync_mode):
scope = fluid.core.Scope() scope = fluid.core.Scope()
program = Program() program = Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
...@@ -36,7 +36,10 @@ def run_pserver(use_cuda, sync_mode): ...@@ -36,7 +36,10 @@ def run_pserver(use_cuda, sync_mode):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
# create and initialize Param Variable # create and initialize Param Variable
param = scope.var('table').get_tensor() param = scope.var('table').get_tensor()
param_array = np.full((10, 8), 5.0).astype("float32")
param_array = np.ones((10, 8)).astype("float32")
for i in range(len(param_array)):
param_array[i] *= param_array[i] * i + pserver_id * 10
param.set(param_array, place) param.set(param_array, place)
optimize_block = program._create_block(program.global_block().idx) optimize_block = program._create_block(program.global_block().idx)
...@@ -60,8 +63,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -60,8 +63,8 @@ class TestListenAndServOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.ps_timeout = 5 self.ps_timeout = 5
def _start_pserver(self, use_cuda, sync_mode, pserver_func): def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
p = Process(target=pserver_func, args=(use_cuda, sync_mode)) p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
p.daemon = True p.daemon = True
p.start() p.start()
return p return p
...@@ -85,7 +88,7 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -85,7 +88,7 @@ class TestListenAndServOp(unittest.TestCase):
port = int(f.read().strip()) port = int(f.read().strip())
return port return port
def _run_lookup_table_op(self, place, port): def _run_lookup_table_op_one_pserver(self, place, port):
scope = fluid.core.Scope() scope = fluid.core.Scope()
program = Program() program = Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
...@@ -96,15 +99,17 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -96,15 +99,17 @@ class TestListenAndServOp(unittest.TestCase):
param.set(param_array, place) param.set(param_array, place)
ids = scope.var('Ids').get_tensor() ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1.0], [2.0]]).astype("int64") ids_array = np.array([[1], [2], [5]]).astype("int64")
ids.set(ids_array, place) ids.set(ids_array, place)
ids.set_lod([[0, 1, 2]]) ids_lod = [[0, 1, 2, 3]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor() out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port)] emaps = ['127.0.0.1:' + str(port)]
table_names = ['table'] table_names = ['table']
height_sections = [10] height_sections = [10]
# create and run sgd operator # create and run sgd operator
lookup_table_op = Operator( lookup_table_op = Operator(
"lookup_table", "lookup_table",
...@@ -120,24 +125,75 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -120,24 +125,75 @@ class TestListenAndServOp(unittest.TestCase):
# get and compare result # get and compare result
result_array = np.array(out) result_array = np.array(out)
print(result_array) self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
self.assertTrue((result_array[0] == 5).all()) def _run_lookup_table_op_two_pserver(self, place, port0, port1):
self.assertTrue((result_array[0] == 5).all()) scope = fluid.core.Scope()
program = Program()
with fluid.scope_guard(scope):
with program_guard(program, startup_program=Program()):
# create and initialize Param Variable
param = scope.var('W').get_tensor()
param_array = np.full((10, 8), 1.0).astype("float32")
param.set(param_array, place)
ids = scope.var('Ids').get_tensor()
ids_array = np.array([[1], [2], [11], [13]]).astype("int64")
ids.set(ids_array, place)
ids_lod = [[0, 2, 3, 4]]
ids.set_lod(ids_lod)
out = scope.var('Out').get_tensor()
emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
table_names = ['table', 'table']
height_sections = [10, 20]
# create and run sgd operator
lookup_table_op = Operator(
"lookup_table",
W='W',
Ids='Ids',
Out='Out',
remote_prefetch=True,
epmap=emaps,
table_names=table_names,
height_sections=height_sections)
lookup_table_op.run(scope, place)
# get and compare result
result_array = np.array(out)
self.assertEqual(out.lod(), ids_lod)
self.assertEqual(list(result_array.shape), [len(ids_array), 8])
for i in range(len(ids_array)):
id = ids_array[i][0]
self.assertTrue((result_array[i] == id).all())
def test_lookup_remote_table(self): def test_lookup_remote_table(self):
# run pserver on CPU in sync mode # run pserver on CPU in sync mode
p1 = self._start_pserver(False, True, run_pserver) p0 = self._start_pserver(0, False, True, run_pserver)
self._wait_ps_ready(p0.pid)
port0 = self._get_pserver_port(p0.pid)
p1 = self._start_pserver(1, False, True, run_pserver)
self._wait_ps_ready(p1.pid) self._wait_ps_ready(p1.pid)
port = self._get_pserver_port(p1.pid) port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()] places = [core.CPUPlace()]
# if core.is_compiled_with_cuda(): # if core.is_compiled_with_cuda():
# places.append(core.CUDAPlace(0)) # places.append(core.CUDAPlace(0))
for place in places: for place in places:
self._run_lookup_table_op(place, port) self._run_lookup_table_op_one_pserver(place, port0)
self._run_lookup_table_op_two_pserver(place, port0, port1)
# raise SIGTERM to pserver # raise SIGTERM to pserver
os.kill(p0.pid, signal.SIGINT)
p0.join()
os.kill(p1.pid, signal.SIGINT) os.kill(p1.pid, signal.SIGINT)
p1.join() p1.join()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册