test_einsum_op.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import paddle
from op_test import OpTest


class TestEinsumBinary(OpTest):
22

23 24 25 26 27 28 29 30 31 32 33 34 35
    def setUp(self):
        paddle.enable_static()
        self.op_type = "einsum"
        self.disable = False
        self.set_mandatory()
        self.init_input()
        np.random.seed(123)
        out = np.einsum(self.equation, *self.inputs)
        self.operands = []
        for idx, inp in enumerate(self.inputs):
            self.operands.append(("x" + str(idx), inp))
        self.inputs = {"Operands": self.operands}
        self.attrs = {"equation": self.equation}
36
        self.outputs = {
37 38
            'Out':
            out,
39
            "InnerCache": [('cache_' + str(i), np.array([1.0]))
40 41 42
                           for i in range(len(self.operands))],
            "XShape": [('xshape_' + str(i), np.array([1.0]))
                       for i in range(len(self.operands))],
43
        }
44 45 46 47 48 49 50 51 52 53 54 55 56

    def init_input(self):
        self.inputs = []
        for t, s in zip(self.types, self.shapes):
            self.inputs.append(np.random.random(s).astype(t))

    def set_mandatory(self):
        self.shapes = [(10, 10, 20), (20, 6)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,jk->ki"

    def test_check_output(self):
        if not self.disable:
57
            self.check_output(no_check_set=["InnerCache", "XShape"])
58 59 60 61 62 63 64

    def test_grad(self):
        if not self.disable:
            self.check_grad([op[0] for op in self.operands], ["Out"])


class TestEinsum1(TestEinsumBinary):
65

66 67 68 69 70 71 72
    def set_mandatory(self):
        self.shapes = [(20, 3, 3), (20, 3, 3)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,mjk->mik"


class TestEinsum2(TestEinsumBinary):
73

74 75 76 77 78 79 80
    def set_mandatory(self):
        self.shapes = [(20, 3, 3), (20, 3, 3)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,mjk->ikm"


class TestEinsum3(TestEinsumBinary):
81

82 83 84 85 86 87 88
    def set_mandatory(self):
        self.shapes = [(10, 10), (10, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "ij,jk->ik"  # }}}


class TestEinsumWithReduction(TestEinsumBinary):
89

90 91 92 93 94 95 96
    def set_mandatory(self):
        self.shapes = [(10, 3, 5), (5, 30)]
        self.types = [np.float64, np.float64]
        self.equation = "ijk,kl->jl"


class TestEinsumWithReduction1(TestEinsumBinary):
97

98 99 100 101 102 103 104
    def set_mandatory(self):
        self.shapes = [(10, 3, 3, 5), (10, 5, 10, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "mijk,mklh->ljm"


class TestEinsumWithUnary(TestEinsumBinary):
105

106 107 108 109 110 111 112
    def set_mandatory(self):
        self.shapes = [(10, 10, 3, 5)]
        self.types = [np.float64]
        self.equation = "mijk->mi"


class TestEinsumWithUnary1(TestEinsumBinary):
113

114 115 116 117 118 119 120
    def set_mandatory(self):
        self.shapes = [(5, 10, 3, 3), (3, 6, 3, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "imjl,jklm->imk"


class TestEinsumWithBroadcast1(TestEinsumBinary):
121

122 123 124 125 126 127 128
    def set_mandatory(self):
        self.shapes = [(5, 10, 3, 3)]
        self.types = [np.float64]
        self.equation = "i...->..."


class TestEinsumWithBroadcast2(TestEinsumBinary):
129

130 131 132 133 134 135 136
    def set_mandatory(self):
        self.shapes = [(10, 11), (3, 4, 5, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "...ij,...i->j..."


class TestEinsumWithBroadcast3(TestEinsumBinary):
137

138 139 140 141 142 143 144
    def set_mandatory(self):
        self.shapes = [(10, 3, 2, 3, 4), (12, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "k...,...jk->...k"


class TestEinsumWithBroadcast4(TestEinsumBinary):
145

146 147 148 149 150 151 152
    def set_mandatory(self):
        self.shapes = [(10, 3, 2, 3, 4), (12, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "a...d,...cb->...abcd"


class TestEinsumWithBroadcast5(TestEinsumBinary):
153

154 155 156 157 158 159 160
    def set_mandatory(self):
        self.shapes = [(3, 2, 2, 10), (10, 3, 2, 2)]
        self.types = [np.float64, np.float64]
        self.equation = "...a,a...->..."


class TestEinsumWithBroadcast6(TestEinsumBinary):
161

162 163 164 165 166 167 168 169
    def set_mandatory(self):
        self.shapes = [(100), (100)]
        self.types = [np.float64, np.float64]
        self.equation = "i,i->"


if __name__ == "__main__":
    unittest.main()