util.py 2.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

J
jiangjiajun 已提交
15
from paddle.fluid.proto import framework_pb2
J
jiangjiajun 已提交
16 17
import paddle.fluid as fluid
import numpy
J
jiangjiajun 已提交
18 19 20
import math
import os

J
jiangjiajun 已提交
21 22 23 24 25

def string(param):
    return "\'{}\'".format(param)


J
jiangjiajun 已提交
26 27 28
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
29 30
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
J
jiangjiajun 已提交
31

J
jiangjiajun 已提交
32
    return [pad0, pad1]
J
jiangjiajun 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48


def export_paddle_param(param, param_name, dir):
    dtype_map = {
        "int16": [framework_pb2.VarType.INT16, 'h'],
        "int32": [framework_pb2.VarType.INT32, 'i'],
        "int64": [framework_pb2.VarType.INT64, 'q'],
        "float16": [framework_pb2.VarType.FP16, 'e'],
        "float32": [framework_pb2.VarType.FP32, 'f'],
        "float64": [framework_pb2.VarType.FP64, 'd']
    }
    shape = param.shape
    if len(shape) == 0:
        assert param.size == 1, "Unexpected situation happend!"
        shape = [1]
    assert str(param.dtype) in dtype_map, "Unknown dtype of params."
S
SunAhong1993 已提交
49 50
    if not os.path.exists(dir):
        os.makedirs(dir)
J
jiangjiajun 已提交
51 52

    fp = open(os.path.join(dir, param_name), 'wb')
J
jiangjiajun 已提交
53 54 55
    numpy.array([0], dtype='int32').tofile(fp)
    numpy.array([0], dtype='int64').tofile(fp)
    numpy.array([0], dtype='int32').tofile(fp)
J
jiangjiajun 已提交
56 57 58 59
    tensor_desc = framework_pb2.VarType.TensorDesc()
    tensor_desc.data_type = dtype_map[str(param.dtype)][0]
    tensor_desc.dims.extend(shape)
    desc_size = tensor_desc.ByteSize()
J
jiangjiajun 已提交
60
    numpy.array([desc_size], dtype='int32').tofile(fp)
J
jiangjiajun 已提交
61 62 63
    fp.write(tensor_desc.SerializeToString())
    param.tofile(fp)
    fp.close()
J
jiangjiajun 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78


def init_net(param_dir="./"):
    import os
    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())

    def if_exist(var):
        b = os.path.exists(os.path.join(param_dir, var.name))
        return b

    fluid.io.load_vars(exe,
                       param_dir,
                       fluid.default_main_program(),
                       predicate=if_exist)