未验证 提交 f736f151 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports unsqueeze, delete hack in Paddle's pass (#55336)

上级 14551c85
...@@ -755,8 +755,8 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims( ...@@ -755,8 +755,8 @@ std::shared_ptr<OpStrategy> StrategyForExpandDims(
std::vector<std::vector<int>> InferShapeForExpandDims( std::vector<std::vector<int>> InferShapeForExpandDims(
const std::vector<std::vector<int>> &inputs_shape, const std::vector<std::vector<int>> &inputs_shape,
const framework::AttrMapType &attrs) { const framework::AttrMapType &attrs) {
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) CHECK(!inputs_shape.empty())
<< "The input's shape size is 0! Please check again."; << "At least 1 input tensor for expand_dims operator.";
CHECK_EQ(inputs_shape.size(), 1U); CHECK_EQ(inputs_shape.size(), 1U);
const std::vector<int> &axes = const std::vector<int> &axes =
......
...@@ -32,32 +32,6 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const { ...@@ -32,32 +32,6 @@ void CinnZeroTensorTrickPass::ApplyImpl(ir::Graph* graph) const {
"assign_value", "assign_value",
"gaussian_random", "gaussian_random",
"set_value"}; "set_value"};
// NOTE: Hack squeeze2 0D-Tensor input
// If squeeze2 inputs 0D-Tensor and axes, The 0D-Tensor's shape will convert
// to 1D-Tensor, which could lead error. We hack squeeze2's axes attribute to
// resolve this. Change 0D-Tensor input to 1D-Tensor input and then make
// axes->axes[: -1]
for (const ir::Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->Type() == "unsqueeze2") {
if (n->Op()->HasAttr("axes")) {
auto axes =
PADDLE_GET_CONST(std::vector<int32_t>, n->Op()->GetAttr("axes"));
for (const ir::Node* var : n->inputs) {
if (var->Var() &&
var->Var()->GetType() == proto::VarType::LOD_TENSOR) {
std::vector<int64_t> shape = var->Var()->GetShape();
if (shape.empty()) {
axes.pop_back();
n->Op()->SetAttr("axes", axes);
VLOG(4) << "unsqueeze2 axes dims is full, fix dim -> dim[:-1] to "
"avoid 0D-Tensor input error";
}
}
}
}
}
}
// CINN ops in this white list support 0D-Tensor, wait-list = {"remainder"} // CINN ops in this white list support 0D-Tensor, wait-list = {"remainder"}
const std::unordered_set<std::string> white_op_list{"elementwise_add", const std::unordered_set<std::string> white_op_list{"elementwise_add",
"elementwise_sub", "elementwise_sub",
......
...@@ -668,6 +668,60 @@ class TestTransposeOp(OpTest): ...@@ -668,6 +668,60 @@ class TestTransposeOp(OpTest):
self.check_outputs_and_grads() self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestExpandDimsOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.unsqueeze_dim = [0]
self.target_shape = (1,)
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.unsqueeze(x, self.unsqueeze_dim)
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("unsqueeze_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.expand_dims(x, self.unsqueeze_dim)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestExpandDimsOp2D(TestExpandDimsOp):
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.unsqueeze_dim = [0, 1]
self.target_shape = (
1,
1,
)
@OpTestTool.skip_if( @OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout." not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册