test_einsum_op.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
from op_test import OpTest


class TestEinsumBinary(OpTest):
24

25 26 27 28 29 30 31 32 33 34 35 36 37
    def setUp(self):
        paddle.enable_static()
        self.op_type = "einsum"
        self.disable = False
        self.set_mandatory()
        self.init_input()
        np.random.seed(123)
        out = np.einsum(self.equation, *self.inputs)
        self.operands = []
        for idx, inp in enumerate(self.inputs):
            self.operands.append(("x" + str(idx), inp))
        self.inputs = {"Operands": self.operands}
        self.attrs = {"equation": self.equation}
38
        self.outputs = {
39 40
            'Out':
            out,
41
            "InnerCache": [('cache_' + str(i), np.array([1.0]))
42 43 44
                           for i in range(len(self.operands))],
            "XShape": [('xshape_' + str(i), np.array([1.0]))
                       for i in range(len(self.operands))],
45
        }
46 47 48 49 50 51 52 53 54 55 56 57 58

    def init_input(self):
        self.inputs = []
        for t, s in zip(self.types, self.shapes):
            self.inputs.append(np.random.random(s).astype(t))

    def set_mandatory(self):
        self.shapes = [(10, 10, 20), (20, 6)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,jk->ki"

    def test_check_output(self):
        if not self.disable:
59
            self.check_output(no_check_set=["InnerCache", "XShape"])
60 61 62 63 64 65 66

    def test_grad(self):
        if not self.disable:
            self.check_grad([op[0] for op in self.operands], ["Out"])


class TestEinsum1(TestEinsumBinary):
67

68 69 70 71 72 73 74
    def set_mandatory(self):
        self.shapes = [(20, 3, 3), (20, 3, 3)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,mjk->mik"


class TestEinsum2(TestEinsumBinary):
75

76 77 78 79 80 81 82
    def set_mandatory(self):
        self.shapes = [(20, 3, 3), (20, 3, 3)]
        self.types = [np.float64, np.float64]
        self.equation = "mij,mjk->ikm"


class TestEinsum3(TestEinsumBinary):
83

84 85 86 87 88 89 90
    def set_mandatory(self):
        self.shapes = [(10, 10), (10, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "ij,jk->ik"  # }}}


class TestEinsumWithReduction(TestEinsumBinary):
91

92 93 94 95 96 97 98
    def set_mandatory(self):
        self.shapes = [(10, 3, 5), (5, 30)]
        self.types = [np.float64, np.float64]
        self.equation = "ijk,kl->jl"


class TestEinsumWithReduction1(TestEinsumBinary):
99

100 101 102 103 104 105 106
    def set_mandatory(self):
        self.shapes = [(10, 3, 3, 5), (10, 5, 10, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "mijk,mklh->ljm"


class TestEinsumWithUnary(TestEinsumBinary):
107

108 109 110 111 112 113 114
    def set_mandatory(self):
        self.shapes = [(10, 10, 3, 5)]
        self.types = [np.float64]
        self.equation = "mijk->mi"


class TestEinsumWithUnary1(TestEinsumBinary):
115

116 117 118 119 120 121 122
    def set_mandatory(self):
        self.shapes = [(5, 10, 3, 3), (3, 6, 3, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "imjl,jklm->imk"


class TestEinsumWithBroadcast1(TestEinsumBinary):
123

124 125 126 127 128 129 130
    def set_mandatory(self):
        self.shapes = [(5, 10, 3, 3)]
        self.types = [np.float64]
        self.equation = "i...->..."


class TestEinsumWithBroadcast2(TestEinsumBinary):
131

132 133 134 135 136 137 138
    def set_mandatory(self):
        self.shapes = [(10, 11), (3, 4, 5, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "...ij,...i->j..."


class TestEinsumWithBroadcast3(TestEinsumBinary):
139

140 141 142 143 144 145 146
    def set_mandatory(self):
        self.shapes = [(10, 3, 2, 3, 4), (12, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "k...,...jk->...k"


class TestEinsumWithBroadcast4(TestEinsumBinary):
147

148 149 150 151 152 153 154
    def set_mandatory(self):
        self.shapes = [(10, 3, 2, 3, 4), (12, 10)]
        self.types = [np.float64, np.float64]
        self.equation = "a...d,...cb->...abcd"


class TestEinsumWithBroadcast5(TestEinsumBinary):
155

156 157 158 159 160 161 162
    def set_mandatory(self):
        self.shapes = [(3, 2, 2, 10), (10, 3, 2, 2)]
        self.types = [np.float64, np.float64]
        self.equation = "...a,a...->..."


class TestEinsumWithBroadcast6(TestEinsumBinary):
163

164 165 166 167 168 169 170 171
    def set_mandatory(self):
        self.shapes = [(100), (100)]
        self.types = [np.float64, np.float64]
        self.equation = "i,i->"


if __name__ == "__main__":
    unittest.main()