diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index e487c3a6d4ebe33964916a3a6f76fbbb659cf223..29d727fc8cba5f023683bb7a6082c8ba6a06bf3a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -23,7 +23,7 @@ import os ######################## ops_to_fill_zero_for_empty_grads = set([ "split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad", - "sigmoid_triple_grad" + "sigmoid_triple_grad, add_double_grad" ]) # For API dispatch used at python-level diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index e4f07e6da89667e65f41e3bab294987a2a8e47b4..2154b5d6a4898308ac60f7b592ed34d580a804dd 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -205,6 +205,7 @@ FORWARD_FUNCTION_TEMPLATE = \ #endif }} // Forward API Call + VLOG(3) << \"Final State Running: \" << \"{}\"; {} // Get Outputs {} @@ -505,15 +506,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): for i in range(len(forward_attrs_list)): orig_attr_type = orig_forward_attrs_list[i][1] - orig_attr_default = orig_forward_attrs_list[i][2] orig_attr_pos = orig_forward_attrs_list[i][3] forward_attr_type = forward_attrs_list[i][1] - forward_attr_default = forward_attrs_list[i][2] forward_attr_pos = forward_attrs_list[i][3] assert orig_attr_type == forward_attr_type, AssertMessage( orig_attr_type, forward_attr_type) - assert orig_attr_default == forward_attr_default, AssertMessage( - orig_attr_default, forward_attr_default) assert orig_attr_pos == forward_attr_pos, AssertMessage( orig_attr_pos, forward_attr_pos) @@ -753,6 +750,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): set_grad_out_meta_list = [] set_edges_list = [] for name, (_, pos) in forward_inputs_position_map.items(): + # Has corresponding grad output + has_corresponding_grad_output = False + for _, (_, corresponding_pos, + _) in backward_grad_outputs_map.items(): + if pos == corresponding_pos: + has_corresponding_grad_output = True + if not has_corresponding_grad_output: + continue + input_autograd_meta_name = GetAutoGradMetaName(name) is_optional = (name in self.optional_inputs) if is_optional: @@ -1063,9 +1069,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format( returns_type_str, forward_function_name, inputs_args_definition_str, dygraph_event_str, amp_logic_str, inputs_autograd_meta_str, - forward_call_str, get_outputs_str, outputs_autograd_meta_str, - compute_require_grad_args_str, check_inplace_str, - bump_inplace_version_str, node_creation_str, returns_str) + forward_function_name, forward_call_str, get_outputs_str, + outputs_autograd_meta_str, compute_require_grad_args_str, + check_inplace_str, bump_inplace_version_str, node_creation_str, + returns_str) self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" logging.info( @@ -1439,28 +1446,18 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): compute_require_grad_str += f"{indent}bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});" # Construct grad_api returns - num_bwd_outputs = len(backward_grad_outputs_map.keys()) slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) returns_str = f"{indent}std::vector> returns({slot_num_bwd_outputs});\n" for name, (ttype, fwd_position, grad_api_position) in backward_grad_outputs_map.items(): transformed_tensor_name = self.TransformToNextGradName(name) - # Infer Grad API Return Type - if num_bwd_outputs == 1: - # Single tensor output, return as is - if IsPlainTensorType(ttype): - returns_str += f"{indent}returns[0] = {{ {transformed_tensor_name} }};\n" - else: - assert IsVectorTensorType(ttype) - returns_str += f"{indent}returns[0] = {transformed_tensor_name};\n" + # Rearrange output order accordingly + if IsPlainTensorType(ttype): + returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n" else: - # Rearrange output order accordingly - if IsPlainTensorType(ttype): - returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n" - else: - assert IsVectorTensorType(ttype) - returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n" + assert IsVectorTensorType(ttype) + returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n" returns_str += f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str += f"{indent}return returns;\n" diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index c0b4a7e712948b89de17daf0485ef3d4c1c6e8b7..4afe8ff105e7611d9b275bee5bedf7b87edaca15 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -485,6 +485,7 @@ std::unordered_map getInDegreeMap( } } } + return node_in_degree_map; } @@ -526,6 +527,7 @@ std::vector RunBackward( bool allow_unused = false, const std::vector& no_grad_vars = {}) { VLOG(6) << "Start Backward"; + // *Gradient Hook should happen at node-level // *Inplace version check should perform at node-level // *Cross-batch accumulation happens at forward pass @@ -729,6 +731,16 @@ std::vector RunBackward( continue; } + auto* next_node = next_node_shared.get(); + if (!node_input_buffers_dict.count(next_node)) { + const auto& input_meta = next_node->InputMeta(); + auto grad_tensor_holder = + std::make_unique(input_meta); + VLOG(6) << "Construct GradTensorHolder for grad node: " + << next_node->name(); + node_input_buffers_dict[next_node] = std::move(grad_tensor_holder); + } + PADDLE_ENFORCE_LT( j, grad_output_tensors[i].size(), paddle::platform::errors::Fatal( @@ -748,15 +760,6 @@ std::vector RunBackward( << ", rank: " << j << " 's name is: " << grad_output_tensor.name(); - auto* next_node = next_node_shared.get(); - if (!node_input_buffers_dict.count(next_node)) { - const auto& input_meta = next_node->InputMeta(); - auto grad_tensor_holder = - std::make_unique(input_meta); - VLOG(6) << "Construct GradTensorHolder for grad node: " - << next_node->name(); - node_input_buffers_dict[next_node] = std::move(grad_tensor_holder); - } VLOG(6) << "Sum grad inputs for edge slot: " << edge_rank.first << ", rank: " << edge_rank.second; node_input_buffers_dict[next_node]->add( diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index 1548272f8622c4817718311b4db57b0f00583738..f452d9ffb7e8950ba2beec7b6cb1b7b09e746851 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -63,9 +63,9 @@ void AddGradKernel(const Context& dev_ctx, template void AddDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, + const DenseTensor& dout, paddle::optional ddx, paddle::optional ddy, - const DenseTensor& dout, int axis, DenseTensor* ddout) { phi::AddDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index 979bb61c2e3cab464e3ed6b9895e8e4914ef8583..0e730fbfbfa4de7fddc29d648b8a40d5e3e31951 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -31,9 +31,9 @@ void AddGradKernel(const Context& dev_ctx, template void AddDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, + const DenseTensor& dout, paddle::optional ddx, paddle::optional ddy, - const DenseTensor& dout, int axis, DenseTensor* ddout); diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 3750e4b2bd61e1310077510ca6c19dc100d7e935..fae7978d3d2ea0518879224364335eea68b3a831 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -56,9 +56,9 @@ void AddGradKernel(const Context& dev_ctx, template void AddDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, + const DenseTensor& dout, paddle::optional ddx, paddle::optional ddy, - const DenseTensor& dout, int axis, DenseTensor* ddout) { phi::AddDoubleGradImpl(dev_ctx, y, ddx, ddy, dout, axis, ddout); diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 5ab71c0cd0fde118a2eba60aefa64ccab570b7bf..0a58d86b05b06be6da363b0b274c8efdaedfe06a 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -115,7 +115,7 @@ KernelSignature ElementwiseAddGradOpArgumentMapping( KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "add_double_grad", {"Y", "DDX", "DDY", "DOut"}, {"axis"}, {"DDOut"}); + "add_double_grad", {"Y", "DOut", "DDX", "DDY"}, {"axis"}, {"DDOut"}); } KernelSignature ElementwiseAddTripleGradOpArgumentMapping( diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 8b80444fe9011f8b5e0f13c4fc47f1bddcbd07ab..5b305325f3d2dfc7fb92598ab78273c8f7feb733 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -15,7 +15,7 @@ from __future__ import print_function from .. import core -from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator +from ..framework import Variable, convert_np_dtype_to_dtype_, _varbase_creator, _in_legacy_dygraph, in_dygraph_mode from ..layers.layer_function_generator import OpProtoHolder from . import no_grad from .. import framework @@ -62,6 +62,15 @@ _complex_dtypes = [ _already_patch_varbase = False _already_patch_eager_tensor = False +# Dispatch to final state Python-C functions +_final_state_op_type_mapping = { + "elementwise_add": "final_state_add", + "elementwise_sub": "final_state_subtract", + "elementwise_div": "final_state_divide", + "elementwise_mul": "final_state_multiply", + "matmul_v2": "final_state_matmul", +} + def monkey_patch_math_varbase(): """ @@ -105,10 +114,15 @@ def monkey_patch_math_varbase(): """ if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - return _C_ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype) + + if _in_legacy_dygraph(): + return _C_ops.cast(self, 'in_dtype', self.dtype, 'out_dtype', dtype) + return _C_ops.final_state_cast(self, dtype) def _scalar_elementwise_op_(var, scale, bias): - return _C_ops.scale(var, 'scale', scale, 'bias', bias) + if _in_legacy_dygraph(): + return _C_ops.scale(var, 'scale', scale, 'bias', bias) + return _C_ops.final_state_scale(var, float(scale), bias, True) def _neg_(var): return _scalar_elementwise_op_(var, -1.0, 0.0) @@ -164,7 +178,10 @@ def monkey_patch_math_varbase(): perm = [] for i in range(len(var.shape)): perm.insert(0, i) - out, _ = _C_ops.transpose2(var, 'axis', perm) + if _in_legacy_dygraph(): + out, _ = _C_ops.transpose2(var, 'axis', perm) + else: + out = _C_ops.final_state_transpose(var, perm) return out def _scalar_add_(var, value): @@ -270,11 +287,13 @@ def monkey_patch_math_varbase(): # 4. calculation axis = -1 - if framework._in_eager_mode_ and op_type == 'elementwise_add': - math_op = getattr(_C_ops, 'final_state_add') + if in_dygraph_mode( + ) and op_type in _final_state_op_type_mapping.keys(): + math_op = getattr(_C_ops, _final_state_op_type_mapping[op_type]) + return math_op(self, other_var) else: math_op = getattr(_C_ops, op_type) - return math_op(self, other_var, 'axis', axis) + return math_op(self, other_var, 'axis', axis) comment = OpProtoHolder.instance().get_op_proto(op_type).comment diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 949964b8281fdca3a9ba4e066f3891970cb321f0..47f40a2e6a5af1567765357efc893d841637574b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9037,7 +9037,10 @@ def relu(x, name=None): # [[0. 0. ] # [1. 2.6]] """ - if _non_static_mode(): + + if in_dygraph_mode(): + return _C_ops.final_state_relu(x) + if _in_legacy_dygraph(): return _C_ops.relu(x) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu') diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index c9e41fe93ebe1cdba69a120bc74bc9dbae9afc46..00b192b2a057b6bd2180d83ca94a7cd1e06ad8e5 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -385,26 +385,23 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - if not _in_legacy_dygraph(): - pass - else: - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward(retain_graph=True) + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward(retain_graph=True) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * + (x_np + dx_expected * + (x_np > 0) * 2 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + + for i in range(5): + loss.backward(retain_graph=True) x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * ( + x_grad_expected = (i + 2) * (2.0 / float(numel) * ( x_np + dx_expected * (x_np > 0) * 2 / float(numel))).astype('float32') self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) - for i in range(5): - loss.backward(retain_graph=True) - x_grad_actual = x.gradient() - x_grad_expected = (i + 2) * (2.0 / float(numel) * ( - x_np + dx_expected * - (x_np > 0) * 2 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) - def test_example_with_gradient_accumulation_and_create_graph(self): with _test_eager_guard(): self.func_example_with_gradient_accumulation_and_create_graph() @@ -426,7 +423,10 @@ class TestDygraphDoubleGrad(TestCase): del y1, z, w dx_actual, = self.grad( - [w_mean], [x], create_graph=True, no_grad_vars=[y2]) + [w_mean], [x], + retain_graph=True, + create_graph=True, + no_grad_vars=[y2]) self.assertFalse(y2.stop_gradient) self.assertFalse(dx_actual.stop_gradient) @@ -435,17 +435,14 @@ class TestDygraphDoubleGrad(TestCase): (x_np > 0) * 2).astype('float32') self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - if not _in_legacy_dygraph(): - pass - else: - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 / float(numel) * ( - x_np + dx_expected * - (x_np > 0) * 4 / float(numel))).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 / float(numel) * + (x_np + dx_expected * + (x_np > 0) * 4 / float(numel))).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) def test_example_with_gradient_accumulation_and_no_grad_vars(self): with _test_eager_guard(): @@ -476,15 +473,12 @@ class TestDygraphDoubleGrad(TestCase): self.assertTrue(np.allclose(dx_actual.numpy(), dx_expected)) - if not _in_legacy_dygraph(): - pass - else: - loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) - loss.backward() + loss = fluid.layers.reduce_mean(dx_actual * dx_actual + x * x) + loss.backward() - x_grad_actual = x.gradient() - x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') - self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) + x_grad_actual = x.gradient() + x_grad_expected = (2.0 * x_np / float(numel)).astype('float32') + self.assertTrue(np.allclose(x_grad_actual, x_grad_expected)) def test_example_with_gradient_accumulation_and_not_create_graph(self): with _test_eager_guard(): diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index db4e08792a43c21d71d73bbf913e424dc39edb9d..006b2bf698bd4637f0936e92b4bd75513e484ea4 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -30,6 +30,18 @@ kernel : func : acosh_grad +- backward_api : add_double_grad + forward : add_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) + args : (Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + output : Tensor(grad_out_grad) + infer_meta : + func : UnchangedInferMeta + param : [grad_out] + kernel : + func : add_double_grad + optional : grad_x_grad, grad_y_grad + backward : add_triple_grad + - backward_api : add_grad forward : add (Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) @@ -40,6 +52,7 @@ kernel : func : add_grad no_need_buffer : x, y + backward : add_double_grad - backward_api : add_n_grad forward : add_n (Tensor[] x) -> Tensor(out) @@ -48,6 +61,16 @@ invoke : add_n_grad_impl(x, out_grad) no_need_buffer : x +- backward_api : add_triple_grad + forward : add_double_grad (Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, int axis = -1) -> Tensor(grad_grad_out) + args : (Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_grad_out_grad, int axis = -1) + output : Tensor(grad_grad_x_grad), Tensor(grad_grad_y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [grad_grad_x, grad_grad_y] + kernel : + func : add_triple_grad + - backward_api : addmm_grad forward : addmm (Tensor input, Tensor x, Tensor y, float alpha, float beta) -> Tensor(out) args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta) @@ -940,6 +963,12 @@ kernel : func : mean_all_grad +- backward_api : mean_double_grad + forward: mean_grad (Tensor x, Tensor grad_out, int64_t[] dims={}, bool keep_dim=false, bool reduce_all = false) -> Tensor(grad_x) + args : (Tensor grad_x_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + output : Tensor(grad_out_grad) + invoke : mean(grad_x_grad, dims, keep_dim) + - backward_api : mean_grad forward: mean (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) args : (Tensor x, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) @@ -949,6 +978,7 @@ param: [x] kernel : func : mean_grad + backward : mean_double_grad no_need_buffer : x - backward_api : meshgrid_grad @@ -1020,6 +1050,17 @@ output : Tensor[](ins_grad) invoke : multiplex_grad_impl(ins, ids, out_grad) +- backward_api : multiply_double_grad + forward : multiply_grad (Tensor x, Tensor y, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) + args : (Tensor x, Tensor y, Tensor grad_out, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + output : Tensor(x_grad), Tensor(y_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [x, y, grad_out] + kernel : + func : multiply_double_grad + optional : grad_x_grad, grad_y_grad + - backward_api : multiply_grad forward : multiply (Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) @@ -1029,6 +1070,7 @@ param : [x, y] kernel : func : multiply_grad + backward : multiply_double_grad - backward_api : mv_grad forward : mv (Tensor x, Tensor vec) -> Tensor(out) @@ -1179,10 +1221,10 @@ - backward_api : relu_double_grad forward : relu_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x) args : (Tensor out, Tensor grad_x_grad) - output : Tensor(out_grad), Tensor(grad_out_grad) + output : Tensor(grad_out_grad) infer_meta : - func : GeneralBinaryGradInferMeta - param : [out, out] + func : UnchangedInferMeta + param : [out] kernel : func : relu_double_grad @@ -1255,11 +1297,35 @@ kernel : func : round_grad +- backward_api : rsqrt_grad + forward : rsqrt (Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : rsqrt_grad + +- backward_api : scale_double_grad + forward : scale_grad (Tensor grad_out, Scalar scale, float bias, bool bias_after_scale) -> Tensor(grad_x) + args : (Tensor grad_x_grad, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) + output : Tensor(grad_out_grad) + invoke : scale(grad_x_grad, scale, 0.0, bias_after_scale) + backward : scale_triple_grad + - backward_api : scale_grad forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) output : Tensor(x_grad) invoke : scale(out_grad, scale, 0.0, bias_after_scale) + backward : scale_double_grad + +- backward_api : scale_triple_grad + forward : scale_double_grad (Tensor grad_grad_x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(grad_grad_out) + args : (Tensor grad_grad_out_grad, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) + output : Tensor(grad_grad_x_grad) + invoke : scale(grad_grad_out_grad, scale, 0.0, bias_after_scale) - backward_api : scatter_grad forward : scatter (Tensor x, Tensor index, Tensor updates, bool overwrite) -> Tensor(out)