basic_api_transformer.py 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import astor
16

17
from paddle.utils import gast
18

19
from . import utils
20 21
from .base_transformer import BaseTransformer
from .static_analysis import AstNodeWrapper
22

23 24
__all__ = []

25

26
class BasicApiTransformer(BaseTransformer):
27 28 29 30 31 32 33 34 35 36 37 38 39 40
    """
    Class to transform basic API from dygraph to static graph.
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."

        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node
        self.class_node_dict = {}

    def transform(self):
41 42
        to_tensor_transformer = ToTensorTransformer(self.root)
        to_tensor_transformer.transform()
43 44
        attribute_transformer = AttributeJstTransformer(self.root)
        attribute_transformer.transform()
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
        self.visit(self.root)
        return self.wrapper_root

    def visit_Assign(self, node):
        if self._update_class_node_dict(node):
            return None

        for child_node in gast.walk(node.value):
            if isinstance(child_node, gast.Call):
                self._visit_Call(child_node)
        return node

    def visit_Expr(self, node):
        value_node = node.value
        for child_node in gast.walk(value_node):
            if isinstance(child_node, gast.Call):
                # TODO(liym27):
                #  Considers that a dygraph api which modifies the input or has a output.
63
                if utils.is_dygraph_api(child_node):
64 65 66 67 68 69 70 71 72 73 74
                    return
                else:
                    self._visit_Call(child_node)
        return node

    def _visit_Call(self, node):
        assert isinstance(node, gast.Call)
        func_name = astor.to_source(gast.gast_to_ast(node.func))

        if self._is_dygraph_forward(func_name):
            class_node = self._get_class_node(func_name)
75
            static_node = utils.to_static_ast(node, class_node)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            return static_node
        else:
            return node

    def _is_dygraph_forward(self, func_id):
        return func_id in self.class_node_dict

    def _get_class_node(self, func_id):
        return self.class_node_dict[func_id]

    def _update_class_node_dict(self, node):
        assert isinstance(node, gast.Assign)
        node_value = node.value
        if isinstance(node_value, gast.Call):
            if is_to_variable(node_value):
                return False

93
            if utils.is_dygraph_api(node_value):
94
                dygraph_api = node_value.func.attr
95
                if not utils.dygraph_class_to_static_api.get(dygraph_api):
96 97
                    return False

98
                utils.update_args_of_func(node_value, node_value, "__init__")
99 100 101 102 103
                target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
                self.class_node_dict[target_str] = node_value
                return True
            # TODO: node.value is not dygraph class
        return False
104 105


106
class ToTensorTransformer(BaseTransformer):
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    """
    Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign
    """

    def __init__(self, node):
        assert isinstance(
            node, gast.AST
        ), "Input non-gast.AST node for the initialization of ToTensorTransformer."
        self.root = node

    def transform(self):
        self.visit(self.root)
        return self.root

    def visit_Call(self, node):
        assert isinstance(node, gast.Call)
        if is_to_variable(node):
            node = to_assign_node(node)
        self.generic_visit(node)
        return node


129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
class NameloadJstTransformer(BaseTransformer):
    """
    change name and attribute load to __jst.Ld(name) pattern.
    for example:
        a.dtype -->  __jst.Ld(__jst.Ld(a).dtype)

    In paddle science and deepxde, we have to support changing tensor into variable
    in arbitrary occasion such as global tensor.

    NOTE: we only deal with ctx=Load() case.
    """

    def __init__(self, wrapper_root):
        assert isinstance(
            wrapper_root, AstNodeWrapper
        ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer."

        self.wrapper_root = wrapper_root
        self.root = wrapper_root.node

    def transform(self):
        self.visit(self.root)
        return self.root

    def _surround_with_ld(self, node):
        node = (
            gast.parse(
                "_jst.Ld({})".format(utils.ast_to_source_code(node).strip())
            )
            .body[0]
            .value
        )
        return node

    def visit_Call(self, node):
        """
        Can't convert name of function call, bacause this will affect CallTransformer.
        """
        node.args = [self.generic_visit(arg) for arg in node.args]
        return node

    def visit_Attribute(self, node):
        assert isinstance(node, gast.Attribute)
        assert isinstance(node.attr, str)
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Load):
            node = self._surround_with_ld(node)
        return node

    def visit_Name(self, node):
        assert isinstance(node, gast.Name)
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Load):
            node = self._surround_with_ld(node)
        return node


186 187 188 189 190 191
class AttributeJstTransformer(BaseTransformer):
    """
    change some special attribute into __jst.XXX(obj, "attr_name") format.
    for example:
        a.size  -->  __jst.attr(a, "size")

192
    because `size` have different behavier when in dygraph / static graph mode
193 194 195 196 197 198 199
    NOTE: we only deal with ctx=Load() case.
    """

    def __init__(self, node):
        assert isinstance(
            node, gast.AST
        ), "Input non-gast.AST node for the initialization of ToTensorTransformer."
200 201 202
        self.interested_name = {
            'size',
        }
203 204 205 206 207 208 209 210 211
        self.root = node

    def transform(self):
        self.visit(self.root)
        return self.root

    def visit_Attribute(self, node):
        assert isinstance(node, gast.Attribute)
        assert isinstance(node.attr, str)
212 213 214 215
        if (
            isinstance(node.ctx, gast.Load)
            and node.attr in self.interested_name
        ):
216 217
            attr = node.attr
            value = node.value
218 219 220
            node = (
                gast.parse(
                    "_jst.Attr({}, \"{}\")".format(
221
                        utils.ast_to_source_code(value).strip(), attr
222 223 224 225 226
                    )
                )
                .body[0]
                .value
            )
227 228 229 230
        self.generic_visit(node)
        return node


231 232 233 234 235 236 237 238 239 240 241
def is_to_variable(node):
    assert isinstance(node, gast.Call)
    api_name = utils.ast_to_source_code(node.func).strip()

    if utils.is_dygraph_api(node):
        return api_name.endswith("to_variable")

    return False


def to_assign_node(node):
242
    # Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `paddle.assign`.
243 244 245 246 247 248
    # NOTE:
    #   1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
    #   but api `assign` only supports {float32, float64, int32, int64, bool};
    #   2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.

    assert isinstance(node, gast.Call)
249
    assign_api = gast.parse('paddle.assign').body[0].value
250 251 252 253 254 255 256 257
    node.func = assign_api

    if node.args:
        node.args = [node.args[0]]
        node.keywords = []
    else:
        for idx, kw in enumerate(node.keywords):
            if kw.arg == 'value' or kw.arg == 'data':
258
                node.keywords[idx].arg = 'x'
259 260 261 262
                node.keywords = [node.keywords[idx]]
                node.args = []
                break
    return node