im2sequence.py 3.3 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
import onnx
import numpy as np
from onnx import onnx_pb, helper

im2seq_counter = 0


def im2sequence(op, block):
    global im2sequence_counter
    n, c, h, w = block.var(op.input('X')[0]).shape
    assert h > 0 and w > 0, "Only supported fixed input shape for im2sequence operator."
    stride_h, stride_w = op.attr('strides')
    paddings = op.attr('paddings')
    assert op.attr(
        'out_stride'
    ) != 1, "Only out_stride==1 is supported for im2sequence operator."
    h = h + paddings[0] + paddings[1]
    w = w + paddings[1] + paddings[2]
    kernel_h, kernel_w = op.attr('kernels')
    out_h = 1 + (h - kernel_h + stride_h - 1) // stride_h
    out_w = 1 + (w - kernel_w + stride_w - 1) // stride_w
    h_steps = list()
    for i in range(out_h):
        h_steps.append([i * stride_h, i * stride_h + kernel_h])
    w_steps = list()
    for i in range(out_w):
        w_steps.append([i * stride_w, i * stride_w + kernel_w])

    nodes = list()
    slice_blocks = list()
    for i in range(out_h):
        for j in range(out_w):
            starts_name = "im2sequence.starts.{}.{}.{}".format(
                im2seq_counter, i, j)
            starts_tensor = helper.make_tensor(
                name=starts_name,
                data_type=onnx_pb.TensorProto.INT64,
                dims=[4],
                vals=[0, 0, h_steps[i][0], w_steps[j][0]])
            ends_name = "im2sequence.ends.{}.{}.{}".format(im2seq_counter, i, j)
            ends_tensor = helper.make_tensor(
                name=ends_name,
                data_type=onnx_pb.TensorProto.INT64,
                dims=[4],
                vals=[999999, 999999, h_steps[i][1], w_steps[j][1]])
            starts_node = helper.make_node(
                'Constant',
                inputs=[],
                outputs=[starts_name],
                value=starts_tensor)
            ends_node = helper.make_node(
                'Constant', inputs=[], outputs=[ends_name], value=ends_tensor)
            nodes.extend([starts_node, ends_node])

            slice_block_name = "im2sequence.slice.{}.{}.{}".format(
                im2seq_counter, i, j)
            slice_block_node = helper.make_node(
                'Slice',
                inputs=[op.input('X')[0], starts_name, ends_name],
                outputs=[slice_block_name])
            flatten_block_name = "im2sequence.flatten.{}.{}.{}".format(
                im2seq_counter, i, j)
            flatten_block_node = helper.make_node(
                "Flatten",
                inputs=[slice_block_name],
                outputs=[flatten_block_name],
                axis=0)
            nodes.extend([slice_block_node, flatten_block_node])
            slice_blocks.append(flatten_block_name)
    concat_block_name = "im2sequence.concat_block.{}".format(im2seq_counter)
    #    concat_block_node = helper.make_node("Concat", inputs=slice_blocks, outputs=[concat_block_name], axis=0)
    concat_block_node = helper.make_node(
        "Concat", inputs=slice_blocks, outputs=op.output('Out'), axis=0)
    nodes.append(concat_block_node)
    print("\n\n==========Importance Notice===========")
    print(
        "Since im2sequence operator is used in your paddlepaddle model, the translated onnx model only support input data with batch_size=1."
    )
    print("======================================\n")
    return nodes