未验证 提交 bb0df468 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports squeeze, fix infershape and GetPositiveAxes (#55333)

上级 de9318a3
......@@ -705,10 +705,6 @@ std::vector<std::vector<int>> InferShapeForSqueeze(
VLOG(4) << "The output calculated in Squeeze: "
<< cinn::utils::Join(output_shape, ", ");
if (output_shape.size() == 0) {
output_shape.push_back(1);
}
return {output_shape};
}
......
......@@ -23,7 +23,7 @@ std::vector<int> GetPositiveAxes(const std::vector<int>& axes, int rank) {
std::vector<int> new_axes(axes.size());
for (int i = 0; i < axes.size(); ++i) {
int axis = axes[i] + (axes[i] < 0 ? rank : 0);
CHECK(axis >= 0 && axis < rank)
CHECK(axis >= 0 && (rank == 0 || axis < rank))
<< "The axis should in [" << -rank << ", " << rank << "), but axes["
<< i << "]=" << axes[i] << " not.";
new_axes[i] = axis;
......
......@@ -713,5 +713,62 @@ class TestDropoutOp(OpTest):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestSqueezeOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.squeeze_axex = [0]
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.squeeze(x, axis=self.squeeze_axex)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("squeeze_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.squeeze(x, self.squeeze_axex)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
class TestSqueezeOp1D(TestSqueezeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [1]).astype(self.dtype),
}
self.squeeze_axex = []
self.target_shape = ()
class TestSqueezeOp2D(TestSqueezeOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, [1, 1]).astype(self.dtype),
}
self.squeeze_axex = [0, 1]
self.target_shape = ()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册