program_utils.py 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

M
minqiyang 已提交
17 18
import six

19 20 21
from paddle.fluid import core
import paddle

22 23

def delete_ops(block, ops):
24 25 26 27 28 29
    for op in ops:
        try:
            idx = list(block.ops).index(op)
            block._remove_op(idx)
        except Exception as e:
            print(e)
30 31 32 33 34 35 36 37 38


def find_op_by_input_arg(block, arg_name):
    for index, op in enumerate(block.ops):
        if arg_name in op.input_arg_names:
            return index
    return -1


39 40 41 42 43 44 45 46 47 48 49 50
def find_op_by_output_arg(block, arg_name, reverse=False):
    if reverse:
        pos = len(block.ops) - 1
        while pos >= 0:
            op = block.ops[pos]
            if arg_name in op.output_arg_names:
                return pos
            pos -= 1
    else:
        for index, op in enumerate(block.ops):
            if arg_name in op.output_arg_names:
                return index
51
    return -1
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71


def get_indent_space(indent, space_num=4):
    ret = ""
    for i in range(0, indent * space_num):
        ret += " "

    return ret


def variable_to_code(var):
    """
    Get readable codes of fluid variable.

    Args:
        var: A fluid operator.

    Returns:
        string: The formatted string.
    """
72 73 74 75 76 77
    if var.type == core.VarDesc.VarType.SELECTED_ROWS or var.type == core.VarDesc.VarType.LOD_TENSOR:
        var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\
            format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype)
    else:
        var_str = "{name} : fluid.{type})".\
            format(i="{", e="}", name=var.name, type=var.type)
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

    if type(var) == paddle.fluid.framework.Parameter:
        if var.trainable:
            var_str = "trainable parameter " + var_str
        else:
            var_str = "parameter " + var_str
    else:
        var_str = "var " + var_str

    if var.persistable:
        var_str = "persist " + var_str

    return var_str


T
tangwei12 已提交
93
def op_to_code(op, skip_op_callstack=True):
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    """
    Get readable codes of fluid operator.

    Args:
        op: A fluid operator.

    Returns:
        string: The foramtted string.
    """

    outputs_str = "{"
    for i in range(0, len(op.output_names)):
        outputs_str += "{name}=".format(name=op.output_names[i])
        o = op.output(op.output_names[i])
        outputs_str += "{value}".format(value=o)
        if i != len(op.output_names) - 1:
            outputs_str += ", "
    outputs_str += "}"

    inputs_str = "{"
    for i in range(0, len(op.input_names)):
        inputs_str += "{name}=".format(name=op.input_names[i])
        o = op.input(op.input_names[i])
        inputs_str += "{value}".format(value=o)

        if i != len(op.input_names) - 1:
            inputs_str += ", "
    inputs_str += "}"

G
gongweibao 已提交
123
    attr_names = sorted(op.attr_names)
124
    attrs_str = ""
G
gongweibao 已提交
125 126
    for i in range(0, len(attr_names)):
        name = attr_names[i]
127 128
        if skip_op_callstack and name == "op_callstack":
            continue
129 130 131 132

        attr_type = op.desc.attr_type(name)
        if attr_type == core.AttrType.BLOCK:
            a = "{name} = block[{value}]".format(
W
Wu Yi 已提交
133
                name=name, type=attr_type, value=op._block_attr_id(name))
134
            attrs_str += a
G
gongweibao 已提交
135 136
            if i != len(attr_names) - 1:
                attrs_str += ", "
137 138 139 140
            continue

        if attr_type == core.AttrType.BLOCKS:
            a = "{name} = blocks{value}".format(
W
Wu Yi 已提交
141
                name=name, type=attr_type, value=op._blocks_attr_ids(name))
142
            attrs_str += a
G
gongweibao 已提交
143 144
            if i != len(attr_names) - 1:
                attrs_str += ", "
145 146 147 148 149
            continue

        a = "{name} = {value}".format(
            name=name, type=attr_type, value=op.desc.attr(name))
        attrs_str += a
G
gongweibao 已提交
150
        if i != len(attr_names) - 1:
151 152 153 154 155 156 157 158 159 160 161
            attrs_str += ", "

    if outputs_str != "{}":
        op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
            format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    else:
        op_str = "{op_type}(inputs={inputs}, {attrs})".\
            format(op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    return op_str


162
def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
163 164
    indent = 0

165 166 167
    print(
        "{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx),
        file=fout)
168 169 170

    indent += 1
    # sort all vars
M
minqiyang 已提交
171
    all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
172
    for var in all_vars:
173 174 175
        print(
            "{}{}".format(get_indent_space(indent), variable_to_code(var[1])),
            file=fout)
176 177

    if len(all_vars) > 0:
178
        print("", file=fout)
179 180

    for op in block.ops:
181 182 183 184
        print(
            "{}{}".format(
                get_indent_space(indent), op_to_code(op, skip_op_callstack)),
            file=fout)
185 186
    indent -= 1

187
    print("{0}{1}".format(get_indent_space(indent), '}'), file=fout)
188 189


T
tangwei12 已提交
190
def program_to_code(prog, fout=None, skip_op_callstack=True):
191 192 193 194 195 196 197 198 199 200 201
    """
    Print readable codes of fluid program.

    Args:
        prog : A fluid program.

    An example result like bellow:
    https://github.com/PaddlePaddle/Paddle/pull/12673
    """
    block_idx = 0
    for block in prog.blocks:
202
        block_to_code(block, block_idx, fout, skip_op_callstack)
203
        block_idx += 1