dygraph.py 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
import paddle
import collections
import logging
import numpy as np
from paddle.fluid.framework import _dygraph_tracer, dygraph_only, _dygraph_guard
from paddle.fluid.dygraph.base import program_desc_tracing_guard
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from ..common import get_logger

__all__ = ["dygraph2program"]

_logger = get_logger(__name__, level=logging.INFO)


def _is_shape(values):
    if not isinstance(values, (list, tuple)):
        return False
    for v in values:
        if not isinstance(v, int):
            return False
    return True


def _is_shapes(values):
    if not isinstance(values, (list, tuple)):
        return False
    for v in values:
        if not _is_shape(v):
            return False
    return True


def _create_tensors(shapes, dtypes=None):
    if dtypes is not None:
        assert len(shapes) == len(
            dtypes
        ), "Length of shapes and dtypes must be same. But get len(shapes): {}; len(dtypes): {}; shapes: {}; dtypes: {}".format(
            len(shapes), len(dtypes), shapes, dtypes)
    else:
        dtypes = len(shapes) * ['float32']
    tensors = []
    for shape, dtype in zip(shapes, dtypes):
        data = np.ones(tuple(shape)).astype(dtype)
        tensors.append(paddle.to_tensor(data))
    return tensors


def extract_vars(inputs):
    """
    Extract a list of variables from inputs.
    Args:
        inputs(Variable | list<Object> | dict): 
    """
    vars = []
    if isinstance(inputs, Variable):
        vars = [inputs]
    elif isinstance(inputs, dict):
        for _key, _value in inputs.items():
            if isinstance(_value, Variable):
                vars.append(_value)
            else:
                _logger.warn(
C
ceci3 已提交
64
                    "Variable is excepted, but get an element with type({}) from inputs whose type is dict. And the key of element is {}.".format(type(_value), _key)
65 66 67 68 69
                )
    elif isinstance(inputs, (tuple, list)):
        for _value in inputs:
            vars.extend(extract_vars(_value))
    if len(vars) == 0:
C
ceci3 已提交
70
        _logger.warn("Extract none variables from inputs.")
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    return vars


def to_variables(inputs):
    """
    Find and rename variables. Find np.ndarray and convert it to variable.
    """
    if isinstance(inputs, Variable) or isinstance(inputs, np.ndarray):
        return paddle.fluid.dygraph.to_variable(inputs)
    elif isinstance(inputs, dict):
        ret = {}
        for _key in inputs:
            ret[_key] = to_variables(inputs[_key])
        return inputs
    elif isinstance(inputs, list):
        ret = []
        for _value in inputs:
            ret.append(to_variables(_value))
        return ret


@dygraph_only
def dygraph2program(layer,
                    inputs,
                    feed_prefix='feed_',
                    fetch_prefix='fetch_',
                    tmp_prefix='t_',
                    extract_inputs_fn=None,
                    extract_outputs_fn=None,
                    dtypes=None):
    assert isinstance(layer, Layer)

    extract_inputs_fn = extract_inputs_fn if extract_inputs_fn is not None else extract_vars
    extract_outputs_fn = extract_outputs_fn if extract_outputs_fn is not None else extract_vars
    tracer = _dygraph_tracer()._get_program_desc_tracer()

    with program_desc_tracing_guard(True):

        if _is_shape(inputs):
            shapes = [inputs]
            inputs = _create_tensors(shapes, dtypes=dtypes)
            input_var_list = inputs
        elif _is_shapes(inputs):
            inputs = _create_tensors(inputs, dtypes=dtypes)
            input_var_list = inputs
        else:
            inputs = to_variables(inputs)
            input_var_list = extract_inputs_fn(inputs)
        original_outputs = layer(*inputs)
        # 'original_outputs' may be dict, so we should convert it to list of varibles.
        # And should not create new varibles in 'extract_vars'.
        out_var_list = extract_outputs_fn(original_outputs)
        program_desc, feed_names, fetch_names, parameters = tracer.create_program_desc(
            input_var_list, feed_prefix, out_var_list, fetch_prefix, tmp_prefix)
        tracer.reset()

    with _dygraph_guard(None):
        program = Program()
        program.desc = program_desc
        program.blocks = [Block(program, 0)]
        program._sync_with_cpp()
    return program