model.py 4.1 KB
Newer Older
S
suweiyue 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
import numpy as np
import pgl
import paddle.fluid as F
import paddle.fluid.layers as L
from models.encoder import Encoder
from models.loss import Loss

class BaseModel(object):

    def __init__(self, config):
        self.config = config
        datas, graph_wrappers, loss, outputs = self.forward()
        self.build(datas, graph_wrappers, loss, outputs)

    def forward(self):
        raise NotImplementedError

    def build(self, datas, graph_wrappers, loss, outputs):
        self.datas = datas
        self.graph_wrappers = graph_wrappers
        self.loss = loss
        self.outputs = outputs
        self.build_feed_list()
        self.build_data_loader()

    def build_feed_list(self):
        self.feed_list = []
        for i in range(len(self.graph_wrappers)):
            self.feed_list.extend(self.graph_wrappers[i].holder_list)
        self.feed_list.extend(self.datas)

    def build_data_loader(self):
        self.data_loader = F.io.PyReader(
            feed_list=self.feed_list, capacity=20, use_double_buffer=True, iterable=True)


class LinkPredictModel(BaseModel):

    def forward(self):
        # datas
        user_index = L.data(
            "user_index", shape=[None], dtype="int64", append_batch_size=False)
        pos_item_index = L.data(
            "pos_item_index", shape=[None], dtype="int64", append_batch_size=False)
        neg_item_index = L.data(
            "neg_item_index", shape=[None], dtype="int64", append_batch_size=False)
        user_real_index = L.data(
            "user_real_index", shape=[None], dtype="int64", append_batch_size=False)
        pos_item_real_index = L.data(
            "pos_item_real_index", shape=[None], dtype="int64", append_batch_size=False)
51
        datas = [user_index, pos_item_index, neg_item_index, user_real_index, pos_item_real_index]
S
suweiyue 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
         
        # graph_wrappers
        graph_wrappers = []
        node_feature_info, edge_feature_info = [], []
        node_feature_info.append(('index', [None], np.dtype('int64')))
        node_feature_info.append(('term_ids', [None, None], np.dtype('int64')))
        for i in range(self.config.num_layers):
            graph_wrappers.append(
                pgl.graph_wrapper.GraphWrapper(
                    "layer_%s" % i, node_feat=node_feature_info, edge_feat=edge_feature_info))

        # encoder model
        encoder = Encoder.factory(self.config)
        outputs = encoder(graph_wrappers, [user_index, pos_item_index, neg_item_index])
        user_feat, pos_item_feat, neg_item_feat = outputs

        # loss 
        if self.config.neg_type == "batch_neg":
            neg_item_feat = pos_item_feat
        loss_func = Loss.factory(self.config)
        loss = loss_func(user_feat, pos_item_feat, neg_item_feat)

        # set datas, graph_wrappers, loss, outputs
75
        return datas, graph_wrappers, loss, outputs + [user_real_index, pos_item_real_index]
S
suweiyue 已提交
76 77 78 79 80 81 82 83


class NodeClassificationModel(BaseModel):

    def forward(self):
        # inputs
        node_index = L.data(
            "node_index", shape=[None], dtype="int64", append_batch_size=False)
84 85
        node_real_index = L.data(
            "node_real_index", shape=[None], dtype="int64", append_batch_size=False)
S
suweiyue 已提交
86 87
        label = L.data(
            "label", shape=[None], dtype="int64", append_batch_size=False)
88
        datas = [node_index, node_real_index, label]
S
suweiyue 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102

        # graph_wrappers
        graph_wrappers = []
        node_feature_info = []
        node_feature_info.append(('index', [None], np.dtype('int64')))
        node_feature_info.append(('term_ids', [None, None], np.dtype('int64')))
        for i in range(self.config.num_layers):
            graph_wrappers.append(
                pgl.graph_wrapper.GraphWrapper(
                    "layer_%s" % i, node_feat=node_feature_info))

        # encoder model
        encoder = Encoder.factory(self.config)
        outputs = encoder(graph_wrappers, [node_index])
103 104
        feat = outputs[0]
        logits = L.fc(feat, self.config.num_label)
S
suweiyue 已提交
105 106

        # loss 
107
        label = L.reshape(label, [-1, 1])
S
suweiyue 已提交
108
        loss_func = Loss.factory(self.config)
109
        loss = loss_func(logits, label)
S
suweiyue 已提交
110

111
        return datas, graph_wrappers, loss, outputs + [node_real_index, logits]