#!/usr/bin/env python3 # Copyright (c) 2021 CINN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import unittest import cinn import numpy as np from cinn import Target, ir, lang, runtime, utils from cinn.poly import create_stages class TestMamul(unittest.TestCase): def setUp(self): np.random.seed(0) self.target = Target() self.target.arch = Target.Arch.X86 self.target.bits = Target.Bit.k32 self.target.os = Target.OS.Linux self.m = 1024 self.n = 1024 self.k = 1024 self.bn = 32 self.engine = cinn.ExecutionEngine() utils.ProfilerHelper.enable_cpu() self.assertTrue(utils.ProfilerHelper.is_enable_cpu()) def test_matmul_basic(self): a, b, c, c_target, *args = create_data(self.m, self.n, self.k, self.bn) module = create_matmul_basic(self.target, self.m, self.n, self.k) self.engine.link(module) matmul = self.engine.lookup("matmul") matmul(args) cd = c.numpy() cd_target = c_target.numpy() np.testing.assert_allclose(cd, cd_target, atol=1e-4, rtol=1e-5) print(utils.HostEventRecorder.table()) def test_matmul_tile(self): a, b, c, c_target, *args = create_data(self.m, self.n, self.k, self.bn) module = create_matmul_tile(self.target, self.m, self.n, self.k) print('module:\n', module.get_c_code()) self.engine.link(module) matmul = self.engine.lookup("matmul_tile") matmul(args) cd = c.numpy() cd_target = c_target.numpy() np.testing.assert_allclose(cd, cd_target, atol=1e-4, rtol=1e-5) def create_matmul_basic(target, m, n, k): m, n, k = [ir.Expr(_) for _ in (m, n, k)] a = lang.Placeholder("float32", "A", [m, k]) b = lang.Placeholder("float32", "B", [k, n]) k1 = ir.Var(k.as_int32(), "k1") c = lang.compute( [m, n], lambda v: lang.reduce_sum( a(v[0], k1.to_expr_mutable()) * b(k1.to_expr_mutable(), v[1]), [k1] ), "c", ) stages = create_stages([c]) c_stage = stages[c] builder = lang.Module.Builder("matmul", target) ts = [a.to_tensor(), b.to_tensor(), c] func = lang.lower("matmul", stages, ts) print('func', func) builder.add_function(func) return builder.build() def create_matmul_tile(target, m, n, k): m, n, k = [ir.Expr(_) for _ in [m, n, k]] a = lang.Placeholder("float32", "A", [m, k]) b = lang.Placeholder("float32", "B", [k, n]) k1 = ir.Var(k.as_int32(), "k1") c = lang.compute( [m, n], lambda v: lang.reduce_sum( a(v[0], k1.to_expr_mutable()) * b(k1.to_expr_mutable(), v[1]), [k1] ), "c", ) stages = create_stages([c]) stages[c].tile(0, 1, 4, 4) builder = lang.Module.Builder("matmul_tile", target) ts = [a.to_tensor(), b.to_tensor(), c] func = lang.lower("matmul_tile", stages, ts) print('func', func) builder.add_function(func) return builder.build() def create_data(m, n, k, bn): # call around to lower the numpy's float precision so that it will not vary too much from C's float precision. a_init = np.around(np.random.randn(m, k).astype("float32"), 2) b_init = np.around(np.random.randn(k, n).astype("float32"), 2) a = runtime.cinn_buffer_t(a_init, runtime.cinn_x86_device) b = runtime.cinn_buffer_t(b_init, runtime.cinn_x86_device) c = runtime.cinn_buffer_t( np.zeros([m, n]).astype("float32"), runtime.cinn_x86_device ) c_target = runtime.cinn_buffer_t( a.numpy() @ b.numpy(), runtime.cinn_x86_device ) packed_b = runtime.cinn_buffer_t( np.zeros([n // bn, k, bn]).astype("float32"), runtime.cinn_x86_device ) a_arg = runtime.cinn_pod_value_t(a) b_arg = runtime.cinn_pod_value_t(b) c_arg = runtime.cinn_pod_value_t(c) packed_b_arg = runtime.cinn_pod_value_t(packed_b) return [a, b, c, c_target, a_arg, b_arg, c_arg] if __name__ == "__main__": unittest.main()