未验证 提交 276c159d 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports broadcast_to, fix infershape (#55321)

上级 885d1aec
......@@ -322,10 +322,12 @@ Variable NetBuilder::Concat(const std::vector<Variable>& input_vars, int axis) {
Variable NetBuilder::BroadcastTo(const Variable& operand,
const std::vector<int>& out_shape) {
auto x_shape_size = operand->shape.size();
if (x_shape_size == 0) {
VLOG(4) << "0D-Tensor " << operand->id << " broadcast to shape ("
<< cinn::utils::Join(out_shape, ",") << ")";
return BroadcastTo(operand, out_shape, {0});
}
auto y_shape_size = out_shape.size();
CHECK_GT(x_shape_size, 0)
<< "Cannot broadcast a empty operand " << operand->id << " to "
<< cinn::utils::Join(out_shape, ",");
CHECK_LE(x_shape_size, y_shape_size)
<< "The broadcast_p's input shape dimension should less than the "
"output's, "
......
......@@ -236,8 +236,13 @@ std::vector<shape_t> InferShapeForBroadcastTo(
VLOG(3) << "broadcast input shape: " << utils::Join(inputs_shape[0], ", ");
VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", ");
VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", ");
CHECK_EQ(inputs_shape[0].size(), broadcast_axes.size())
<< "broadcast_axes's size should be same with the input shape's size";
if (inputs_shape[0].empty()) {
CHECK(broadcast_axes.size() == 1 && broadcast_axes[0] == 0)
<< "broadcast_axes's size should be {1} when the input is 0D-Tensor";
} else {
CHECK_EQ(inputs_shape[0].size(), broadcast_axes.size())
<< "broadcast_axes's size should be same with the input shape's size";
}
CHECK_GE(out_shape.size(), broadcast_axes.size())
<< "broadcast_axes's size should be no more than out_shape's size";
......
......@@ -630,6 +630,60 @@ class TestScaleOp(OpTest):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestBroadcastToOp1D(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.broadcast_shape = [1]
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.broadcast_to(x, shape=self.broadcast_shape)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("broadcast_to_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.broadcast_to(x, self.broadcast_shape)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(list(res[0].shape), list(self.broadcast_shape))
def test_check_results(self):
self.check_outputs_and_grads()
class TestBroadcastToOp2D(TestBroadcastToOp1D):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.broadcast_shape = [1, 1]
class TestBroadcastToOp3D(TestBroadcastToOp1D):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.broadcast_shape = [3, 3, 3]
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册