client.py 3.0 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15 16
import ctypes
import os

17 18 19 20
__lib__ = None


def get_c_lib():
Y
Yu Yang 已提交
21
    global __lib__
22 23 24 25
    if __lib__ is None:
        path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
        __lib__ = ctypes.cdll.LoadLibrary(path)
    return __lib__
26 27 28 29 30 31 32


class client(object):
    """
    client is a client to the master server.
    """

33
    def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
34 35
        self.c = get_c_lib().paddle_new_etcd_master_client(
            etcd_endpoints, timeout_sec, buf_size)
36

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    def request_save_model(self, trainer_id, block_ms):
        """request to save model

        Conventionally the 0-th trainer will save model. But in
        distributed training, any trainer could be killed. This
        function asks the master server if the trainer should proceed
        with saving model.

        :param trainer_id: trainer id.
        :param block_ms: number of millisecond that other save model
        will be blocked if this save model request succeeded.

        Returns:
            int: 1 if the save the model request is approved, 0 if
            does the request is rejected because other trainer is
            saving the model, -1 if error happened.

        """
55 56
        return get_c_lib().paddle_request_save_model(self.c, trainer_id,
                                                     block_ms)
57 58

    def release(self):
59
        get_c_lib().paddle_release_master_client(self.c)
60 61 62 63 64 65 66 67
        self.c = None

    def set_dataset(self, paths):
        holder_type = ctypes.c_char_p * len(paths)
        holder = holder_type()
        for idx, path in enumerate(paths):
            c_ptr = ctypes.c_char_p(path)
            holder[idx] = c_ptr
68
        get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
69 70

    def next_record(self):
71 72 73 74 75 76
        """gets next record for training

        Returns:
            string: the record.
            int: error code, 0 if successful, < 0 otherwise.
        """
77 78
        p = ctypes.c_char_p()
        ret = ctypes.pointer(p)
79
        size = get_c_lib().paddle_next_record(self.c, ret)
G
gongweibao 已提交
80
        if size < 0:
G
gongweibao 已提交
81 82 83
            # Error
            return None, size

84
        if size == 0:
H
Helin Wang 已提交
85
            # Empty record
G
gongweibao 已提交
86 87
            return "", 0

88
        record = ret.contents.value[:size]
H
Helin Wang 已提交
89
        # Memory created from C should be freed.
90
        get_c_lib().mem_free(ret.contents)
G
gongweibao 已提交
91
        return record, 0
Y
Yancey 已提交
92 93 94

    def paddle_start_get_records(self, pass_id):
        get_c_lib().paddle_start_get_records(self.c, pass_id)