gen_c++_tutorial.py 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
This script helps to extract the tutorial content from a C++ source file.
'''

# syntax definition
# The text content locates in the comments with `//!` prefix.
# Some predefined marks:
#  - @h1, @h2, @h3, the nth headline
#  - @IGNORE-NEXT, hide the next line of code
#  - @ROC, the code block inside a C++ multi-line string guard `ROC()ROC`,
#          display as a markdown code block.

26
import logging
27 28 29 30
import sys
from typing import List


31
class Markdown:
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
    '''
    A simple markdown generator.
    '''

    def __init__(self):
        self.content: List[str] = []

    def h1(self, title: str):
        self.add_line('# ' + title)

    def h2(self, title: str):
        self.add_line('## ' + title)

    def h3(self, title: str):
        self.add_line('### ' + title)

    def code_block(self, lang: str, block: List[str]):
        # drop the precending and tailing empty lines to make code block more compact
        pre_valid_offset = 0
        tail_valid_offset = 0
        for x in block:
            if x.strip():
                break
            else:
                pre_valid_offset += 1
        for x in reversed(block):
            if x.strip():
                break
            else:
                tail_valid_offset += 1
        logging.warning("block0: %s" % block)
63 64 65 66 67
        block = (
            block[pre_valid_offset:-tail_valid_offset]
            if tail_valid_offset > 0
            else block[pre_valid_offset:]
        )
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        logging.warning("block1: %s" % block)
        if not block:
            return

        c = "```" + lang

        # add empty lines to wrap code block
        self.add_line('')
        self.add_line('\n'.join([c, '\n'.join(block), "```"]))
        self.add_line('')

    def add_line(self, content: str):
        self.content.append(content)

    def generate(self):
        return '\n'.join(self.content)


class Mark:
    h1 = "@h1"
    h2 = "@h2"
    h3 = "@h3"
    h4 = "@h4"
    ignore_next = "@IGNORE-NEXT"
    roc = "@ROC"


95
class ContentGenerator:
96 97 98 99 100 101 102 103 104 105 106
    '''
    Interface for some content passed into the parser.
    '''

    def has_next(self) -> bool:
        pass

    def get_line(self) -> str:
        pass


107
class Parser:
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    DOC_COMMENT_PREFIX = "//!"

    def __init__(self):
        self.doc = Markdown()
        self.code_block = []

    def parse(self, content: ContentGenerator):
        while content.has_next():
            line = content.get_line()
            line_striped = line.strip()
            is_doc = False
            if line_striped.startswith(self.DOC_COMMENT_PREFIX):
                is_doc = True
                if self.code_block:
                    self.doc.code_block('c++', self.code_block)
                    self.code_block = []

125 126 127
                line_striped = line_striped[
                    len(self.DOC_COMMENT_PREFIX) :
                ].strip()
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

                if line_striped.startswith(Mark.h1):
                    self.eat_h1(line_striped)
                elif line_striped.startswith(Mark.h2):
                    self.eat_h2(line_striped)
                elif line_striped.startswith(Mark.h3):
                    self.eat_h3(line_striped)
                elif line_striped.startswith(Mark.h4):
                    self.eat_h4(line_striped)
                elif line_striped.startswith(Mark.ignore_next):
                    self.eat_ignore_next(content)
                elif line_striped.startswith(Mark.roc):
                    self.eat_roc(line_striped, content)
                else:
                    self.doc.add_line(line_striped)

            else:  # normal code
                self.code_block.append(line)

    def eat_h1(self, content: str) -> None:
148
        self.doc.h1(content[len(Mark.h1) :].strip())
149 150

    def eat_h2(self, content: str) -> None:
151
        self.doc.h2(content[len(Mark.h2) :].strip())
152 153

    def eat_h3(self, content: str) -> None:
154
        self.doc.h3(content[len(Mark.h3) :].strip())
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178

    def eat_ignore_next(self, content: ContentGenerator) -> None:
        content.get_line()

    def eat_roc(self, header: str, content: ContentGenerator) -> None:
        '''
        Get the content from a pair of ROC guards.
        @param header the string contains description of the ROC block.
        @content: the content generator.

        e.g.

        the content:

            //! @ROC[c++]
            auto target_source = R"ROC(
            function fn0 (_A, _B, _tensor)
            {
            }
            ROC);

        The parameter header is `//! @ROC[c++]`.
        '''
        assert "ROC" in header
179
        lang = header[len("@ROC[") : -1]
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

        logging.warning("eating ROC")

        assert content.has_next()
        line: str = content.get_line()
        assert "ROC(" in line
        line = content.get_line()
        code_block = []
        while ")ROC" not in line:
            code_block.append(line)
            line: str = content.get_line()

        logging.warning("DOC content: %s" % code_block)

        self.doc.code_block(lang, code_block)

    def generate(self):
        return self.doc.generate()


if __name__ == '__main__':

    class Content(ContentGenerator):
        def __init__(self):
204
            self.lines = list(sys.stdin)
205 206 207 208 209 210 211 212 213 214 215 216 217 218
            self.cur = 0

        def has_next(self):
            return self.cur < len(self.lines)

        def get_line(self):
            assert self.has_next()
            res = self.lines[self.cur]
            self.cur += 1
            return res.rstrip()

    parser = Parser()
    parser.parse(Content())
    sys.stdout.write(parser.generate())