module.py 6.8 KB
Newer Older
K
kinghuin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# -*- coding:utf-8 -*-
import os
import argparse

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from paddle import fluid
import paddlehub as hub
from paddlehub.module.module import serving, moduleinfo, runnable

try:
    from ddparser import DDParser as DDParserModel
except:
    raise ImportError(
        "The module requires additional dependencies: ddparser. Please run 'pip install ddparser' to install it."
    )


@moduleinfo(
    name="ddparser",
    version="1.0.0",
    summary="Baidu's open-source DDParser model.",
    author="baidu-nlp",
    author_email="",
    type="nlp/syntactic_analysis")
class ddparser(hub.NLPPredictionModule):
    def _initialize(self):
        """
        initialize with the necessary elements
        """
        self.ddp = DDParserModel(prob=True, use_pos=True)
        self.font = font_manager.FontProperties(
35
            fname=os.path.join(self.directory, "SourceHanSans-Regular.ttf"))
K
kinghuin 已提交
36 37 38

    @serving
    def serving_parse(self, texts=[], return_visual=False):
39 40 41 42
        results = self.parse(texts, return_visual)
        if return_visual:
            for i, result in enumerate(results):
                result['visual'] = result['visual'].tolist()
K
kinghuin 已提交
43

44
        return results
K
kinghuin 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

    def parse(self, texts=[], return_visual=False):
        """
        parse the dependency.

        Args:
            texts(list[list[str] or list[list[str]]]): the input texts to be parse. It should be a list with elements: untokenized string or tokens list.
            return_visual(bool): if set True, the result will contain the dependency visualization.

        Returns:
            results(list[dict]): a list, with elements corresponding to each of the elements in texts. The element is a dictionary of shape:
                {
                    'word': list[str], the tokenized words.
                    'head': list[int], the head ids.
                    'deprel': list[str], the dependency relation.
                    'prob': list[float], the prediction probility of the dependency relation.
61 62
                    'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not return.
                    'visual' : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not return.
K
kinghuin 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76
                }
       """

        if not texts:
            return
        if all([isinstance(i, str) and i for i in texts]):
            do_parse = self.ddp.parse
        elif all([isinstance(i, list) and i for i in texts]):
            do_parse = self.ddp.parse_seg
        else:
            raise ValueError("All of the elements should be string or list")
        results = do_parse(texts)
        if return_visual:
            for result in results:
77 78 79
                result['visual'] = self.visualize(
                    result['word'], result['head'], result['deprel'])
        return results
K
kinghuin 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193

    @runnable
    def run_cmd(self, argvs):
        """
        Run as a command
        """
        self.parser = argparse.ArgumentParser(
            description='Run the %s module.' % self.name,
            prog='hub run %s' % self.name,
            usage='%(prog)s',
            add_help=True)

        self.arg_input_group = self.parser.add_argument_group(
            title="Input options", description="Input data. Required")

        self.add_module_input_arg()

        args = self.parser.parse_args(argvs)

        input_data = self.check_input_data(args)

        results = self.parse(texts=input_data)

        return results

    def visualize(self, word, head, deprel):
        """
        Visualize the dependency.

        Args:
            word: list[str], the tokenized words.
            head: list[int], the head ids.
            deprel: list[str], the dependency relation.

        Returns:
            data: a numpy array, use cv2.imshow to show it or cv2.imwrite to save it.
        """
        nodes = ['ROOT'] + word
        x = list(range(len(nodes)))
        y = [0] * (len(nodes))
        fig, ax = plt.subplots()
        # control the picture size
        max_span = max([abs(i + 1 - j) for i, j in enumerate(head)])
        fig.set_size_inches((len(nodes), max_span / 2))
        # set the points
        plt.scatter(x, y, c='w')

        for i in range(len(nodes)):
            txt = nodes[i]
            xytext = (i, 0)
            if i == 0:
                # set 'ROOT'
                ax.annotate(
                    txt,
                    xy=xytext,
                    xycoords='data',
                    xytext=xytext,
                    textcoords='data',
                )
            else:
                xy = (head[i - 1], 0)
                rad = 0.5 if head[i - 1] < i else -0.5
                # set the word
                ax.annotate(
                    txt,
                    xy=xy,
                    xycoords='data',
                    xytext=(xytext[0] - 0.1, xytext[1]),
                    textcoords='data',
                    fontproperties=self.font)
                # draw the curve
                ax.annotate(
                    "",
                    xy=xy,
                    xycoords='data',
                    xytext=xytext,
                    textcoords='data',
                    arrowprops=dict(
                        arrowstyle="<-",
                        shrinkA=12,
                        shrinkB=12,
                        color='blue',
                        connectionstyle="arc3,rad=%s" % rad,
                    ),
                )
                # set the deprel label. Calculate its position by the radius
                text_x = min(i, head[i - 1]) + abs((i - head[i - 1])) / 2 - 0.2
                text_y = abs((i - head[i - 1])) / 4
                ax.annotate(
                    deprel[i - 1],
                    xy=xy,
                    xycoords='data',
                    xytext=[text_x, text_y],
                    textcoords='data')

        # control the axis
        plt.axis('equal')
        plt.axis('off')

        # save to numpy array
        fig.canvas.draw()
        data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        data = data.reshape(fig.canvas.get_width_height()[::-1] +
                            (3, ))[:, :, ::-1]
        return data


if __name__ == "__main__":
    module = ddparser()
    # Data to be predicted
    test_text = ["百度是一家高科技公司"]
    results = module.parse(texts=test_text)
    print(results)
    test_tokens = [['百度', '是', '一家', '高科技', '公司']]
194
    results = module.parse(texts=test_text, return_visual=True)
K
kinghuin 已提交
195 196 197 198 199
    print(results)
    result = results[0]
    data = module.visualize(result['word'], result['head'], result['deprel'])
    import cv2
    import numpy as np
200 201
    cv2.imwrite('test1.jpg', data)
    cv2.imwrite('test2.jpg', result['visual'])