未验证 提交 abd2df4c 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #3] Supported higher-order GradNode generation (#41051)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* Fixed yaml typo
上级 489a64ef
......@@ -89,6 +89,10 @@ def FindForwardName(string):
return string[:-5]
def IsGradName(string):
return string.endswith("_grad")
def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types:
......@@ -166,6 +170,12 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string
######################
### Yaml Parsers ###
######################
......
......@@ -31,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage
......@@ -146,15 +147,38 @@ NODE_DECLARATION_TEMPLATE = \
}};
"""
FUNCTION_TEMPLATE = \
GRAD_FUNCTION_TEMPLATE = \
"""
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
// Fill Zero For GradIn Tensors
{}
// Apply Gradient Hooks
auto hooked_grads = ApplyGradientHooks(grads);
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
// Call grad_api function
VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}{}({});
{}
// Get Output
{}
// Get GradIn autograd_meta
{}
// Get GradOut autograd_meta
{}
// Compute Require Grad
{}
// Create Grad Node
{}
// Return
{}
}}
"""
......@@ -170,11 +194,14 @@ FORWARD_FUNCTION_TEMPLATE = \
// Get Input AutoGradMeta
{}
// Forward API Call
{}
// Get Outputs
{}
// Get Output AutoGradMeta
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
// Check Inplace & Bump Inplace Version
{}
{}
......@@ -225,6 +252,7 @@ NODE_CC_FILE_TEMPLATE = \
#include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
......@@ -689,14 +717,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
tw_name = f"api_result"
if is_optional:
set_tensor_wrappers = f" if({tw_name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({tw_name}.get_ptr()), false);"
set_tensor_wrappers = f" if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
......@@ -729,12 +754,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad({name});"
set_grad_in_meta = f" grad_node->SetGradInMeta({name}, {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
set_grad_in_meta_list.append(set_grad_in_meta)
......@@ -898,20 +919,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
function_name = GetIntermediateAPIFunctionName(function_name)
forward_call_str = f"auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
# Get return type list & outputs
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
# Get Outputs
get_outputs_str = ""
for name, (rtype, pos) in forward_outputs_position_map.items():
if num_outputs == 1 and len(intermediate_outputs) == 0:
get_outputs_str += f"auto& {name} = api_result;\n"
else:
get_outputs_str += f"auto& {name} = std::get<{pos}>(api_result);\n"
# Get return type list & outputs
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
continue
if num_outputs == 1 and len(intermediate_outputs) == 0:
returns_list[0] = f"api_result"
else:
# Tuple api_result
returns_list[pos] = f"std::get<{pos}>(api_result)"
returns_list[pos] = f"{name}"
if IsPlainTensorType(rtype):
returns_type_list[pos] = "paddle::experimental::Tensor"
......@@ -956,26 +981,24 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));\n"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
# 3. ComputeRequireGrad & PassStopGradient
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# 4. Check Inplace
# 3. Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
if is_inplaced:
......@@ -1015,7 +1038,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_call_str, outputs_autograd_meta_str,
forward_call_str, get_outputs_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
......@@ -1083,13 +1106,18 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
def __init__(self, forward_api_contents, grad_api_contents, namespace):
def __init__(self,
forward_api_contents,
grad_api_contents,
namespace,
next_grad_api_contents=None):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
......@@ -1151,7 +1179,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
def GenerateNodeDefinition(self):
def GenerateNodeDefinition(self, grad_node_creation_str):
namespace = self.namespace
forward_api_name = self.forward_api_name
backward_api_name = self.backward_api_name
......@@ -1165,62 +1193,183 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
backward_grad_inputs_map.keys()) + len(backward_attrs_list)
grad_api_args = ["" for i in range(grad_api_args_len)]
get_grad_in_args_list = []
# Fill Grad Ins with Zero
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
# Grad Ins from TensorWrappers
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
is_optional = (name in self.optional_inputs)
if is_optional:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverOptionalTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
else:
grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr)"
for _, (ttype, fwd_position,
tensor_wrapper_recover_str = f"auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr);"
grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(tensor_wrapper_recover_str)
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}]"
get_tensor_str = f"auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
grad_api_args[grad_api_position] = transformed_tensor_name
get_grad_in_args_list.append(get_tensor_str)
# Grad Attrs
for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name)
grad_api_args[grad_api_position] = f"this->{saved_attribute_name}"
get_attr_str = f"auto& {name} = this->{saved_attribute_name};"
grad_api_args[grad_api_position] = name
get_grad_in_args_list.append(get_attr_str)
get_grad_in_args_str = "\n".join(get_grad_in_args_list)
grad_api_args_str = ", ".join(grad_api_args)
# Grad Function Call String
grad_api_namespace = f"paddle::experimental::{namespace}"
grad_function_call_str = f"auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
# Get Grad Outputs
get_outputs_str = ""
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
if num_outputs == 1:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result;"
else:
get_tensor_str = f"auto& {transformed_tensor_name} = grad_api_result[{fwd_position}];"
get_outputs_str += get_tensor_str + "\n"
# Prepare for Node Creation if Necessary
inputs_autograd_meta_str = ""
outputs_autograd_meta_str = ""
compute_require_grad_str = ""
if len(grad_node_creation_str) > 0:
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
if IsPlainTensorType(ttype):
input_autograd_meta = f" egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
input_autograd_meta = f" std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
input_autograd_meta += f" std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(
compute_require_grad_args_list)
# 3. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(
transformed_tensor_name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
compute_require_grad_str = "bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
compute_require_grad_str += f"bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
# Construct grad_api returns
num_bwd_outputs = len(backward_grad_outputs_map.keys())
slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
returns_str = f"std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for _, (ttype, fwd_position,
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
# Infer Grad API Return Type
if num_bwd_outputs == 1:
# Single tensor output, return as is
if IsPlainTensorType(ttype):
returns_str += "returns[0] = { grad_api_returns };\n"
returns_str += f"returns[0] = {{ {transformed_tensor_name} }};\n"
else:
assert IsVectorTensorType(ttype)
returns_str += "returns[0] = grad_api_returns;\n"
returns_str += f"returns[0] = {transformed_tensor_name};\n"
else:
# Rearrange output order accordingly
returns_str += f"returns[{fwd_position}] = grad_api_returns[{grad_api_position}];\n"
returns_str += f"returns[{fwd_position}] = {transformed_tensor_name};\n"
returns_str += f"if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(forward_api_name)
fill_zero_str = ""
if forward_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
grad_api_namespace = f"paddle::experimental::{namespace}"
self.node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str,
grad_node_creation_str, returns_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}")
......@@ -1231,7 +1380,22 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
## Code Generation ##
#####################
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition()
namespace = self.namespace
grad_node_creation_str = ""
next_grad_api_contents = self.next_grad_api_contents
if next_grad_api_contents:
forward_api_contents = self.grad_api_contents
forward_api_contents['api'] = forward_api_contents['backward_api']
backward_api_contents = next_grad_api_contents
next_node_generator = DygraphFunctionGeneratorBase(
forward_api_contents, backward_api_contents, namespace)
next_node_generator.run()
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
self.GenerateNodeDefinition(grad_node_creation_str)
class DygraphYamlGenerator(YamlGeneratorBase):
......@@ -1278,19 +1442,35 @@ class DygraphYamlGenerator(YamlGeneratorBase):
forward_api_contents)
if backward_api_contents is None: continue
# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator(
forward_api_contents, backward_api_contents, namespace)
function_generator.run()
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace)
node_generator.run()
self.forward_definition_str += function_generator.forward_definition_str + "\n"
self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
while True:
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)
node_generator = DygraphNodeGenerator(
forward_api_contents, backward_api_contents, namespace,
next_grad_api_contents)
node_generator.run()
self.node_declaration_str += node_generator.node_declaration_str + "\n"
self.node_definition_str += node_generator.node_definition_str + "\n"
if next_grad_api_contents is None: break
# Detect if there exists higher-order GradNode
forward_api_contents = backward_api_contents
# Fake forward_api_content
forward_api_contents['api'] = forward_api_contents[
'backward_api']
backward_api_contents = next_grad_api_contents
if len(namespace) > 0:
if namespace.endswith("::"):
namespace = namespace[:-2]
......
......@@ -649,6 +649,16 @@
kernel :
func : put_along_axis_grad
- backward_api : relu_double_grad
forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
kernel :
func : relu_double_grad
- backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......@@ -658,6 +668,7 @@
param : [out]
kernel :
func : relu_grad
backward: relu_double_grad
- backward_api : reshape_grad
forward : reshape_with_xshape (Tensor x, ScalarArray shape) -> Tensor(out), Tensor(xshape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册