未验证 提交 abd2df4c 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #3] Supported higher-order GradNode generation (#41051)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* Fixed yaml typo
上级 489a64ef
...@@ -89,6 +89,10 @@ def FindForwardName(string): ...@@ -89,6 +89,10 @@ def FindForwardName(string):
return string[:-5] return string[:-5]
def IsGradName(string):
return string.endswith("_grad")
def IsPlainTensorType(string): def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor'] plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types: if string in plain_tensor_types:
...@@ -166,6 +170,12 @@ def GetForwardFunctionName(string): ...@@ -166,6 +170,12 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function" return f"{string}_final_state_dygraph_function"
def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
......
...@@ -649,6 +649,16 @@ ...@@ -649,6 +649,16 @@
kernel : kernel :
func : put_along_axis_grad func : put_along_axis_grad
- backward_api : relu_double_grad
forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
kernel :
func : relu_double_grad
- backward_api : relu_grad - backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out) forward : relu (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad) args : (Tensor out, Tensor out_grad)
...@@ -658,6 +668,7 @@ ...@@ -658,6 +668,7 @@
param : [out] param : [out]
kernel : kernel :
func : relu_grad func : relu_grad
backward: relu_double_grad
- backward_api : reshape_grad - backward_api : reshape_grad
forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape) forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册