client.py 3.0 KB
Newer Older
D
dzhwinter 已提交
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17
import ctypes
import os

18 19 20 21
__lib__ = None


def get_c_lib():
Y
Yu Yang 已提交
22
    global __lib__
23 24 25 26
    if __lib__ is None:
        path = os.path.join(os.path.dirname(__file__), "libpaddle_master.so")
        __lib__ = ctypes.cdll.LoadLibrary(path)
    return __lib__
27 28 29 30 31 32 33


class client(object):
    """
    client is a client to the master server.
    """

34
    def __init__(self, etcd_endpoints, timeout_sec, buf_size=0):
35 36
        self.c = get_c_lib().paddle_new_etcd_master_client(
            etcd_endpoints, timeout_sec, buf_size)
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    def request_save_model(self, trainer_id, block_ms):
        """request to save model

        Conventionally the 0-th trainer will save model. But in
        distributed training, any trainer could be killed. This
        function asks the master server if the trainer should proceed
        with saving model.

        :param trainer_id: trainer id.
        :param block_ms: number of millisecond that other save model
        will be blocked if this save model request succeeded.

        Returns:
            int: 1 if the save the model request is approved, 0 if
            does the request is rejected because other trainer is
            saving the model, -1 if error happened.

        """
56 57
        return get_c_lib().paddle_request_save_model(self.c, trainer_id,
                                                     block_ms)
58 59

    def release(self):
60
        get_c_lib().paddle_release_master_client(self.c)
61 62 63 64 65 66 67 68
        self.c = None

    def set_dataset(self, paths):
        holder_type = ctypes.c_char_p * len(paths)
        holder = holder_type()
        for idx, path in enumerate(paths):
            c_ptr = ctypes.c_char_p(path)
            holder[idx] = c_ptr
69
        get_c_lib().paddle_set_dataset(self.c, holder, len(paths))
70 71

    def next_record(self):
72 73 74 75 76 77
        """gets next record for training

        Returns:
            string: the record.
            int: error code, 0 if successful, < 0 otherwise.
        """
78 79
        p = ctypes.c_char_p()
        ret = ctypes.pointer(p)
80
        size = get_c_lib().paddle_next_record(self.c, ret)
G
gongweibao 已提交
81
        if size < 0:
G
gongweibao 已提交
82 83 84
            # Error
            return None, size

85
        if size == 0:
H
Helin Wang 已提交
86
            # Empty record
G
gongweibao 已提交
87 88
            return "", 0

89
        record = ret.contents.value[:size]
H
Helin Wang 已提交
90
        # Memory created from C should be freed.
91
        get_c_lib().mem_free(ret.contents)
G
gongweibao 已提交
92
        return record, 0
Y
Yancey 已提交
93 94 95

    def paddle_start_get_records(self, pass_id):
        get_c_lib().paddle_start_get_records(self.c, pass_id)