提交 8273dd79 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #3045 from reyoung/feature/make_golang_client_lazy_load

Make C lib in `paddle.v2.master.client` lazy load
import ctypes import ctypes
import os import os
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so") __lib__ = None
lib = ctypes.cdll.LoadLibrary(path)
def get_c_lib():
global __lib__
if __lib__ is None:
path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
__lib__ = ctypes.cdll.LoadLibrary(path)
return __lib__
class client(object): class client(object):
...@@ -11,8 +18,8 @@ class client(object): ...@@ -11,8 +18,8 @@ class client(object):
""" """
def __init__(self, etcd_endpoints, timeout_sec, buf_size=0): def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
self.c = lib.paddle_new_etcd_master_client(etcd_endpoints, timeout_sec, self.c = get_c_lib().paddle_new_etcd_master_client(
buf_size) etcd_endpoints, timeout_sec, buf_size)
def request_save_model(self, trainer_id, block_ms): def request_save_model(self, trainer_id, block_ms):
"""request to save model """request to save model
...@@ -32,10 +39,11 @@ class client(object): ...@@ -32,10 +39,11 @@ class client(object):
saving the model, -1 if error happened. saving the model, -1 if error happened.
""" """
return lib.paddle_request_save_model(self.c, trainer_id, block_ms) return get_c_lib().paddle_request_save_model(self.c, trainer_id,
block_ms)
def release(self): def release(self):
lib.paddle_release_master_client(self.c) get_c_lib().paddle_release_master_client(self.c)
self.c = None self.c = None
def set_dataset(self, paths): def set_dataset(self, paths):
...@@ -45,7 +53,7 @@ class client(object): ...@@ -45,7 +53,7 @@ class client(object):
for idx, path in enumerate(paths): for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path) c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr holder[idx] = c_ptr
lib.paddle_set_dataset(self.c, holder, len(paths)) get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
def next_record(self): def next_record(self):
"""gets next record for training """gets next record for training
...@@ -56,7 +64,7 @@ class client(object): ...@@ -56,7 +64,7 @@ class client(object):
""" """
p = ctypes.c_char_p() p = ctypes.c_char_p()
ret = ctypes.pointer(p) ret = ctypes.pointer(p)
size = lib.paddle_next_record(self.c, ret) size = get_c_lib().paddle_next_record(self.c, ret)
if size < 0: if size < 0:
# Error # Error
return None, size return None, size
...@@ -67,5 +75,5 @@ class client(object): ...@@ -67,5 +75,5 @@ class client(object):
record = ret.contents.value[:size] record = ret.contents.value[:size]
# Memory created from C should be freed. # Memory created from C should be freed.
lib.mem_free(ret.contents) get_c_lib().mem_free(ret.contents)
return record, 0 return record, 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册