op.py 2.4 KB
Newer Older
Y
yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
op package provide some common ops for building paddle model.

"""
import numpy as np
import paddle.fluid as fluid

from pgl.utils import paddle_helper


def nested_lod_reset(data, reset_index):
    """Reset the lod information as the one in given reset_index.

    This function apply :code:`fluid.layers.lod_reset` recursively
    to all the tensor in nested data.

    Args:
        data: A dictionary of tensor or tensor.
        reset_index: A variable which the target lod information comes from.

    Return:
        Return a dictionary of LodTensor of LodTensor.

    """
    if data is None:
        return None
    elif isinstance(data, dict):
        new_data = {}
        for key, value in data.items():
            new_data[key] = nested_lod_reset(value, reset_index)
        return new_data
    else:
        return fluid.layers.lod_reset(data, reset_index)


def read_rows(data, index):
    """Slice tensor with given index from dictionary of tensor or tensor

    This function helps to slice data from nested dictionary structure.

    Args:
        data: A dictionary of tensor or tensor

        index: A tensor of slicing index

    Returns:
        Return a dictionary of tensor or tensor.
    """
    if data is None:
        return None
    elif isinstance(data, dict):
        new_data = {}
        for key, value in data.items():
            new_data[key] = read_rows(value, index)
        return new_data
    else:
        return paddle_helper.gather(data, index)
Y
Yelrose 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85


class RowReader(object):
    """Memory Efficient RowReader 
    """
    def __init__(self, nfeat, index):
        self.nfeat = nfeat
        self.loaded_nfeat = {}
        self.index = index
 
    def __getitem__(self, key):
        if key not in self.loaded_nfeat:
            self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index)
        return self.loaded_nfeat[key]