program_utils.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

M
minqiyang 已提交
17 18
import six

19 20 21
from paddle.fluid import core
import paddle

22 23 24 25 26

def delete_ops(block, ops):
    try:
        start = list(block.ops).index(ops[0])
        end = list(block.ops).index(ops[-1])
M
minqiyang 已提交
27
        [block._remove_op(start) for _ in six.moves.range(end - start + 1)]
28
    except Exception as e:
29
        raise e
W
Wu Yi 已提交
30
    block.program._sync_with_cpp()
31 32 33 34 35 36 37 38 39 40 41 42 43 44


def find_op_by_input_arg(block, arg_name):
    for index, op in enumerate(block.ops):
        if arg_name in op.input_arg_names:
            return index
    return -1


def find_op_by_output_arg(block, arg_name):
    for index, op in enumerate(block.ops):
        if arg_name in op.output_arg_names:
            return index
    return -1
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64


def get_indent_space(indent, space_num=4):
    ret = ""
    for i in range(0, indent * space_num):
        ret += " "

    return ret


def variable_to_code(var):
    """
    Get readable codes of fluid variable.

    Args:
        var: A fluid operator.

    Returns:
        string: The formatted string.
    """
65 66 67 68 69 70
    if var.type == core.VarDesc.VarType.SELECTED_ROWS or var.type == core.VarDesc.VarType.LOD_TENSOR:
        var_str = "{name} : fluid.{type}.shape{shape}.astype({dtype})".\
            format(i="{", e="}", name=var.name, type=var.type, shape=var.shape, dtype=var.dtype)
    else:
        var_str = "{name} : fluid.{type})".\
            format(i="{", e="}", name=var.name, type=var.type)
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

    if type(var) == paddle.fluid.framework.Parameter:
        if var.trainable:
            var_str = "trainable parameter " + var_str
        else:
            var_str = "parameter " + var_str
    else:
        var_str = "var " + var_str

    if var.persistable:
        var_str = "persist " + var_str

    return var_str


def op_to_code(op):
    """
    Get readable codes of fluid operator.

    Args:
        op: A fluid operator.

    Returns:
        string: The foramtted string.
    """

    outputs_str = "{"
    for i in range(0, len(op.output_names)):
        outputs_str += "{name}=".format(name=op.output_names[i])
        o = op.output(op.output_names[i])
        outputs_str += "{value}".format(value=o)
        if i != len(op.output_names) - 1:
            outputs_str += ", "
    outputs_str += "}"

    inputs_str = "{"
    for i in range(0, len(op.input_names)):
        inputs_str += "{name}=".format(name=op.input_names[i])
        o = op.input(op.input_names[i])
        inputs_str += "{value}".format(value=o)

        if i != len(op.input_names) - 1:
            inputs_str += ", "
    inputs_str += "}"

G
gongweibao 已提交
116
    attr_names = sorted(op.attr_names)
117
    attrs_str = ""
G
gongweibao 已提交
118 119
    for i in range(0, len(attr_names)):
        name = attr_names[i]
120 121 122 123 124 125

        attr_type = op.desc.attr_type(name)
        if attr_type == core.AttrType.BLOCK:
            a = "{name} = block[{value}]".format(
                name=name, type=attr_type, value=op.block_attr_id(name))
            attrs_str += a
G
gongweibao 已提交
126 127
            if i != len(attr_names) - 1:
                attrs_str += ", "
128 129 130 131 132 133
            continue

        if attr_type == core.AttrType.BLOCKS:
            a = "{name} = blocks{value}".format(
                name=name, type=attr_type, value=op.blocks_attr_ids(name))
            attrs_str += a
G
gongweibao 已提交
134 135
            if i != len(attr_names) - 1:
                attrs_str += ", "
136 137 138 139 140
            continue

        a = "{name} = {value}".format(
            name=name, type=attr_type, value=op.desc.attr(name))
        attrs_str += a
G
gongweibao 已提交
141
        if i != len(attr_names) - 1:
142 143 144 145 146 147 148 149 150 151 152
            attrs_str += ", "

    if outputs_str != "{}":
        op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
            format(outputs = outputs_str, op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    else:
        op_str = "{op_type}(inputs={inputs}, {attrs})".\
            format(op_type=op.type, inputs=inputs_str, attrs=attrs_str)
    return op_str


153 154 155 156 157 158 159 160
def block_to_code(block, block_idx):
    indent = 0

    print("{0}{1} // block {2}".format(
        get_indent_space(indent), '{', block_idx))

    indent += 1
    # sort all vars
M
minqiyang 已提交
161
    all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
162 163 164 165 166 167 168 169 170 171 172 173 174
    for var in all_vars:
        print("{}{}".format(get_indent_space(indent), variable_to_code(var[1])))

    if len(all_vars) > 0:
        print("")

    for op in block.ops:
        print("{}{}".format(get_indent_space(indent), op_to_code(op)))
    indent -= 1

    print("{0}{1}".format(get_indent_space(indent), '}'))


175 176 177 178 179 180 181 182 183 184 185 186
def program_to_code(prog):
    """
    Print readable codes of fluid program.

    Args:
        prog : A fluid program.

    An example result like bellow:
    https://github.com/PaddlePaddle/Paddle/pull/12673
    """
    block_idx = 0
    for block in prog.blocks:
187
        block_to_code(block, block_idx)
188
        block_idx += 1