未验证 提交 03f06841 编写于 作者: Q qiuwenbo 提交者: GitHub

开发grad_fn、next_functions两个API 并暴露到python端- 修改单侧文件路径到合理位置 (#55311)

* [尝试] 给tensor增加一个属性, 这个属性是一个定值 1

* 暴露gradnode 并构建gradnode新的方法(用来测试)进行暴露给python python端可以访问

* 开发grad_fn、next_functions两个API 并暴露到python端- 做一些规范化处理

* 增加一个单元测试

* 优化 code-style

* 将单侧文件迁到正确的位置

* 优化 code-style

* 删除无用注释

* 解决 __main__ has no attribute

* 修改单侧文件

* 修改单侧脚本-temp
上级 b2c797ad
...@@ -18,12 +18,12 @@ Test the tensor attribute grad_fn and the properties of the reverse node grad_no ...@@ -18,12 +18,12 @@ Test the tensor attribute grad_fn and the properties of the reverse node grad_no
import unittest import unittest
import paddle import paddle
import paddle.nn as nn from paddle import nn
class Testmodel(nn.Layer): class Testmodel(nn.Layer):
def __init__(self): def __init__(self):
super(Testmodel, self).__init__() super().__init__()
def forward(self, x): def forward(self, x):
y = x**2 y = x**2
...@@ -74,7 +74,7 @@ class TestAnonmousSurvey(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestAnonmousSurvey(unittest.TestCase):
def test_grad_fn_and_next_funs(self): def test_grad_fn_and_next_funs(self):
self.check_func(self.output.grad_fn, self.output_grad_fn["grad_fn"]) self.check_func(self.output.grad_fn, self.output_grad_fn["grad_fn"])
def check_func(self, grad_fn: grad_fn, grad_fn_json: dict) -> None: def check_func(self, grad_fn, grad_fn_json) -> None:
""" """
Check each node, grad_fn is tensor attribute. grad_fn_json is structure of next_node. Check each node, grad_fn is tensor attribute. grad_fn_json is structure of next_node.
...@@ -82,14 +82,8 @@ class TestAnonmousSurvey(unittest.TestCase): ...@@ -82,14 +82,8 @@ class TestAnonmousSurvey(unittest.TestCase):
grad_fn (grad_fn): grad_fn of node grad_fn (grad_fn): grad_fn of node
grad_fn_json (dict): grad_node_json of node grad_fn_json (dict): grad_node_json of node
""" """
# print(grad_fn.name())
# assert func name
self.assertEqual(grad_fn.name(), grad_fn_json["func_name"]) self.assertEqual(grad_fn.name(), grad_fn_json["func_name"])
# Recursively test other nodes
if hasattr(grad_fn, 'next_functions') and grad_fn.next_functions[0]:
next_funcs_json = grad_fn_json["next_funcs"]
for u in grad_fn.next_functions:
self.check_func(u, next_funcs_json[u.name()])
unittest.main() if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册