program_utils.py 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

M
minqiyang 已提交
17 18
import six

19 20 21
from paddle.fluid import core
import paddle

22 23

def delete_ops(block, ops):
24 25 26 27 28 29
    for op in ops:
        try:
            idx = list(block.ops).index(op)
            block._remove_op(idx)
        except Exception as e:
            print(e)
30 31 32 33 34 35 36 37 38


def find_op_by_input_arg(block, arg_name):
    for index, op in enumerate(block.ops):
        if arg_name in op.input_arg_names:
            return index
    return -1


39 40 41 42 43 44 45 46 47 48 49 50
def find_op_by_output_arg(block, arg_name, reverse=False):
    if reverse:
        pos = len(block.ops) - 1
        while pos >= 0:
            op = block.ops[pos]
            if arg_name in op.output_arg_names:
                return pos
            pos -= 1
    else:
        for index, op in enumerate(block.ops):
            if arg_name in op.output_arg_names:
                return index
51
    return -1
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71


def get_indent_space(indent, space_num=4):
    ret = ""
    for i in range(0, indent * space_num):
        ret += " "

    return ret


def variable_to_code(var):
    """
    Get readable codes of fluid variable.

    Args:
        var: A fluid operator.

    Returns:
        string: The formatted string.
    """
72 73 74 75 76 77
    if var.type == core.VarDesc.VarType.SELECTED_ROWS or var.type == core.VarDesc.VarType.LOD_TENSOR:
        var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\
            format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype)
    else:
        var_str = "{name} : fluid.{type})".\
            format(i="{", e="}", name=var.name, type=var.type)
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

    if type(var) == paddle.fluid.framework.Parameter:
        if var.trainable:
            var_str = "trainable parameter " + var_str
        else:
            var_str = "parameter " + var_str
    else:
        var_str = "var " + var_str

    if var.persistable:
        var_str = "persist " + var_str

    return var_str


T
tangwei12 已提交
93
def op_to_code(op, skip_op_callstack=True):
94 95 96 97 98 99 100
    """
    Get readable codes of fluid operator.

    Args:
        op: A fluid operator.

    Returns:
T
tianshuo78520a 已提交
101
        string: The formatted string.
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    """

    outputs_str = "{"
    for i in range(0, len(op.output_names)):
        outputs_str += "{name}=".format(name=op.output_names[i])
        o = op.output(op.output_names[i])
        outputs_str += "{value}".format(value=o)
        if i != len(op.output_names) - 1:
            outputs_str += ", "
    outputs_str += "}"

    inputs_str = "{"
    for i in range(0, len(op.input_names)):
        inputs_str += "{name}=".format(name=op.input_names[i])
        o = op.input(op.input_names[i])
        inputs_str += "{value}".format(value=o)

        if i != len(op.input_names) - 1:
            inputs_str += ", "
    inputs_str += "}"

G
gongweibao 已提交
123
    attr_names = sorted(op.attr_names)
124
    attrs_str = ""
G
gongweibao 已提交
125 126
    for i in range(0, len(attr_names)):
        name = attr_names[i]
127 128
        if skip_op_callstack and name == "op_callstack":
            continue
129 130 131 132

        attr_type = op.desc.attr_type(name)
        if attr_type == core.AttrType.BLOCK:
            a = "{name} = block[{value}]".format(
W
Wu Yi 已提交
133
                name=name, type=attr_type, value=op._block_attr_id(name))
134
            attrs_str += a
G
gongweibao 已提交
135 136
            if i != len(attr_names) - 1:
                attrs_str += ", "
137 138 139 140
            continue

        if attr_type == core.AttrType.BLOCKS:
            a = "{name} = blocks{value}".format(
W
Wu Yi 已提交
141
                name=name, type=attr_type, value=op._blocks_attr_ids(name))
142
            attrs_str += a
G
gongweibao 已提交
143 144
            if i != len(attr_names) - 1:
                attrs_str += ", "
145 146 147 148 149
            continue

        a = "{name} = {value}".format(
            name=name, type=attr_type, value=op.desc.attr(name))
        attrs_str += a
G
gongweibao 已提交
150
        if i != len(attr_names) - 1:
151 152 153 154 155 156 157 158 159 160 161
            attrs_str += ", "

    if outputs_str != "{}":
        op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
            format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    else:
        op_str = "{op_type}(inputs={inputs}, {attrs})".\
            format(op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    return op_str


162
def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
163 164
    indent = 0

165 166 167
    print(
        "{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx),
        file=fout)
168 169 170

    indent += 1
    # sort all vars
M
minqiyang 已提交
171
    all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
172
    for var in all_vars:
173 174 175
        print(
            "{}{}".format(get_indent_space(indent), variable_to_code(var[1])),
            file=fout)
176 177

    if len(all_vars) > 0:
178
        print("", file=fout)
179 180

    for op in block.ops:
181 182 183 184
        print(
            "{}{}".format(
                get_indent_space(indent), op_to_code(op, skip_op_callstack)),
            file=fout)
185 186
    indent -= 1

187
    print("{0}{1}".format(get_indent_space(indent), '}'), file=fout)
188 189


T
tangwei12 已提交
190
def program_to_code(prog, fout=None, skip_op_callstack=True):
191 192 193 194 195 196 197 198 199 200 201
    """
    Print readable codes of fluid program.

    Args:
        prog : A fluid program.

    An example result like bellow:
    https://github.com/PaddlePaddle/Paddle/pull/12673
    """
    block_idx = 0
    for block in prog.blocks:
202
        block_to_code(block, block_idx, fout, skip_op_callstack)
203
        block_idx += 1