未验证 提交 b869e963 编写于 作者: F Fisher 提交者: GitHub

[CINN] Add test and full dtype support for argx op (#54939)

* Add test and full dtype support for argsort, argmax, argmin

* Reformat code

* Merge develop and remove printf in python code

* Reformat code

* Reformat code
上级 77bef883
...@@ -686,6 +686,8 @@ __device__ inline int cinn_cuda_find_float_from(const float *buf, int size, floa ...@@ -686,6 +686,8 @@ __device__ inline int cinn_cuda_find_float_from(const float *buf, int size, floa
CINN_NVGPU_LT_NUM(fp32, float) CINN_NVGPU_LT_NUM(fp32, float)
CINN_NVGPU_LT_NUM(fp64, double) CINN_NVGPU_LT_NUM(fp64, double)
CINN_NVGPU_LT_NUM(uint8, uint8_t)
CINN_NVGPU_LT_NUM(int16, int16_t)
CINN_NVGPU_LT_NUM(int32, int) CINN_NVGPU_LT_NUM(int32, int)
CINN_NVGPU_LT_NUM(int64, long long int) CINN_NVGPU_LT_NUM(int64, long long int)
#ifdef CINN_CUDA_FP16 #ifdef CINN_CUDA_FP16
...@@ -706,6 +708,8 @@ CINN_NVGPU_LT_NUM(fp16, float16) ...@@ -706,6 +708,8 @@ CINN_NVGPU_LT_NUM(fp16, float16)
CINN_NVGPU_GT_NUM(fp32, float) CINN_NVGPU_GT_NUM(fp32, float)
CINN_NVGPU_GT_NUM(fp64, double) CINN_NVGPU_GT_NUM(fp64, double)
CINN_NVGPU_GT_NUM(uint8, uint8_t)
CINN_NVGPU_GT_NUM(int16, int16_t)
CINN_NVGPU_GT_NUM(int32, int) CINN_NVGPU_GT_NUM(int32, int)
CINN_NVGPU_GT_NUM(int64, long long int) CINN_NVGPU_GT_NUM(int64, long long int)
#ifdef CINN_CUDA_FP16 #ifdef CINN_CUDA_FP16
......
...@@ -345,6 +345,9 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { ...@@ -345,6 +345,9 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
_REGISTER_CINN_NVGPU_LT_NUM(fp32, float); _REGISTER_CINN_NVGPU_LT_NUM(fp32, float);
_REGISTER_CINN_NVGPU_LT_NUM(fp64, double); _REGISTER_CINN_NVGPU_LT_NUM(fp64, double);
_REGISTER_CINN_NVGPU_LT_NUM(uint8, uint8_t);
_REGISTER_CINN_NVGPU_LT_NUM(int16, int16_t);
_REGISTER_CINN_NVGPU_LT_NUM(int32, int); _REGISTER_CINN_NVGPU_LT_NUM(int32, int);
_REGISTER_CINN_NVGPU_LT_NUM(int64, int64_t); _REGISTER_CINN_NVGPU_LT_NUM(int64, int64_t);
...@@ -362,6 +365,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) { ...@@ -362,6 +365,8 @@ CINN_REGISTER_HELPER(cuda_intrinsics) {
_REGISTER_CINN_NVGPU_GT_NUM(fp32, float); _REGISTER_CINN_NVGPU_GT_NUM(fp32, float);
_REGISTER_CINN_NVGPU_GT_NUM(fp64, double); _REGISTER_CINN_NVGPU_GT_NUM(fp64, double);
_REGISTER_CINN_NVGPU_GT_NUM(uint8, uint8_t);
_REGISTER_CINN_NVGPU_GT_NUM(int16, int16_t);
_REGISTER_CINN_NVGPU_GT_NUM(int32, int); _REGISTER_CINN_NVGPU_GT_NUM(int32, int);
_REGISTER_CINN_NVGPU_GT_NUM(int64, int64_t); _REGISTER_CINN_NVGPU_GT_NUM(int64, int64_t);
......
#!/usr/bin/env python3
# Copyright (c) 2023 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from cinn.common import *
from cinn.frontend import *
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestArgMaxOp(OpTest):
def setUp(self):
self.prepare_inputs()
def prepare_inputs(self):
self.x_np = self.random(
self.case["shape"], self.case["dtype"], low=0, high=10
)
self.axis = self.case["axis"]
self.keepdim = self.case["keepdim"]
def build_paddle_program(self, target):
x = paddle.to_tensor(self.x_np, stop_gradient=True)
out = paddle.argmax(x, self.axis, self.keepdim)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("argmax")
x = builder.create_input(
self.nptype2cinntype(self.case["dtype"]), self.case["shape"], "x"
)
out = builder.argmax(x, self.axis, self.keepdim)
prog = builder.build()
forward_res = self.get_cinn_output(
prog, target, [x], [self.x_np], [out]
)
self.cinn_outputs = np.array(forward_res).astype("int64")
def test_check_results(self):
self.check_outputs_and_grads()
class TestArgMaxOpShapeTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMaxOpShapeTest"
self.cls = TestArgMaxOp
self.inputs = [
{
"shape": [512],
},
{
"shape": [1024],
},
{
"shape": [1200],
},
{
"shape": [64, 16],
},
{
"shape": [4, 32, 8],
},
{
"shape": [16, 8, 4, 2],
},
{
"shape": [2, 8, 4, 2, 5],
},
{
"shape": [4, 8, 1, 2, 16],
},
{
"shape": [1],
},
{
"shape": [1, 1, 1, 1],
},
{
"shape": [1, 1, 1, 1, 1],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [{"axis": 0, "keepdim": False}]
class TestArgMaxOpDtypeTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMaxOpDtypeTest"
self.cls = TestArgMaxOp
self.inputs = [
{
"shape": [1024],
},
]
self.dtypes = [
{
"dtype": "float16",
},
{
"dtype": "float32",
},
{
"dtype": "float64",
},
{
"dtype": "uint8",
},
{
"dtype": "int16",
},
{
"dtype": "int32",
},
{
"dtype": "int64",
},
]
self.attrs = [{"axis": 0, "keepdim": False}]
class TestArgMaxOpAxisTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMaxOpAxisTest"
self.cls = TestArgMaxOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "keepdim": False},
{"axis": 1, "keepdim": False},
{"axis": 2, "keepdim": False},
{"axis": 3, "keepdim": False},
]
class TestArgMaxOpKeepdimTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMaxOpKeepdimTest"
self.cls = TestArgMaxOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "keepdim": True},
{"axis": 1, "keepdim": True},
{"axis": 2, "keepdim": True},
{"axis": 3, "keepdim": True},
]
if __name__ == "__main__":
TestArgMaxOpShapeTest().run()
TestArgMaxOpDtypeTest().run()
TestArgMaxOpAxisTest().run()
TestArgMaxOpKeepdimTest().run()
#!/usr/bin/env python3
# Copyright (c) 2023 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from cinn.common import *
from cinn.frontend import *
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestArgMinOp(OpTest):
def setUp(self):
self.prepare_inputs()
def prepare_inputs(self):
self.x_np = self.random(
self.case["shape"], self.case["dtype"], low=0, high=10
)
self.axis = self.case["axis"]
self.keepdim = self.case["keepdim"]
def build_paddle_program(self, target):
x = paddle.to_tensor(self.x_np, stop_gradient=True)
out = paddle.argmin(x, self.axis, self.keepdim)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("argmin")
x = builder.create_input(
self.nptype2cinntype(self.case["dtype"]), self.case["shape"], "x"
)
out = builder.argmin(x, self.axis, self.keepdim)
prog = builder.build()
forward_res = self.get_cinn_output(
prog, target, [x], [self.x_np], [out]
)
self.cinn_outputs = np.array(forward_res).astype("int64")
def test_check_results(self):
self.check_outputs_and_grads()
class TestArgMinOpShapeTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMinOpShapeTest"
self.cls = TestArgMinOp
self.inputs = [
{
"shape": [512],
},
{
"shape": [1024],
},
{
"shape": [1200],
},
{
"shape": [64, 16],
},
{
"shape": [4, 32, 8],
},
{
"shape": [16, 8, 4, 2],
},
{
"shape": [2, 8, 4, 2, 5],
},
{
"shape": [4, 8, 1, 2, 16],
},
{
"shape": [1],
},
{
"shape": [1, 1, 1, 1],
},
{
"shape": [1, 1, 1, 1, 1],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [{"axis": 0, "keepdim": False}]
class TestArgMinOpDtypeTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMinOpDtypeTest"
self.cls = TestArgMinOp
self.inputs = [
{
"shape": [1024],
},
]
self.dtypes = [
{
"dtype": "float16",
},
{
"dtype": "float32",
},
{
"dtype": "float64",
},
{
"dtype": "uint8",
},
{
"dtype": "int16",
},
{
"dtype": "int32",
},
{
"dtype": "int64",
},
]
self.attrs = [{"axis": 0, "keepdim": False}]
class TestArgMinOpAxisTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMinOpAxisTest"
self.cls = TestArgMinOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "keepdim": False},
{"axis": 1, "keepdim": False},
{"axis": 2, "keepdim": False},
{"axis": 3, "keepdim": False},
]
class TestArgMinOpKeepdimTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgMinOpKeepdimTest"
self.cls = TestArgMinOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "keepdim": True},
{"axis": 1, "keepdim": True},
{"axis": 2, "keepdim": True},
{"axis": 3, "keepdim": True},
]
if __name__ == "__main__":
TestArgMinOpShapeTest().run()
TestArgMinOpDtypeTest().run()
TestArgMinOpAxisTest().run()
TestArgMinOpKeepdimTest().run()
...@@ -14,16 +14,14 @@ ...@@ -14,16 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
import cinn
import numpy as np import numpy as np
from cinn.common import * from cinn.common import *
from cinn.frontend import * from cinn.frontend import *
from op_test import OpTest, OpTestTool from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle import paddle
import paddle.nn.functional as F
@OpTestTool.skip_if( @OpTestTool.skip_if(
...@@ -31,96 +29,143 @@ import paddle.nn.functional as F ...@@ -31,96 +29,143 @@ import paddle.nn.functional as F
) )
class TestArgSortOp(OpTest): class TestArgSortOp(OpTest):
def setUp(self): def setUp(self):
self.init_case() self.prepare_inputs()
def init_case(self): def prepare_inputs(self):
self.inputs = { self.x_np = self.random(self.case["shape"], self.case["dtype"])
"x1": np.random.random( self.axis = self.case["axis"]
[ self.descending = self.case["descending"]
2,
4,
]
).astype("float32")
}
self.axis = 1
self.descending = False
def build_paddle_program(self, target): def build_paddle_program(self, target):
x1 = paddle.to_tensor(self.inputs["x1"], stop_gradient=True) x1 = paddle.to_tensor(self.x_np, stop_gradient=True)
out = paddle.argsort(x1, self.axis, self.descending) out = paddle.argsort(x1, self.axis, self.descending)
self.paddle_outputs = [out] self.paddle_outputs = [out]
def build_cinn_program(self, target): def build_cinn_program(self, target):
builder = NetBuilder("argsort") builder = NetBuilder("argsort")
x1 = builder.create_input(Float(32), self.inputs["x1"].shape, "x1") x1 = builder.create_input(
self.nptype2cinntype(self.case["dtype"]), self.case["shape"], "x1"
)
out = builder.argsort(x1, self.axis, not self.descending) out = builder.argsort(x1, self.axis, not self.descending)
prog = builder.build() prog = builder.build()
forward_res = self.get_cinn_output( forward_res = self.get_cinn_output(prog, target, [x1], [self.x_np], out)
prog, target, [x1], [self.inputs["x1"]], out
)
self.cinn_outputs = np.array([forward_res[0]]).astype("int64") self.cinn_outputs = np.array([forward_res[0]]).astype("int64")
def test_check_results(self): def test_check_results(self):
self.check_outputs_and_grads() self.check_outputs_and_grads()
class TestArgSortCase1(TestArgSortOp): class TestArgSortOpShapeTest(TestCaseHelper):
def init_case(self): def init_attrs(self):
self.inputs = { self.class_name = "ArgSortOpShapeTest"
"x1": np.random.random( self.cls = TestArgSortOp
[ self.inputs = [
2, {
4, "shape": [512],
] },
).astype("float32") {
} "shape": [1024],
self.axis = 0 },
self.descending = False {
"shape": [1200],
},
class TestArgSortCase2(TestArgSortOp): {
def init_case(self): "shape": [64, 16],
self.inputs = { },
"x1": np.random.random( {
[ "shape": [4, 32, 8],
2, },
4, {
] "shape": [16, 8, 4, 2],
).astype("float32") },
} {
self.axis = 0 "shape": [2, 8, 4, 2, 5],
self.descending = True },
{
"shape": [4, 8, 1, 2, 16],
class TestArgSortCase3(TestArgSortOp): },
def init_case(self): {
self.inputs = { "shape": [1],
"x1": np.random.random( },
[ {
2, "shape": [1, 1, 1, 1],
4, },
] {
).astype("float32") "shape": [1, 1, 1, 1, 1],
} },
self.axis = 1 ]
self.descending = True self.dtypes = [{"dtype": "float32"}]
self.attrs = [{"axis": 0, "descending": False}]
class TestArgSortCase4(TestArgSortOp):
def init_case(self): class TestArgSortOpDtypeTest(TestCaseHelper):
self.inputs = { def init_attrs(self):
"x1": np.random.random( self.class_name = "ArgSortOpDtypeTest"
[ self.cls = TestArgSortOp
2, self.inputs = [
4, {
] "shape": [1024],
).astype("float32") },
} ]
self.axis = -1 self.dtypes = [
self.descending = True {
"dtype": "float32",
},
{
"dtype": "float64",
},
# Throw dtype not support error in paddle
# {
# "dtype": "uint8",
# },
{
"dtype": "int32",
},
{
"dtype": "int64",
},
]
self.attrs = [{"axis": 0, "descending": False}]
class TestArgSortOpAxisTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgSortOpAxisTest"
self.cls = TestArgSortOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "descending": False},
{"axis": 1, "descending": False},
{"axis": 2, "descending": False},
{"axis": 3, "descending": False},
]
class TestArgSortOpDescedingTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "ArgSortOpDescedingTest"
self.cls = TestArgSortOp
self.inputs = [
{
"shape": [16, 8, 4, 2],
},
]
self.dtypes = [{"dtype": "float32"}]
self.attrs = [
{"axis": 0, "descending": True},
{"axis": 1, "descending": True},
{"axis": 2, "descending": True},
{"axis": 3, "descending": True},
]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() TestArgSortOpShapeTest().run()
TestArgSortOpDtypeTest().run()
TestArgSortOpAxisTest().run()
TestArgSortOpDescedingTest().run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册