diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index f198919b0c87bb4f2ea9991e401a8242676d3f46..e1ce705533ab4ba1c75d8f656683608365e97907 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -46,8 +46,12 @@ if(WITH_GLOO) endif() endif() +if(WITH_MLU) + SET(MLU_DEPS mlu_baseop) +endif() + if(NOT WITH_ASCEND_CL) -cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function phi_tensor) +cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function phi_tensor ${MLU_DEPS}) else() cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner phi_tensor) endif() diff --git a/paddle/fluid/operators/matmul_op_mlu.cc b/paddle/fluid/operators/matmul_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0c84c4751e78e6bd02c4a988a7d3558962a0de5 --- /dev/null +++ b/paddle/fluid/operators/matmul_op_mlu.cc @@ -0,0 +1,337 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +static void Mul(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + MLUCnnl::OpTensor(ctx, mul_op_desc.get(), x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out), ToCnnlDataType(), alpha); +} + +template +static void MatMul2D(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_LT(fabs(alpha - 1.0), std::numeric_limits::epsilon(), + platform::errors::InvalidArgument( + "MLU(matmul): alpha should be equal to 1.0! " + "Other values are not supported yet." + "But received alpha is %d.", + alpha)); + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnl::Matmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out)); +} + +template +static void MatMulND(const framework::ExecutionContext& ctx, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y, const float alpha) { + if (!Out->initialized()) { + Out->mutable_data(ctx.GetPlace()); + } + + PADDLE_ENFORCE_LT(fabs(alpha - 1.0), std::numeric_limits::epsilon(), + platform::errors::InvalidArgument( + "MLU(matmul): alpha should be equal to 1.0! " + "Other values are not supported yet." + "But received alpha is %d.", + alpha)); + + MLUCnnlTensorDesc x_desc(X, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc y_desc(Y, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*Out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + MLUCnnl::BatchMatmul(ctx, trans_x, trans_y, x_desc.get(), GetBasePtr(&X), + y_desc.get(), GetBasePtr(&Y), out_desc.get(), + GetBasePtr(Out)); +} + +template +static void ReduceDims(const framework::ExecutionContext& ctx, + const std::vector& dims, + const std::vector& bcast_dims, const Tensor& in, + Tensor* out) { + std::vector axes; + int64_t size = bcast_dims.size(); + int64_t diff = bcast_dims.size() - dims.size(); + for (int64_t i = 0; i < size; ++i) { + if (i < diff) { + axes.push_back(i); + continue; + } + if (bcast_dims[i] > dims[i - diff]) { + axes.push_back(i); + } + } + out->mutable_data(ctx.GetPlace()); + + MLUCnnlTensorDesc in_desc(in, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + MLUCnnlTensorDesc out_desc(*out, CNNL_LAYOUT_ARRAY, ToCnnlDataType()); + + std::vector reduce_dims(axes.begin(), axes.end()); + MLUCnnlReduceDesc reduce_desc(reduce_dims, CNNL_REDUCE_ADD, + ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduce_desc.get(), nullptr, + in_desc.get(), GetBasePtr(&in), 0 /*indices_size*/, nullptr, + nullptr, out_desc.get(), GetBasePtr(out)); +} + +template +class MatMulMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Output("Out"); + bool transpose_x = ctx.Attr("transpose_X"); + bool transpose_y = ctx.Attr("transpose_Y"); + float alpha = static_cast(ctx.Attr("alpha")); + + std::vector x_dims = phi::vectorize(X->dims()); + std::vector y_dims = phi::vectorize(Y->dims()); + std::vector out_dims = phi::vectorize(Out->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + + // Case 1: [K] x [K] = [1] + // Equal: [1, K] x [K, 1] = [1, 1] => [1] + const bool all_one_dim = (x_ndim == 1 && y_ndim == 1); + if (all_one_dim) { + Out->Resize({1, 1}); + } + + // Resize dim 1 to 2 + Tensor x_temp, y_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + x_temp.Resize(phi::make_ddim(x_dims)); + x_ndim = 2; + // matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim` + if (out_dims.size() < y_dims.size()) { + std::vector temp_out_dims(out_dims.begin(), out_dims.end()); + temp_out_dims.insert(temp_out_dims.end() - 1, 1); + Out->Resize(phi::make_ddim(temp_out_dims)); + } + } + if (y_ndim == 1) { + y_dims.push_back(1); + y_temp.Resize(phi::make_ddim(y_dims)); + y_ndim = 2; + // matmul op of mlu needs `std::max(x->dim, y->dim) == out->dim` + if (out_dims.size() < x_dims.size()) { + std::vector temp_out_dims(out_dims.begin(), out_dims.end()); + temp_out_dims.push_back(1); + Out->Resize(phi::make_ddim(temp_out_dims)); + } + } + + const int K = transpose_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (transpose_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); + } + + if (x_ndim == 2 && y_ndim == 2) { + // Case 2: [M, K] x [K, N] = [M, N] + MatMul2D(ctx, x_temp, y_temp, Out, transpose_x, transpose_y, alpha); + } else { + // Case 3: [B, M, K] x [K, N] = [B, M, N] + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + MatMulND(ctx, x_temp, y_temp, Out, transpose_x, transpose_y, alpha); + } + + if (phi::vectorize(Out->dims()) != out_dims) { + Out->Resize(phi::make_ddim(out_dims)); + } + } +}; + +template +class MatMulGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + bool transpose_x = ctx.Attr("transpose_X"); + bool transpose_y = ctx.Attr("transpose_Y"); + float alpha = static_cast(ctx.Attr("alpha")); + + std::vector x_dims = phi::vectorize(X->dims()); + std::vector y_dims = phi::vectorize(Y->dims()); + std::vector out_dims = phi::vectorize(dOut->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int out_ndim = out_dims.size(); + + // Case 1: [K] x [K] = [1] + if (x_ndim == 1 && y_ndim == 1) { + if (dX) { + Mul(ctx, *dOut, *Y, dX, alpha); + } + if (dY) { + Mul(ctx, *dOut, *X, dY, alpha); + } + return; + } + + // Resize dim 1 to 2 + Tensor x_temp, y_temp, dout_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + dout_temp.ShareDataWith(*dOut); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + out_dims.insert(out_dims.end() - 1, 1); + x_temp.Resize(phi::make_ddim(x_dims)); + dout_temp.Resize(phi::make_ddim(out_dims)); + x_ndim = 2; + out_ndim += 1; + } + if (y_ndim == 1) { + y_dims.push_back(1); + out_dims.push_back(1); + y_temp.Resize(phi::make_ddim(y_dims)); + dout_temp.Resize(phi::make_ddim(out_dims)); + y_ndim = 2; + out_ndim += 1; + } + + // Case 2: [M, K] x [K, N] = [M, N] + if (out_ndim == 2) { + if (dX) { + dX->Resize(phi::make_ddim(x_dims)); + if (transpose_x) { + MatMul2D(ctx, y_temp, dout_temp, dX, transpose_y, true, alpha); + } else { + MatMul2D(ctx, dout_temp, y_temp, dX, false, !transpose_y, alpha); + } + dX->Resize(X->dims()); + } + if (dY) { + dY->Resize(phi::make_ddim(y_dims)); + if (transpose_y) { + MatMul2D(ctx, dout_temp, x_temp, dY, true, transpose_x, alpha); + } else { + MatMul2D(ctx, x_temp, dout_temp, dY, !transpose_x, false, alpha); + } + dY->Resize(Y->dims()); + } + return; + } + + // Case 3: [B, M, K] x [K, N] = [B, M, N] + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + std::vector x_bcast_dims(out_ndim, 1); + std::vector y_bcast_dims(out_ndim, 1); + std::copy(out_dims.begin(), out_dims.end() - 2, x_bcast_dims.begin()); + std::copy(out_dims.begin(), out_dims.end() - 2, y_bcast_dims.begin()); + std::copy(x_dims.end() - 2, x_dims.end(), x_bcast_dims.end() - 2); + std::copy(y_dims.end() - 2, y_dims.end(), y_bcast_dims.end() - 2); + + if (dX) { + Tensor dx_temp(X->type()); + if (x_dims != x_bcast_dims) { + dx_temp.Resize(phi::make_ddim(x_bcast_dims)); + } else { + dX->mutable_data(ctx.GetPlace()); + dx_temp.ShareDataWith(*dX); + } + + if (transpose_x) { + MatMulND(ctx, y_temp, dout_temp, &dx_temp, transpose_y, true, alpha); + } else { + MatMulND(ctx, dout_temp, y_temp, &dx_temp, false, !transpose_y, + alpha); + } + + if (x_dims != x_bcast_dims) { + ReduceDims(ctx, x_dims, x_bcast_dims, dx_temp, dX); + } + } + + if (dY) { + Tensor dy_temp(Y->type()); + if (y_dims != y_bcast_dims) { + dy_temp.Resize(phi::make_ddim(y_bcast_dims)); + } else { + dY->mutable_data(ctx.GetPlace()); + dy_temp.ShareDataWith(*dY); + } + + if (transpose_y) { + MatMulND(ctx, dout_temp, x_temp, &dy_temp, true, transpose_x, alpha); + } else { + MatMulND(ctx, x_temp, dout_temp, &dy_temp, !transpose_x, false, + alpha); + } + + if (y_dims != y_bcast_dims) { + ReduceDims(ctx, y_dims, y_bcast_dims, dy_temp, dY); + } + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(matmul, ops::MatMulMLUKernel, + ops::MatMulMLUKernel); +REGISTER_OP_MLU_KERNEL(matmul_grad, ops::MatMulGradMLUKernel, + ops::MatMulGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_matmul_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_matmul_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..adfff112e6be216330ca8e542257944a1b32213b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_matmul_op_mlu.py @@ -0,0 +1,329 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2022 + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False, scale=1.0): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, )) + elif X.ndim == 2: + X = X.T + else: + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((Y.size, )) + else: + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float64") + if abs(scale - 1.0) > 1e-09: + Out = Out * scale + return Out + + +class TestMatMulOp(OpTest): + """ + basic case + """ + + def setUp(self): + self.set_mlu() + self.op_type = "matmul" + self.init_dtype() + self.init_alpha() + self.config() + + X = np.random.random(self.x_shape).astype(self.dtype) + Y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + X = -0.1 + 0.2 * X + Y = -0.1 + 0.2 * Y + + Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y, + self.alpha) + Out = Out.astype(self.dtype) + self.inputs = {'X': X, 'Y': Y} + self.attrs = { + 'transpose_X': self.transpose_X, + 'transpose_Y': self.transpose_Y, + 'alpha': self.alpha + } + self.outputs = {'Out': Out} + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + + def config(self): + self.x_shape = (100, ) + self.y_shape = (100, ) + self.transpose_X = False + self.transpose_Y = False + + def init_alpha(self): + self.alpha = 1.0 + + def init_dtype(self): + self.dtype = "float32" + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-7) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + +class TestMatMulOp1(TestMatMulOp): + """ + case x_ndim == 1, y_ndim != 1 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 3, 2, 100) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp2(TestMatMulOp): + """ + case x_ndim != 1, y_ndim == 1 + """ + + def config(self): + self.x_shape = (1, 2, 100, 1) + self.y_shape = (100, ) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp3(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (2, 100) + self.y_shape = (100, 2) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp4(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (2, 100) + self.y_shape = (2, 100) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp5(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (100, 2) + self.y_shape = (100, 2) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp6(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 2, 25) + self.y_shape = (25, 4) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp7(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (1, 2, 25) + self.y_shape = (4, 25) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp8(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (1, 25, 4) + self.y_shape = (25, 4) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp9(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 5, 10) + self.y_shape = (2, 10, 5) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp10(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 10, 5) + self.y_shape = (2, 10, 5) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp11(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 5, 10) + self.y_shape = (2, 5, 10) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp12(TestMatMulOp): + """ + case to check the gradient for special case + """ + + def config(self): + self.x_shape = (100) + self.y_shape = (1, 2, 2, 100, 2) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp13(TestMatMulOp): + """ + case to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 1, 100) + self.y_shape = (100) + self.transpose_X = False + self.transpose_Y = False + + +# TODO(mlu): alpha will be supported in next version +#--------------------test matmul alpha-------------------- +# def create_test_alpha_class(parent): +# class TestMatMulOpAlphaCase(parent): +# def init_alpha(self): +# self.alpha = 0.125 + +# cls_name = "{0}_{1}".format(parent.__name__, "Alpha") +# TestMatMulOpAlphaCase.__name__ = cls_name +# globals()[cls_name] = TestMatMulOpAlphaCase + +# create_test_alpha_class(TestMatMulOp) +# create_test_alpha_class(TestMatMulOp1) +# create_test_alpha_class(TestMatMulOp2) +# create_test_alpha_class(TestMatMulOp3) +# create_test_alpha_class(TestMatMulOp4) +# create_test_alpha_class(TestMatMulOp5) +# create_test_alpha_class(TestMatMulOp6) +# create_test_alpha_class(TestMatMulOp9) +# create_test_alpha_class(TestMatMulOp10) +# create_test_alpha_class(TestMatMulOp11) +# create_test_alpha_class(TestMatMulOp12) +# create_test_alpha_class(TestMatMulOp13) + + +#--------------------test matmul fp16-------------------- +def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5): + class TestMatMulOpFp16Case(parent): + def init_kernel_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=atol) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X', 'Y'], + 'Out', + max_relative_error=max_relative_error) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestMatMulOpFp16Case.__name__ = cls_name + globals()[cls_name] = TestMatMulOpFp16Case + + +create_test_fp16_class(TestMatMulOp) +create_test_fp16_class(TestMatMulOp1) +create_test_fp16_class(TestMatMulOp2) +create_test_fp16_class(TestMatMulOp3) +create_test_fp16_class(TestMatMulOp4) +create_test_fp16_class(TestMatMulOp5) +create_test_fp16_class(TestMatMulOp6) +create_test_fp16_class(TestMatMulOp9) +create_test_fp16_class(TestMatMulOp10) +create_test_fp16_class(TestMatMulOp11) +create_test_fp16_class(TestMatMulOp12) +create_test_fp16_class(TestMatMulOp13) + +if __name__ == "__main__": + unittest.main()