未验证 提交 5bfbaa8b 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports topk, sort, argsort, fix infershape (#55510)

上级 7341e6fc
......@@ -326,8 +326,18 @@ std::vector<std::vector<int>> InferShapeForSort(
break;
}
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
}
std::vector<std::vector<int>> res{inputs_shape[0]};
return res;
}
......@@ -352,11 +362,17 @@ std::vector<std::vector<int>> InferShapeForArgSort(
break;
}
}
if (axis < 0) {
axis += inputs_shape[0].size();
if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
}
CHECK_GT(inputs_shape[0].size(), axis)
<< "The input's dim should be greater than axis! ";
std::vector<std::vector<int>> res{inputs_shape[0], inputs_shape[0]};
return res;
......@@ -381,12 +397,19 @@ std::vector<std::vector<int>> InferShapeForTopK(
auto axis_it = attrs.find("axis");
CHECK(axis_it != attrs.end()) << "The attr axis of topk does not exist.";
int axis = absl::get<int>(axis_it->second);
if (axis < 0) {
axis += res[0].size();
if (inputs_shape[0].empty()) {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} else {
if (axis < 0) {
axis += inputs_shape[0].size();
}
CHECK_GE(axis, 0);
CHECK_LT(axis, res[0].size());
res[0][axis] = std::min(res[0][axis], k);
}
CHECK_GE(axis, 0);
CHECK_LT(axis, res[0].size());
res[0][axis] = std::min(res[0][axis], k);
return {res[0], res[0]};
}
......
......@@ -788,6 +788,153 @@ class TestTransposeOp(OpTest):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestArgsortOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.argsort(x, axis=self.axis)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("argsort_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.argsort(x, self.axis, True)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], out)
self.cinn_outputs = np.array([res[0]]).astype("int64")
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
class TestArgsortOp2(TestArgsortOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestSortOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.sort(x, axis=self.axis)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("sort_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.sort(x, self.axis, True)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
class TestSortOp2(TestSortOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestTopkOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = -1
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out, indices = paddle.topk(x, k=1, axis=self.axis)
self.paddle_outputs = [out, indices]
def build_cinn_program(self, target):
builder = NetBuilder("topk_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.top_k(x, 1, self.axis, True)
prog = builder.build()
res = self.get_cinn_output(
prog, target, [x], [self.inputs["x"]], [out[0], out[1]]
)
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
self.assertEqual(res[1].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
class TestTopkOp2(TestTopkOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.axis = 0
self.target_shape = ()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册