未验证 提交 eadc5d07 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【New IR] delete print program in test and delete add_n attribute c++ interface...

【New IR] delete print program in test and delete add_n attribute c++ interface to reply #56080 (#56120)

* refine program translator

* fix warning: not override

* fix bug

* merge new modifications

* modify by reviews

* resolve conflicts

* resolve conflicts

* fix

* fix

* fix conflicts

* pseudocode of backward

* modify test

* modify register op

* clear other code

* modify ci build bug

* reply review comments

* reply review comments

* delete print and add_n c++ interface

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 bfc64801
...@@ -48,24 +48,6 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> { ...@@ -48,24 +48,6 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> {
void Verify(); void Verify();
ir::Value inputs() { return operand_source(0); } ir::Value inputs() { return operand_source(0); }
ir::OpResult out() { return result(0); } ir::OpResult out() { return result(0); }
ir::Attribute attribute(const std::string &name) {
{
PADDLE_ENFORCE(
attributes().count(name) > 0,
phi::errors::PreconditionNotMet("Attribute is not exist."));
return attributes().at(name);
}
}
template <typename T>
T attribute(const std::string &name) {
{
PADDLE_ENFORCE(
attributes().count(name) > 0 && attributes().at(name).isa<T>(),
phi::errors::PreconditionNotMet("Attribute is not right."));
return attributes().at(name).dyn_cast<T>();
}
}
static void InferMeta(phi::InferMetaContext *infer_meta); static void InferMeta(phi::InferMetaContext *infer_meta);
}; };
......
...@@ -43,7 +43,6 @@ class TestBuildOp(unittest.TestCase): ...@@ -43,7 +43,6 @@ class TestBuildOp(unittest.TestCase):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(tanh_out) out = paddle.mean(tanh_out)
print(newir_program)
self.assertEqual(out.get_defining_op().name(), "pd.mean") self.assertEqual(out.get_defining_op().name(), "pd.mean")
self.assertEqual( self.assertEqual(
out.get_defining_op() out.get_defining_op()
...@@ -65,7 +64,6 @@ class TestBuildOp2(unittest.TestCase): ...@@ -65,7 +64,6 @@ class TestBuildOp2(unittest.TestCase):
out1 = paddle.mean(tanh_out) out1 = paddle.mean(tanh_out)
out2 = paddle.mean(tanh_out) out2 = paddle.mean(tanh_out)
out = paddle.add_n([out1, out2]) out = paddle.add_n([out1, out2])
print(newir_program)
self.assertEqual(out.get_defining_op().name(), "pd.add_n") self.assertEqual(out.get_defining_op().name(), "pd.add_n")
self.assertEqual( self.assertEqual(
out.get_defining_op() out.get_defining_op()
......
...@@ -88,6 +88,16 @@ class TesBackward(unittest.TestCase): ...@@ -88,6 +88,16 @@ class TesBackward(unittest.TestCase):
print(newir_program) print(newir_program)
self.assertEqual(newir_program.block().ops[-3].name(), "pd.full") self.assertEqual(newir_program.block().ops[-3].name(), "pd.full")
self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad")
self.assertEqual(
input_grad[0]
.get_defining_op()
.operands()[1]
.source()
.get_defining_op()
.name(),
"pd.mean_grad",
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
# TODO(Ruting) test add_n op when add_n api and add_grad finished # TODO(Ruting) test add_n op when add_n api and add_grad finished
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册