未验证 提交 f5830c05 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports `fill_constant`, fix infershape and pass (#55563)

* [0D-Tensor] CINN supports fill_constant, fix infershape and pass

* fix infershape of fill_constant

* add back fill_constant to zero_tensor_trick_pass
上级 97ec1d84
......@@ -39,6 +39,9 @@ class ExpandZeroDimPass : public ProgramPass {
if (instr->op_type == "transpose") {
builder.AppendInstruction(HandleTranspose(instr));
continue;
} else if (instr->op_type == "fill_constant") {
builder.AppendInstruction(HandleFillConstant(instr));
continue;
}
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
......@@ -101,6 +104,20 @@ class ExpandZeroDimPass : public ProgramPass {
}
return new_instr;
}
// Before: out-0D = fill_constant([], 123.456, "out", "float32")
// After: out-1D = fill_constant([1], 123.456, "out", "float32")
Instruction HandleFillConstant(const Instruction& instr) {
Instruction new_instr = instr;
std::vector<int32_t> shape =
new_instr.GetAttrs<std::vector<int32_t>>("shape");
if (shape.empty()) {
shape.push_back(1);
VLOG(4) << "Change fill_constant's attribute shape from [] to [1]";
}
new_instr.SetAttr<std::vector<int32_t>>("shape", shape);
return new_instr;
}
};
} // namespace pass
......
......@@ -393,7 +393,6 @@ std::vector<shape_t> InferShapeForFillConstant(
const framework::AttrMapType &attrs) {
CHECK(attrs.count("shape"));
auto shape = absl::get<std::vector<int>>(attrs.at("shape"));
CHECK(!shape.empty()) << "shape attr is empty!";
return {shape};
}
......
......@@ -1318,6 +1318,33 @@ class TestReshapeOp1DTo0D(TestReshapeOp):
self.target_shape = []
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestFillConstantOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.target_shape = ()
def build_paddle_program(self, target):
out = paddle.full([], 123.456, "float32")
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("fill_constant_op")
out = builder.fill_constant([], 123.456, "out", "float32")
prog = builder.build()
res = self.get_cinn_output(prog, target, [], [], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册