未验证 提交 015285fd 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports argmax, fix infershape (#55489)

* [0D-Tensor] CINN supports argmax, fix infershape

* [0D-Tensor] CINN supports argmax, fix infershape
上级 c7ba0312
...@@ -186,17 +186,22 @@ std::vector<shape_t> InferShapeForArgmax( ...@@ -186,17 +186,22 @@ std::vector<shape_t> InferShapeForArgmax(
const framework::AttrMapType &attrs) { const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 1UL); CHECK_EQ(inputs_shape.size(), 1UL);
auto ndim = inputs_shape[0].size(); auto ndim = inputs_shape[0].size();
CHECK_GT(ndim, 0) << "tensor's dim must be more than 0";
int axis; int axis;
bool keep_dim; bool keep_dim;
CHECK(attrs.find("axis") != attrs.end()); CHECK(attrs.find("axis") != attrs.end());
axis = absl::get<int>(attrs.at("axis")); axis = absl::get<int>(attrs.at("axis"));
if (axis < 0) { if (ndim > 0) {
axis = static_cast<int>(ndim) + axis; if (axis < 0) {
axis = static_cast<int>(ndim) + axis;
}
CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim";
CHECK_GE(axis, 0) << "Axis must be more than 0";
} else {
// 0D Tensor
CHECK(axis == 0 || axis == -1)
<< "Axis must be 0 or -1 if input tensor is 0-dim";
} }
CHECK_LT(axis, ndim) << "Axis must be less than tensor's dim";
CHECK_GE(axis, 0) << "Axis must be more than 0";
CHECK(attrs.find("keep_dim") != attrs.end()); CHECK(attrs.find("keep_dim") != attrs.end());
keep_dim = absl::get<bool>(attrs.at("keep_dim")); keep_dim = absl::get<bool>(attrs.at("keep_dim"));
...@@ -215,10 +220,7 @@ std::vector<shape_t> InferShapeForArgmax( ...@@ -215,10 +220,7 @@ std::vector<shape_t> InferShapeForArgmax(
if (keep_dim) { if (keep_dim) {
CHECK_EQ(ndim, out_shapes.size()); CHECK_EQ(ndim, out_shapes.size());
} else { } else {
CHECK_EQ(ndim - 1, out_shapes.size()); CHECK(ndim - 1 == out_shapes.size() || ndim == 0 && out_shapes.empty());
}
if (out_shapes.empty()) {
out_shapes.push_back(1);
} }
return {out_shapes}; return {out_shapes};
......
...@@ -674,6 +674,82 @@ class TestCastOp(OpTest): ...@@ -674,6 +674,82 @@ class TestCastOp(OpTest):
self.check_outputs_and_grads() self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestArgmaxOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.param = (0,)
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.argmax(x, *self.param)
out = paddle.cast(out, 'int32')
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("squeeze_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.argmax(x, *self.param)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
class TestArgmaxOp2(TestArgmaxOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.param = (-1,)
self.target_shape = ()
class TestArgmaxOp1D(TestArgmaxOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [5]).astype(self.dtype),
}
self.param = (0,)
self.target_shape = ()
class TestArgmaxOp2D(TestArgmaxOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [3, 5]).astype(self.dtype),
}
self.param = (0,)
self.target_shape = (5,)
class TestArgmaxOp2DKeepDim(TestArgmaxOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [3, 5]).astype(self.dtype),
}
self.param = (0, True)
self.target_shape = (1, 5)
@OpTestTool.skip_if( @OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册