permute.py 709 字节
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
from .register import register
from x2paddle.core.util import *


def permute_shape(input_shape, order=None):
    inshape = input_shape[0]
    output_shape = []
    for ii in order:
        assert ii < len(inshape), "invalid order for permute[%s]" % (name)
        output_shape.append(inshape[ii])
    return [output_shape]


def permute_layer(inputs, order=None, input_shape=None, name=None):
    input = inputs[0]
    order = list(order)
S
SunAhong1993 已提交
17
    out = paddle.transpose(input, perm=order, name=name)
S
SunAhong1993 已提交
18 19 20 21 22 23 24 25
    return out


def permute_weights(name, data=None):
    weights_name = []
    return weights_name


J
jiangjiajun 已提交
26 27 28 29 30
register(
    kind='Permute',
    shape=permute_shape,
    layer=permute_layer,
    weights=permute_weights)