未验证 提交 22a11a60 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] fix shadow output translator (#56365)

* fix shadow output translator

* fix coverage ci
上级 b5ac40ba
......@@ -1108,10 +1108,18 @@ struct ShadowOutputOpTranscriber : public OpTranscriber {
TranslationContext* param_map,
const OpDesc& op_desc,
ir::Program* program) override {
auto op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name());
std::vector<ir::OpResult> op_inputs;
auto legacy_input_vars = op_desc.Input("x", true);
auto defining_info = (*param_map)[legacy_input_vars[0]];
if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, defining_info, legacy_input_vars[0]);
defining_info = param_map->at(legacy_input_vars[0]).value;
}
op_inputs.push_back(defining_info.value);
ir::AttributeMap attribute_map = {
......@@ -1120,9 +1128,8 @@ struct ShadowOutputOpTranscriber : public OpTranscriber {
op_desc.GetAttrIfExists<std::string>("name"))},
};
auto create_op_info = ctx->GetRegisteredOpInfo(ir::SetParameterOp::name());
ir::Operation* operation =
ir::Operation::Create(op_inputs, attribute_map, {}, create_op_info);
ir::Operation::Create(op_inputs, attribute_map, {}, op_info);
program->block()->push_back(operation);
return operation;
......
......@@ -300,5 +300,30 @@ class TestGradAddOpTranscriber(unittest.TestCase):
_ = ir.translate_to_new_ir(main_program.desc)
class TestShadowOutputSlice(unittest.TestCase):
def test_op(self):
place = core.Place()
place.set_place(paddle.CPUPlace())
new_scope = paddle.static.Scope()
main_program = paddle.static.Program()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.rand([3, 9, 5])
y = paddle.static.data(
name="y", shape=[3, 9, 5], dtype="float32"
)
_, out, _ = paddle.split(x, num_or_sections=3, axis=1)
helper = LayerHelper('shadow_output')
helper.append_op(
type="shadow_output",
inputs={"x": [out.name]},
outputs={"out": [y.name]},
attrs={"name": out.name},
)
l = ir.translate_to_new_ir(main_program.desc)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册