未验证 提交 4e5ae23e 编写于 作者: H hong 提交者: GitHub

fix op translator reshape type (#56294)

上级 bcc5ce44
...@@ -1355,8 +1355,8 @@ struct ElementwiseTranscriber : public OpTranscriber { ...@@ -1355,8 +1355,8 @@ struct ElementwiseTranscriber : public OpTranscriber {
for (int i = 0; i <= append_size; i++) { for (int i = 0; i <= append_size; i++) {
y_new_shape.push_back(1); y_new_shape.push_back(1);
} }
dialect::Reshape_Op reshape_op = dialect::ReshapeOp reshape_op =
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape); builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
y_new = reshape_op.out(); y_new = reshape_op.out();
VLOG(6) << "[" << op_desc.Type() << "] y_shape change from " VLOG(6) << "[" << op_desc.Type() << "] y_shape change from "
<< y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape); << y_tensor_type.dims() << " to " << phi::make_ddim(y_new_shape);
...@@ -1371,8 +1371,7 @@ struct ElementwiseTranscriber : public OpTranscriber { ...@@ -1371,8 +1371,7 @@ struct ElementwiseTranscriber : public OpTranscriber {
auto concat_op = auto concat_op =
builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0); builder.Build<dialect::ConcatOp>(y_true_shape_op.out(), 0);
auto y_new_shape = concat_op.out(); auto y_new_shape = concat_op.out();
auto reshape_op = auto reshape_op = builder.Build<dialect::ReshapeOp>(y_value, y_new_shape);
builder.Build<dialect::Reshape_Op>(y_value, y_new_shape);
y_new = reshape_op.out(); y_new = reshape_op.out();
} }
return {x_value, y_new}; return {x_value, y_new};
...@@ -1449,7 +1448,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber { ...@@ -1449,7 +1448,7 @@ struct ElementwiseGradTranscriber : public OpTranscriber {
ir::OpResult value = operation->result(idx); ir::OpResult value = operation->result(idx);
ir::Builder builder(ctx, operation->GetParent()); ir::Builder builder(ctx, operation->GetParent());
auto reshape_op = builder.Build<dialect::Reshape_Op>(value, y_shape); auto reshape_op = builder.Build<dialect::ReshapeOp>(value, y_shape);
(*param_map)[y_grad_var_name] = (*param_map)[y_grad_var_name] =
VariableDefiningInfo(reshape_op.out(), false, -1); VariableDefiningInfo(reshape_op.out(), false, -1);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册