未验证 提交 d1b74ba5 编写于 作者: H HongyuJia 提交者: GitHub

[0D-Tensor] CINN supports transpose, add special case to expand_zero_dim_pass (#55379)

上级 69e9f03e
......@@ -34,33 +34,73 @@ class ExpandZeroDimPass : public ProgramPass {
const std::unordered_set<std::string>& fetch_ids,
const common::Target& target) override {
NetBuilder builder("expand_zero_dim_builder");
for (auto var : program->GetInputs()) {
if (var->shape.empty()) {
var->shape.push_back(1);
}
builder.CreateInput(var);
}
for (int i = 0; i < program->size(); ++i) {
auto& instr = (*program)[i];
if (instr->op_type == "transpose") {
builder.AppendInstruction(HandleTranspose(instr));
continue;
}
for (auto& input : instr->inputs) {
if (input->shape.empty()) {
VLOG(4) << "Change input 0D-Tensor " << input->id << " to 1D-Tensor";
VLOG(4) << "Change " << instr->op_type << "'s input 0D-Tensor "
<< input->id << " to 1D-Tensor";
input->shape.push_back(1);
}
}
for (auto& output : instr->outputs) {
if (output->shape.empty()) {
VLOG(4) << "Change output 0D-Tensor " << output->id
<< " to 1D-Tensor";
VLOG(4) << "Change " << instr->op_type << "'s output 0D-Tensor "
<< output->id << " to 1D-Tensor";
output->shape.push_back(1);
}
}
builder.AppendInstruction(instr);
}
for (auto var : program->GetInputs()) {
if (var->shape.empty()) {
VLOG(4) << "Change program's input 0D-Tensor " << var->id
<< " to 1D-Tensor";
var->shape.push_back(1);
}
builder.CreateInput(var);
}
*program = builder.Build();
}
void Clear() override {}
private:
// Before: out-0D = transpose(x-0D, [])
// After: out-1D = transpose(x-1D, [1])
Instruction HandleTranspose(const Instruction& instr) {
Instruction new_instr = instr;
bool has_0d_input = false;
for (auto& input : new_instr->inputs) {
if (input->shape.empty()) {
VLOG(4) << "Change transpose's input 0D-Tensor " << input->id
<< " to 1D-Tensor";
input->shape.push_back(1);
has_0d_input = true;
}
}
for (auto& output : new_instr->outputs) {
if (output->shape.empty()) {
VLOG(4) << "Change transpose's output 0D-Tensor " << output->id
<< " to 1D-Tensor";
output->shape.push_back(1);
}
}
if (has_0d_input) {
std::vector<int32_t> axis =
new_instr.GetAttrs<std::vector<int32_t>>("axis");
CHECK(axis.empty()) << "transpose's axis should be empty when inputs "
"0D-Tensor! Please check setting.\n";
axis.push_back(0);
VLOG(4) << "Change Transpose's attribute axis from [] to [1]";
new_instr.SetAttr<std::vector<int32_t>>("axis", axis);
}
return new_instr;
}
};
} // namespace pass
......
......@@ -1137,8 +1137,8 @@ std::vector<framework::shape_t> InferShapeForTranspose(
const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
std::vector<framework::shape_t> result;
CHECK(!inputs_shape.empty() && !inputs_shape[0].empty())
<< "The input's shape size is 0! Please check again.";
CHECK(!inputs_shape.empty())
<< "The input's shape is empty! Please check again.";
if (attrs.find("axis") != attrs.end()) {
auto axis = absl::get<std::vector<int>>(attrs.at("axis"));
CHECK_EQ(axis.size(), inputs_shape[0].size())
......
......@@ -630,6 +630,44 @@ class TestScaleOp(OpTest):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
class TestTransposeOp(OpTest):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.init_input()
def init_input(self):
self.inputs = {
"x": np.random.randint(-10, 10, []).astype(self.dtype),
}
self.target_shape = ()
def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=False)
out = paddle.transpose(x, [])
self.paddle_outputs = [out]
def build_cinn_program(self, target):
builder = NetBuilder("transpose_op")
x = builder.create_input(
cinn_dtype_convert(self.dtype), self.inputs["x"].shape, "x"
)
out = builder.transpose(x, [])
prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]], [out])
self.cinn_outputs = res
self.assertEqual(res[0].shape, self.target_shape)
def test_check_results(self):
self.check_outputs_and_grads()
@OpTestTool.skip_if(
not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册