axpy.py 881 字节
Newer Older
S
SunAhong1993 已提交
1 2 3 4
from .register import register
from x2paddle.core.util import *


5
def axpy_shape(input_shapes):
S
SunAhong1993 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18
    assert len(input_shapes) == 3, "not valid input shape for axpy layer"
    assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
    output_shape = input_shapes[1]
    assert (input_shapes[2] == output_shape),\
            "shape not consistent for axpy[%s <--> %s]" \
            % (str(output_shape), str(input_shapes[2]))
    return [output_shape]


def axpy_layer(inputs, input_shape=None, name=None):
    alpha = inputs[0]
    x = inputs[1]
    y = inputs[2]
S
SunAhong1993 已提交
19 20
    out = fluid.layers.elementwise_mul(x, alpha, axis=0)
    out = fluid.layers.elementwise_add(out, y, name=name)
21
    return out
S
SunAhong1993 已提交
22 23 24 25 26 27 28 29


def axpy_weights(name, data=None):
    weights_name = []
    return weights_name


register(kind='Axpy', shape=axpy_shape, layer=axpy_layer, weights=axpy_weights)