未验证 提交 6a9cbb27 编写于 作者: F Fisher 提交者: GitHub

[CINN] Refactor op test mul (#55135)

* Refactor op test mul

* Fix test error for mul
上级 85831c32
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) 2021 CINN Authors. All Rights Reserved. # Copyright (c) 2021 CINN Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,17 +13,45 @@ ...@@ -14,17 +13,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import unittest
import cinn
import numpy as np
from cinn.common import * from cinn.common import *
from cinn.frontend import * from cinn.frontend import *
from op_test import OpTest, OpTestTool from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle import paddle
import paddle.nn.functional as F
def infer_shape(
x_shape: list,
y_shape: list,
x_num_col_dim: int,
y_num_col_dim: int,
is_infer: bool,
):
def flatten_shape(shape: list, num_col_dim: int) -> list:
if len(shape) <= 2:
return shape
else:
new_shape = [1, 1]
for i, x in enumerate(shape):
if i < num_col_dim:
new_shape[0] *= x
else:
new_shape[1] *= x
return new_shape
x_new_shape = flatten_shape(x_shape, x_num_col_dim)
y_new_shape = flatten_shape(y_shape, y_num_col_dim)
out_shape = []
for i in range(x_num_col_dim):
out_shape.append(x_shape[i])
if is_infer:
for i in range(y_num_col_dim):
out_shape.append(y_shape[i])
else:
for i in range(y_num_col_dim, len(y_shape)):
out_shape.append(y_shape[i])
return x_new_shape, y_new_shape, out_shape
@OpTestTool.skip_if( @OpTestTool.skip_if(
...@@ -32,37 +59,174 @@ import paddle.nn.functional as F ...@@ -32,37 +59,174 @@ import paddle.nn.functional as F
) )
class TestMulOp(OpTest): class TestMulOp(OpTest):
def setUp(self): def setUp(self):
self.init_case() self.prepare_inputs()
def init_case(self): def prepare_inputs(self):
self.inputs = { self.x_np = self.random(
"x": np.random.random((16, 64)).astype("float32"), shape=self.case["x_shape"], dtype=self.case["dtype"]
"y": np.random.random((64, 16)).astype("float32"), )
} self.y_np = self.random(
shape=self.case["y_shape"], dtype=self.case["dtype"]
)
def build_paddle_program(self, target): def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False) x = paddle.to_tensor(self.x_np, stop_gradient=False)
y = paddle.to_tensor(self.inputs["y"], stop_gradient=False) y = paddle.to_tensor(self.y_np, stop_gradient=False)
out = paddle.matmul(x, y) x_shape, y_shape, out_shape = infer_shape(
x.shape,
y.shape,
self.case["x_num_col_dims"],
self.case["y_num_col_dims"],
self.case["is_infer"],
)
x = paddle.reshape(x, x_shape)
y = paddle.reshape(y, y_shape)
if self.case["is_infer"]:
out = paddle.matmul(x, y, transpose_x=False, transpose_y=True)
else:
out = paddle.matmul(x, y)
out = paddle.reshape(out, out_shape)
self.paddle_outputs = [out] self.paddle_outputs = [out]
def build_cinn_program(self, target): def build_cinn_program(self, target):
builder = NetBuilder("matmul") builder = NetBuilder("mul")
x = builder.create_input(
x = builder.create_input(Float(32), self.inputs["x"].shape, "x") self.nptype2cinntype(self.case["dtype"]), self.case["x_shape"], "x"
y = builder.create_input(Float(32), self.inputs["y"].shape, "y") )
y = builder.create_input(
out = builder.matmul(x, y) self.nptype2cinntype(self.case["dtype"]), self.case["y_shape"], "y"
)
out = builder.mul(
x,
y,
x_num_col_dims=self.case["x_num_col_dims"],
y_num_col_dims=self.case["y_num_col_dims"],
is_infer=self.case["is_infer"],
)
prog = builder.build() prog = builder.build()
forward_res = self.get_cinn_output( res = self.get_cinn_output(
prog, target, [x, y], [self.inputs["x"], self.inputs["y"]], [out] prog, target, [x, y], [self.x_np, self.y_np], [out]
) )
self.cinn_outputs = res
self.cinn_outputs = forward_res
def test_check_results(self): def test_check_results(self):
self.check_outputs_and_grads() max_relative_error = (
self.case["max_relative_error"]
if "max_relative_error" in self.case
else 1e-5
)
self.check_outputs_and_grads(max_relative_error=max_relative_error)
class TestMulOpShape(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestMulOpShape"
self.cls = TestMulOp
self.inputs = [
{
"x_shape": [1, 1],
"y_shape": [1, 1],
"x_num_col_dims": 1,
"y_num_col_dims": 1,
},
{
"x_shape": [32, 64],
"y_shape": [64, 32],
"x_num_col_dims": 1,
"y_num_col_dims": 1,
},
{
"x_shape": [2, 3, 4],
"y_shape": [4, 3, 2],
"x_num_col_dims": 1,
"y_num_col_dims": 2,
},
{
"x_shape": [16, 8, 4, 2],
"y_shape": [2, 4, 8, 16],
"x_num_col_dims": 2,
"y_num_col_dims": 2,
},
{
"x_shape": [1, 1, 1, 1],
"y_shape": [1, 1, 1, 1],
"x_num_col_dims": 2,
"y_num_col_dims": 2,
},
]
self.dtypes = [
{
"dtype": "float32",
},
]
self.attrs = [
{
"is_infer": False,
},
]
class TestMulOpDtype(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestMulOpDtype"
self.cls = TestMulOp
self.inputs = [
{
"x_shape": [32, 64],
"y_shape": [64, 32],
"x_num_col_dims": 1,
"y_num_col_dims": 1,
},
]
self.dtypes = [
# cublas bf16 gemm requires GPU compute capability >= 80
# {
# "dtype": "bfloat16",
# "max_relative_error": 1e-3,
# },
{
"dtype": "float16",
"max_relative_error": 1e-2,
},
{
"dtype": "float32",
},
{
"dtype": "float64",
},
]
self.attrs = [
{
"is_infer": False,
},
]
class TestMulOpAttr(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestMulOpAttr"
self.cls = TestMulOp
self.inputs = [
{
"x_shape": [16, 8, 4, 2],
"y_shape": [16, 8, 4, 2],
"x_num_col_dims": 2,
"y_num_col_dims": 2,
},
]
self.dtypes = [
{
"dtype": "float32",
},
]
self.attrs = [
{
"is_infer": True,
},
]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() TestMulOpShape().run()
TestMulOpDtype().run()
TestMulOpAttr().run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册