test_linear_chain_crf_op.py 5.0 KB
Newer Older
C
caoying03 已提交
1 2 3 4 5 6 7 8
import unittest
import random
import numpy as np

from op_test import OpTest


class LinearChainCrfForward(object):
C
caoying03 已提交
9 10
    def __init__(self, seq_start_positions, emission_weights, emission_row_max,
                 emission_exps, transition_weights, transition_exps, labels):
C
caoying03 已提交
11 12 13 14 15 16 17
        self.tag_num = emission_weights.shape[1]
        self.seq_num = len(seq_start_positions) - 1

        self.seq_start_positions = seq_start_positions
        self.labels = labels
        self.x = emission_weights

C
caoying03 已提交
18 19
        self.x_row_max = emission_row_max
        self.x_exps = emission_exps
C
caoying03 已提交
20 21 22

        # unnormalized logits of the transition weights for the start mark.
        self.a = transition_weights[0, :]
C
caoying03 已提交
23
        self.a_exps = transition_exps[0, :]
C
caoying03 已提交
24 25
        # unnormalized logits of the transition weights for the end mark.
        self.b = transition_weights[1, :]
C
caoying03 已提交
26
        self.b_exps = transition_exps[1, :]
C
caoying03 已提交
27 28
        # unnormalized logits of the transition weights for all the other tags.
        self.w = transition_weights[2:, :]
C
caoying03 已提交
29
        self.w_exps = transition_exps[2:, :]
C
caoying03 已提交
30 31 32 33 34

        # The output of linear chain crf operator.
        # alpha is a memo table in dynamic programming to caculate
        # nomalization factor.
        self.alpha = np.zeros(
C
caoying03 已提交
35
            (seq_start_positions[-1], self.tag_num), dtype="float64")
C
caoying03 已提交
36
        self.log_likelihood = np.zeros((self.seq_num, 1))
C
caoying03 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    def _l1_norm(self, x):
        s = np.sum(x)
        x /= s
        return s

    def _forward_a_sequence(self, x, x_row_max, x_exps, label, alpha):
        seq_len = x_row_max.shape[0]
        log_likelihood = 0.

        for i in range(self.tag_num):
            alpha[0, i] = self.a_exps[i] * x_exps[0, i]
        log_likelihood = -x_row_max[0] - np.log(self._l1_norm(alpha[0, :]))

        # calculate the unnormalized logits of the normalization factor.
        for k in range(1, seq_len):
            for i in range(self.tag_num):
                s = 0.
                for j in range(self.tag_num):
                    s += alpha[k - 1, j] * self.w_exps[j, i]
                alpha[k, i] = x_exps[k, i] * s
            log_likelihood -= x_row_max[k] + np.log(self._l1_norm(alpha[k, :]))
        s = 0.
        for i in range(self.tag_num):
            s += alpha[-1, i] * self.b_exps[i]
        log_likelihood -= np.log(s)

64
        # calculate the nominator part.
C
caoying03 已提交
65
        log_likelihood += (
C
caoying03 已提交
66 67
            self.a[label[0]] + x[0, label[0]] + self.b[label[-1]])

C
caoying03 已提交
68
        for k in range(1, seq_len):
C
caoying03 已提交
69
            log_likelihood += (x[k, label[k]] + self.w[label[k - 1], label[k]])
70
        return -log_likelihood
C
caoying03 已提交
71 72 73 74 75 76 77

    def crf_forward_compute(self):
        for i in range(self.seq_num):
            start = self.seq_start_positions[i]
            end = self.seq_start_positions[i + 1]

            self.log_likelihood[i] = self._forward_a_sequence(
C
caoying03 已提交
78
                self.x[start:end, :], self.x_row_max[start:end, :],
C
caoying03 已提交
79 80 81 82 83 84 85
                self.x_exps[start:end, :], self.labels[start:end, :],
                self.alpha[start:end, :])
        return self.alpha, self.log_likelihood


class TestLinearChainCrfOp(OpTest):
    def set_test_data(self):
C
caoying03 已提交
86 87 88
        # TODO(caoying) Fix the unittest by: add the boundary cases when
        # sequence lengths are 1, 2, and 3.

C
caoying03 已提交
89
        SEQ_NUM = 3
C
caoying03 已提交
90
        TAG_NUM = 17
C
caoying03 已提交
91
        MAX_SEQ_LEN = 5
C
caoying03 已提交
92 93 94 95 96 97

        # the linear_chain_crf operator only supports sequence (LoD level = 1)
        lod = [[0]]
        for i in range(SEQ_NUM):
            lod[-1].append(lod[-1][-1] + random.randint(1, MAX_SEQ_LEN))
        emission = np.random.uniform(-1, 1,
C
caoying03 已提交
98
                                     [lod[-1][-1], TAG_NUM]).astype("float64")
C
caoying03 已提交
99 100 101
        emission_row_max = np.amax(emission, axis=1, keepdims=True)
        emission_exps = np.exp(emission - emission_row_max)

C
caoying03 已提交
102
        transition = np.random.uniform(-0.5, 0.5,
C
caoying03 已提交
103
                                       [TAG_NUM + 2, TAG_NUM]).astype("float64")
C
caoying03 已提交
104 105
        transition_exps = np.exp(transition)

C
caoying03 已提交
106 107 108 109 110 111
        labels = np.random.randint(
            low=0, high=TAG_NUM, size=(lod[-1][-1], 1), dtype="int32")

        self.inputs = {
            "Emission": (emission, lod),
            "Transition": transition,
112
            "Label": (labels, lod)
C
caoying03 已提交
113
        }
C
caoying03 已提交
114 115 116
        crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
                                    emission_exps, transition, transition_exps,
                                    labels)
C
caoying03 已提交
117 118
        alpha, log_likelihood = crf.crf_forward_compute()

C
caoying03 已提交
119 120 121 122 123 124
        self.outputs = {
            "Alpha": alpha,
            "EmissionExps": emission_exps,
            "TransitionExps": transition_exps,
            "LogLikelihood": log_likelihood
        }
C
caoying03 已提交
125 126 127 128 129 130 131 132

    def setUp(self):
        self.op_type = "linear_chain_crf"
        self.set_test_data()

    def test_check_output(self):
        self.check_output()

C
caoying03 已提交
133
    def test_check_grad(self):
C
caoying03 已提交
134
        self.check_grad(["Emission", "Transition"], "LogLikelihood")
C
caoying03 已提交
135 136 137

    def test_check_grad_ignore_transition(self):
        self.check_grad(
C
caoying03 已提交
138
            ["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
C
caoying03 已提交
139

C
caoying03 已提交
140 141 142

if __name__ == "__main__":
    unittest.main()