test_viterbi_decode_op.py 5.9 KB
Newer Older
J
Jack Zhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import core
import unittest
import paddle
17

J
Jack Zhou 已提交
18 19 20
paddle.enable_static()


21
class Decoder:
J
Jack Zhou 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34
    def __init__(self, transitions, use_tag=True):
        self.transitions = transitions
        self.use_tag = use_tag
        self.start_idx, self.stop_idx = -1, -2

    def __call__(self, inputs, length):
        bs, seq_len, n_label = inputs.shape
        inputs_t = np.transpose(inputs, (1, 0, 2))
        trans_exp = np.expand_dims(self.transitions, axis=0)
        historys = []
        left_length = np.array(length)
        max_seq_len = np.amax(left_length)
        left_length = np.expand_dims(left_length, 1)
35 36 37
        alpha = (
            np.full((bs, n_label), -1e4, dtype='float32')
            if self.use_tag
J
Jack Zhou 已提交
38
            else np.zeros((bs, n_label), dtype='float32')
39
        )
J
Jack Zhou 已提交
40 41 42 43 44 45 46 47 48 49 50
        alpha[:, -1] = 0
        for i, logit in enumerate(inputs_t[:max_seq_len]):
            if i == 0 and not self.use_tag:
                alpha = logit
                left_length = left_length - 1
                continue
            alpha_exp = np.expand_dims(alpha, 2)
            alpha_trn_sum = alpha_exp + trans_exp
            max_res = np.amax(alpha_trn_sum, 1), np.argmax(alpha_trn_sum, 1)
            historys = historys + [max_res[1]] if i >= 1 else []
            alpha_nxt = max_res[0] + logit
51
            mask = left_length > 0
J
Jack Zhou 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64
            alpha = mask * alpha_nxt + (1 - mask) * alpha
            if self.use_tag:
                alpha += (left_length == 1) * trans_exp[:, self.stop_idx]
            left_length = left_length - 1
        scores, last_ids = np.amax(alpha, 1), np.argmax(alpha, 1)
        left_length = left_length[:, 0]
        last_ids_update = last_ids * (left_length >= 0)
        batch_path = [last_ids_update]
        batch_offset = np.arange(bs) * n_label
        for hist in reversed(historys):
            left_length = left_length + 1
            gather_idx = batch_offset + last_ids
            last_ids_update = np.take(hist, gather_idx) * (left_length > 0)
65
            mask = left_length == 0
J
Jack Zhou 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
            last_ids_update = last_ids_update * (1 - mask) + last_ids * mask
            batch_path.insert(0, last_ids_update)
            last_ids = last_ids_update + (left_length < 0) * last_ids
        batch_path = np.stack(batch_path, 1)
        return scores, batch_path


class TestViterbiOp(OpTest):
    def set_attr(self):
        self.dtype = "float32" if core.is_compiled_with_rocm() else "float64"
        self.use_tag = True
        self.bz, self.len, self.ntags = 4, 8, 10

    def setUp(self):
        self.op_type = "viterbi_decode"
H
hong 已提交
81
        self.python_api = paddle.text.viterbi_decode
J
Jack Zhou 已提交
82 83 84 85 86 87 88 89 90 91
        self.set_attr()
        bz, length, ntags = self.bz, self.len, self.ntags
        self.input = np.random.randn(bz, length, ntags).astype(self.dtype)
        self.trans = np.random.randn(ntags, ntags).astype(self.dtype)
        self.length = np.random.randint(1, length + 1, [bz]).astype('int64')
        decoder = Decoder(self.trans, self.use_tag)
        scores, path = decoder(self.input, self.length)
        self.inputs = {
            'Input': self.input,
            'Transition': self.trans,
92
            'Length': self.length,
J
Jack Zhou 已提交
93
        }
94 95 96
        self.attrs = {
            'include_bos_eos_tag': self.use_tag,
        }
J
Jack Zhou 已提交
97 98 99
        self.outputs = {'Scores': scores, 'Path': path}

    def test_output(self):
H
hong 已提交
100
        self.check_output(check_eager=True)
J
Jack Zhou 已提交
101 102 103 104 105 106


class TestViterbiAPI(unittest.TestCase):
    def set_attr(self):
        self.use_tag = True
        self.bz, self.len, self.ntags = 4, 8, 10
107 108 109 110 111
        self.places = (
            [fluid.CPUPlace(), fluid.CUDAPlace(0)]
            if core.is_compiled_with_cuda()
            else [fluid.CPUPlace()]
        )
J
Jack Zhou 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124

    def setUp(self):
        self.set_attr()
        bz, length, ntags = self.bz, self.len, self.ntags
        self.input = np.random.randn(bz, length, ntags).astype('float32')
        self.transitions = np.random.randn(ntags, ntags).astype('float32')
        self.length = np.random.randint(1, length + 1, [bz]).astype('int64')
        decoder = Decoder(self.transitions, self.use_tag)
        self.scores, self.path = decoder(self.input, self.length)

    def check_static_result(self, place):
        bz, length, ntags = self.bz, self.len, self.ntags
        with fluid.program_guard(fluid.Program(), fluid.Program()):
125 126 127 128 129 130
            Input = fluid.data(
                name="Input", shape=[bz, length, ntags], dtype="float32"
            )
            Transition = fluid.data(
                name="Transition", shape=[ntags, ntags], dtype="float32"
            )
J
Jack Zhou 已提交
131 132 133 134 135 136 137
            Length = fluid.data(name="Length", shape=[bz], dtype="int64")
            decoder = paddle.text.ViterbiDecoder(Transition, self.use_tag)
            score, path = decoder(Input, Length)
            exe = fluid.Executor(place)
            feed_list = {
                "Input": self.input,
                "Transition": self.transitions,
138
                "Length": self.length,
J
Jack Zhou 已提交
139 140 141 142 143 144 145 146
            }
            fetches = exe.run(feed=feed_list, fetch_list=[score, path])
            np.testing.assert_allclose(fetches[0], self.scores, rtol=1e-5)
            np.testing.assert_allclose(fetches[1], self.path)

    def test_static_net(self):
        for place in self.places:
            self.check_static_result(place)
H
hong 已提交
147 148 149 150 151


if __name__ == "__main__":
    paddle.enable_static()
    unittest.main()