eager_gen.py 100.5 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
16
import os
17 18
import re

19
from codegen_utils import (
20 21 22 23 24 25 26
    AssertMessage,
    FindForwardName,
    FunctionGeneratorBase,
    GeneratorBase,
    GetAutoGradMetaName,
    GetAutoGradMetaVectorName,
    GetConstReference,
27
    GetDygraphForwardFunctionName,
28 29 30
    GetGradNodeName,
    GetIndent,
    GetInplacedFunctionName,
31
    GetIntermediateAPIFunctionName,
32 33 34 35 36 37 38 39 40 41 42 43
    GetSavedName,
    IsPlainTensorType,
    IsVectorTensorType,
    ParseYamlBackward,
    ParseYamlForwardFromBackward,
    ParseYamlInplaceInfo,
    ReadBwdFile,
    RemoveConstAndReference,
    core_ops_args_info,
    core_ops_args_type_info,
    core_ops_returns_info,
    ops_to_fill_zero_for_empty_grads,
44
)
45

Z
zyfncg 已提交
46 47 48 49 50 51
# Note: assign is a inplace api when parameter(output) isn't none,
# so we should check parameter(output) with rule of inplace.
# But because there is no check in old dygraph mode, in order to
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"])
52 53 54

# Black Ops list that's NO NEED to apply code generation
black_ops_list = [
55 56 57 58 59
    "conv2d",
    "conv2d_grad",
    "conv2d_grad_grad",
    "add_n",
    "add_n_grad",
60
]
Z
zyfncg 已提交
61

62

63 64 65
#########
# Utils #
#########
66 67
def ParseArguments():
    parser = argparse.ArgumentParser(
68 69
        description='Eager Code Generator Args Parser'
    )
70 71 72 73 74 75 76 77 78 79 80
    parser.add_argument('--nodes_h_path', type=str)
    parser.add_argument('--nodes_cc_path', type=str)
    parser.add_argument('--forwards_h_path', type=str)
    parser.add_argument('--forwards_cc_path', type=str)
    parser.add_argument('--api_yaml_path', type=str)
    parser.add_argument('--backward_yaml_path', type=str)

    args = parser.parse_args()
    return args


81 82 83
######################
# Code Gen Templates #
######################
84
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """  void SetTensorWrapper{}(const paddle::experimental::Tensor& {}) {{
85
    {} = egr::TensorWrapper({}, {});
86
  }}
87 88
"""

89
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """  void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
90
    for(const auto& eager_tensor : {}) {{
91
      {}.emplace_back(egr::TensorWrapper(eager_tensor, {}));
92 93
    }};
  }}
94 95
"""

96
PLAIN_TENSOR_MEMBER_TEMPLATE = """  egr::TensorWrapper {};
97 98
"""

99
VECTOR_TENSOR_MEMBER_TEMPLATE = """  std::vector<egr::TensorWrapper> {};
100
"""
101

102
CLEAR_TENSOR_WRAPPER_TEMPLATE = """    {}.clear();
103 104
"""

105
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = """    for (auto& tw : {}) {{
106 107
      tw.clear();
    }}
108 109
"""

110
SET_ATTR_METHOD_TEMPLATE = """  void SetAttribute{}({} {}) {{
111 112
    {} = {};
  }}
113 114
"""

115
ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = """  {} {} = {};
116 117
"""

118
ATTRIBUTE_MEMBER_TEMPLATE = """  {} {};
119 120
"""

121
NODE_DECLARATION_TEMPLATE = """
122 123 124
class {} : public egr::GradNodeBase {{
 public:
  {}() : egr::GradNodeBase() {{}}
125
  {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
126 127 128
      egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
  ~{}() override = default;

129 130
  virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> operator()(
      paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph = false, bool is_new_grad = false) override;
C
chenjian 已提交
131
  std::string name() override {{ return \"{}\"; }}
132

133
  void ClearTensorWrappers() override {{
134 135
{}
    SetIsTensorWrappersCleared(true);
136 137 138
  }}

  std::shared_ptr<GradNodeBase> Copy() const override {{
139 140
    auto copied_node = std::shared_ptr<{}>(new {}(*this));
    return copied_node;
141
  }}
142

143
  // SetTensorWrapperX, SetTensorWrapperY, ...
144
{}
145
  // SetAttributes
146
{}
147 148
 private:
  // TensorWrappers
149
{}
150
  // Attributes
151
{}}};
152 153
"""

154
GRAD_FUNCTION_TEMPLATE = """
155
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> {}::operator()(paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) {{
J
Jiabin Yang 已提交
156
  VLOG(3) << \"Running AD API GRAD: \" << \"{}\";
157
  // Fill Zero For GradIn Tensors
158
{}
159 160
  // Apply Gradient Hooks
  auto hooked_grads = ApplyGradientHooks(grads);
161

162
  // Collect GradIn Tensors, Attrs and Recovered TensorWrappers
163
{}
J
Jiabin Yang 已提交
164 165
  // Prepare Grad function call
{}
166
  // Runtime check if we need next grad
J
Jiabin Yang 已提交
167 168 169 170 171
{}
  // Inplace Check
{}
  // Inplace Strategy
{}
172

J
Jiabin Yang 已提交
173
  VLOG(5) << \"Running C++ API: \" << \"{}\";
174 175 176
  // Before log info
{}
  // Call grad_api function
177 178
{}
  // Check NaN and Inf id needed
179
{}
180
  // Get GradOut autograd_meta
181
{}
182
  // Create Grad Node
183
{}
J
Jiabin Yang 已提交
184 185 186
  VLOG(4) << \"Finish AD API GRAD: {}";
  // LOG IF DEBUG
  {}
187
  // Return
188 189
{}
}}
190 191
"""

192
FORWARD_FUNCTION_TEMPLATE = """
193
{} {}({}) {{
J
Jiabin Yang 已提交
194
  VLOG(3) << \"Running AD API: \" << \"{}\";
195
  // Dygraph Record Event
196
{}
197
  // AMP Logic
198 199
{}
  // Layout autotune
200
{}
201
  // Get Input AutoGradMeta
202
{}
203

J
Jiabin Yang 已提交
204
  VLOG(5) << \"Running C++ API: \" << \"{}\";
205 206 207
 // Before log info
{}
 // Forward API Call
208 209
{}
  // Check NaN and Inf if needed
210
{}
211
  // Get Outputs
212
{}
213
  // Get Output AutoGradMeta
214
{}
215 216
  bool trace_backward = egr::Controller::Instance().HasGrad();
  bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
217

218 219 220 221
  // Check Inplace if needed
{}{}
  // Node Creation
{}
J
Jiabin Yang 已提交
222 223 224 225

  VLOG(4) << \"Finish AD API: {}";
  // LOG IF DEBUG
  {}
226 227
  // Returns
  return {};
228
}}
229
"""
230

231
AFTER_LOG_PRINT_TEMPLATE = """
J
Jiabin Yang 已提交
232
  if(VLOG_IS_ON(4)){{
233
      const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s],  \\n Output: [%s] }} \";
J
Jiabin Yang 已提交
234 235 236 237 238
      {}
      VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str);
  }}
"""

239
BEFORE_LOG_PRINT_TEMPLATE = """
240 241 242 243 244 245
  if(VLOG_IS_ON(3)){{
      const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s]}} \";
      {}
      VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
  }}
"""
246

247
FORWARD_ONLY_FUNCTION_TEMPLATE = """
248
{} {}({}) {{
J
Jiabin Yang 已提交
249
  VLOG(3) << \"Running AD API: \" << \"{}\";
250 251 252 253
  // Dygraph Record Event
{}
  // AMP Logic
{}
254 255
  // Layout autotune
{}
J
Jiabin Yang 已提交
256
  VLOG(5) << \"Running C++ API: \" << \"{}\";
257 258 259
  // Before log info
{}
  // Forward API Call
260 261 262
{}
  // Get Outputs
{}
J
Jiabin Yang 已提交
263 264 265
  VLOG(4) << \"Finish AD API: {}";
  // LOG IF DEBUG
  {}
266 267 268 269 270
  // Returns
  return {};
}}
"""

271
FORWARD_BODY_TEMPLATE = """  if(require_any_grad) {{
272
{}
273 274 275
    egr::EagerUtils::PassStopGradient({});

    // Node Construction
276
{}
277
    // SetAttributes if needed
278
{}
279
    // Set TensorWrappers for Forward Inputs if needed
280
{}
281
    // SetGradOutMeta & SetEdges
282
{}
283
    // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
284 285 286
{}
{}
{}
287
{}
288
    // Set TensorWrappers for Forward Outputs if needed
289
{}
290
  }}
291
"""
292

293
HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """  if(trace_backward) {{
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
{}
    // Node Construction
{}
    // SetAttributes if needed
{}
    // Set TensorWrappers for Forward Inputs if needed
{}
    // SetGradOutMeta & SetEdges
{}
    // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
    // Set TensorWrappers for Forward Outputs if needed
{}
  }}
"""

313
NAMESPACE_WRAPPER_TEMPLATE = """
314 315 316
namespace {} {{
    {}
}}
317
"""
318

319
NODE_CC_FILE_TEMPLATE = """
320 321 322 323 324 325
#include "glog/logging.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
326
#include "paddle/fluid/platform/profiler/event_tracing.h"
327 328 329
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
330
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
331
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
332
#include "paddle/fluid/eager/nan_inf_utils.h"
333
#include "paddle/phi/api/include/sparse_api.h"
334
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
J
Jiabin Yang 已提交
335 336 337
#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h"
#include "paddle/fluid/prim/api/all.h"
#include "paddle/fluid/prim/utils/utils.h"
338
DECLARE_bool(check_nan_inf);
339 340 341
{}
"""

342
NODE_H_FILE_TEMPLATE = """
343 344 345
#pragma once
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/eager/grad_node_info.h"
W
Weilong Wu 已提交
346
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
347

348 349
{}
"""
350

351
FORWARD_CC_FILE_TEMPLATE = """
352 353 354
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
355
#include "paddle/fluid/eager/eager_layout_auto_tune.h"
356
#include "paddle/phi/api/include/strings_api.h"
357 358 359
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
360 361
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
362
#include "paddle/phi/backends/gpu/gpu_info.h"
363
#include "paddle/fluid/eager/nan_inf_utils.h"
364
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
365
DECLARE_bool(check_nan_inf);
366 367
{}
{}
368 369
"""

370
FORWARD_H_FILE_TEMPLATE = """
371 372 373 374 375 376 377
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/phi/api/all.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
378
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
W
Weilong Wu 已提交
379

380
using CPUPlace = phi::CPUPlace;
381 382 383
{}
{}
"""
384

385
CORE_OPS_INFO_TEMPLATE = """
386
std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info = {{
387 388
    {}
}};
389
std::unordered_map<std::string, std::vector<std::string>> core_ops_args_type_info = {{
390 391
    {}
}};
392
std::unordered_map<std::string, std::vector<std::string>> core_ops_returns_info = {{
393 394 395 396
    {}
}};

"""
397

398
CORE_OPS_DECLARATION_TEMPLATE = """
399 400 401
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_returns_info;
402 403 404

"""

405
CHECK_INPLACE_TEMPLATE = """
406
  egr::EagerUtils::CheckInplace({}, {}, require_any_grad);
407 408
"""

409
BUMP_INPLACE_VERSION_TEMPLATE = """
410 411 412
  // Bump Inplace Version
  {}.bump_inplace_version();
  VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";
413 414
"""

415
AMP_LOGIC_TEMPLATE = """  if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
416 417 418 419 420 421 422 423 424
    VLOG(5) << "Check and Prepare For AMP";
    {}
    paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
    {}
    {}
    {}
    {{
      paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), paddle::imperative::AmpLevel::O0);
      {}
425
    }}
426
  }}
427
"""
428
LAYOUT_LOGIC_TEMPLATE = """
429
  if (egr::Controller::Instance().UseLayoutAutoTune()) {{
430 431
    paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> tensors_vector = {};
    {}
432
    {}
433 434
    VLOG(5) << "Check and Prepare For LAYOUT "<< op_name;
    paddle::imperative::LayoutAutotuneGuard guard(egr::Controller::Instance().GetCurrentTracer(), false);
435 436 437 438 439 440
    {}
    {}
    // Returns
    return {};
  }}
"""
441
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = """
442 443
  paddle::optional<paddle::experimental::Tensor> {}_optional;
  if({}.initialized()) {}_optional = paddle::make_optional<paddle::experimental::Tensor>({});
444 445
"""

446
CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = """
447 448
  paddle::optional<paddle::experimental::Tensor> {}_optional;
  if( {}.impl() ) {}_optional = paddle::make_optional<paddle::experimental::Tensor>({});
H
hong 已提交
449 450
"""

451
CREATE_RECOVER_OPTIONAL_VECTOR_TENSOR_TEMPLATE = """
452 453 454 455
  paddle::optional<std::vector<paddle::experimental::Tensor>> {}_optional;
  if( !{}.empty() ) {}_optional = paddle::make_optional<std::vector<paddle::experimental::Tensor>>({});
"""

456
CHECK_BACKWARD_INPLACE_TEMPLATE = """
457 458 459 460 461 462 463 464
  bool can_be_inplaced = false;
  if ({}.initialized()) {{
    VLOG(10) << {}.name() << "({}) use_count: " << {}.impl().use_count();
    if ({}.impl().use_count() == 1 || ({}.impl().use_count() == 2 && {}.impl().get() == {}.impl().get())) {{
      can_be_inplaced = true;
    }}
  }}"""

465
CHECK_NAN_AND_INF_TEMPLATE = """  if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
466 467
"""

468
inplace_optional_out_type_map = {
469 470
    "Tensor": "paddle::optional<paddle::experimental::Tensor>&",
    "std::vector<Tensor>": "paddle::optional<std::vector<paddle::experimental::Tensor>>&",
471 472
}

473

474 475 476 477 478
def ExtractForwardApiNameFormInvoke(invoke_config):
    api_name = invoke_config.split('(')[0]
    if api_name[-1] == '_':
        api_name = api_name[:-1]
    return re.search(
479 480
        r"(?P<api_name>[a-zA-Z0-9_]+)(?P<intermediate>_intermediate)?", api_name
    ).group('api_name')
481 482 483


def IsInvokeForwardApi(api_contents, forward_api_name_list):
484 485 486 487 488
    return (
        'invoke' in api_contents
        and ExtractForwardApiNameFormInvoke(api_contents['invoke'])
        in forward_api_name_list
    )
489 490


491 492 493
#####################
# Generator Helpers #
#####################
494 495 496 497 498 499
def GenerateCoreOpInfoDeclaration():
    return CORE_OPS_DECLARATION_TEMPLATE


def GenerateCoreOpInfoDefinition():

500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522
    op_args_info_list = []
    for op_name, arg_list in core_ops_args_info.items():
        arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
        op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
        op_args_info_list.append(op_args_info)

    op_types_info_list = []
    for op_name, type_list in core_ops_args_type_info.items():
        type_str = ",".join(["\"" + v + "\"" for v in type_list])
        op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
        op_types_info_list.append(op_types_info)

    op_returns_info_list = []
    for op_name, return_list in core_ops_returns_info.items():
        return_str = ",".join(["\"" + v + "\"" for v in return_list])
        return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
        op_returns_info_list.append(return_types_info)

    op_args_info_str = "\n".join(op_args_info_list)
    op_types_info_str = "\n".join(op_types_info_list)
    op_returns_info_str = "\n".join(op_returns_info_list)

    core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
523 524
        op_args_info_str, op_types_info_str, op_returns_info_str
    )
525 526 527 528

    return core_ops_info_definition_str


529 530 531
###################
# Generator Class #
###################
532
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
533 534 535 536 537 538 539
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
    ):
540 541
        self.forward_api_contents = forward_api_contents
        # Members from Parent:
542 543 544 545 546 547 548 549 550 551
        # self.namespace
        # self.forward_api_contents
        # self.forward_api_name
        # self.orig_forward_inputs_list
        # self.orig_forward_attrs_list
        # self.orig_forward_returns_list
        # self.forward_inputs_position_map
        # self.forward_outputs_position_map
        # self.optional_inputs
        # self.no_need_buffers
J
Jiabin Yang 已提交
552
        # self.composite_func_info
553 554
        # self.intermediate_outputs
        # self.forward_inplace_map
555 556
        FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)

557
        self.forward_apis_dict = forward_apis_dict
558 559 560 561 562 563
        self.grad_api_contents = grad_api_contents

        # Raw Contents
        self.backward_forward_str = ""
        self.backward_api_name = ""

564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
        self.forward_attrs_list = (
            []
        )  # [ [attr_name, attr_type, default_value, orig_position], ...]
        self.forward_inputs_list = (
            []
        )  # [ [arg_name, arg_type, orig_position], ...]
        self.forward_returns_list = (
            []
        )  # [ [ret_name, ret_type, orig_position], ...]

        self.backward_attrs_list = (
            []
        )  # [ [attr_name, attr_type, default_value, orig_position], ...]
        self.backward_inputs_list = (
            []
        )  # [ [arg_name, arg_type, orig_position], ...]
        self.backward_returns_list = (
            []
        )  # [ [ret_name, ret_type, orig_position], ...]
583 584

        # SlotNameMatched Backward Data
585 586 587 588 589 590 591 592 593 594 595
        self.backward_forward_inputs_map = (
            {}
        )  # { "name" : [type, is_fwd_input, orig_position] ...}
        self.backward_grad_inputs_map = (
            {}
        )  # { "name" : [type, fwd_position, orig_position] ...}
        self.backward_grad_outputs_map = (
            {}
        )  # { "name" : [type, fwd_position, orig_position] ...}

        self.backward_inplace_map = {}  # {name : name, ...}
596 597 598

    def ParseBackwardInplaceInfo(self):
        grad_api_contents = self.grad_api_contents
599 600
        if 'inplace' not in grad_api_contents.keys():
            return
601 602 603 604

        inplace_map_str = grad_api_contents['inplace']
        self.backward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)

605 606 607 608
    def DygraphYamlValidationCheck(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

609 610
        assert (
            'op' in forward_api_contents.keys()
611
        ), "Unable to find \"op\" in ops.yaml"
612 613
        assert (
            'args' in forward_api_contents.keys()
C
Chen Weihang 已提交
614
        ), "Unable to find \"args\" in ops.yaml"
615 616
        assert (
            'output' in forward_api_contents.keys()
C
Chen Weihang 已提交
617
        ), "Unable to find \"output\" in ops.yaml"
618

619
        if grad_api_contents is not None:
620 621
            assert (
                'backward' in forward_api_contents.keys()
C
Chen Weihang 已提交
622
            ), "Unable to find \"backward\" in ops.yaml"
623 624
            assert (
                'args' in grad_api_contents.keys()
625
            ), "Unable to find \"args\" in backward.yaml"
626 627
            assert (
                'output' in grad_api_contents.keys()
628
            ), "Unable to find \"output\" in backward.yaml"
629 630
            assert (
                'forward' in grad_api_contents.keys()
631
            ), "Unable to find \"forward\" in backward.yaml"
632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647

    def ForwardsValidationCheck(self):
        forward_inputs_list = self.forward_inputs_list
        forward_attrs_list = self.forward_attrs_list
        forward_returns_list = self.forward_returns_list

        orig_forward_inputs_list = self.orig_forward_inputs_list
        orig_forward_attrs_list = self.orig_forward_attrs_list
        orig_forward_returns_list = self.orig_forward_returns_list

        for i in range(len(forward_inputs_list)):
            forward_input_type = forward_inputs_list[i][1]
            forward_input_pos = forward_inputs_list[i][2]
            orig_input_type = orig_forward_inputs_list[i][1]
            orig_input_pos = orig_forward_inputs_list[i][2]

648
            assert forward_input_type == orig_input_type, AssertMessage(
649 650
                forward_input_type, orig_input_type
            )
651
            assert forward_input_pos == orig_input_pos, AssertMessage(
652 653
                forward_input_pos, orig_input_pos
            )
654 655 656 657 658 659

        for i in range(len(forward_attrs_list)):
            orig_attr_type = orig_forward_attrs_list[i][1]
            orig_attr_pos = orig_forward_attrs_list[i][3]
            forward_attr_type = forward_attrs_list[i][1]
            forward_attr_pos = forward_attrs_list[i][3]
660
            assert orig_attr_type == forward_attr_type, AssertMessage(
661 662
                orig_attr_type, forward_attr_type
            )
663
            assert orig_attr_pos == forward_attr_pos, AssertMessage(
664 665
                orig_attr_pos, forward_attr_pos
            )
666 667 668 669 670 671 672

        for i in range(len(forward_returns_list)):
            orig_return_type = orig_forward_returns_list[i][1]
            orig_return_pos = orig_forward_returns_list[i][2]
            forward_return_type = forward_returns_list[i][1]
            forward_return_pos = forward_returns_list[i][2]

673
            assert orig_return_type == forward_return_type, AssertMessage(
674 675
                orig_return_type, forward_return_type
            )
676
            assert orig_return_pos == forward_return_pos, AssertMessage(
677 678
                orig_return_pos, forward_return_pos
            )
679 680 681 682 683 684 685

        # Check Order: Inputs, Attributes
        max_input_position = -1
        for _, _, pos in forward_inputs_list:
            max_input_position = max(max_input_position, pos)

        for _, _, _, pos in forward_attrs_list:
686
            assert pos > max_input_position, AssertMessage(
687 688
                pos, max_input_position
            )
689 690 691 692 693 694 695 696 697 698 699 700 701

    def BackwardValidationCheck(self):
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_attrs_list = self.backward_attrs_list

        # Check Order: TensorWrappers, GradTensors, Attributes
        max_fwd_input_position = -1
        for _, (_, _, pos) in backward_forward_inputs_map.items():
            max_fwd_input_position = max(max_fwd_input_position, pos)

        max_grad_tensor_position = -1
        for _, (_, _, pos) in backward_grad_inputs_map.items():
702
            assert pos > max_fwd_input_position, AssertMessage(
703 704
                pos, max_grad_tensor_position
            )
705 706 707 708
            max_grad_tensor_position = max(max_grad_tensor_position, pos)

        max_attr_position = -1
        for _, _, _, pos in backward_attrs_list:
709
            assert pos > max_grad_tensor_position, AssertMessage(
710 711
                pos, max_grad_tensor_position
            )
712 713 714 715 716 717 718 719 720 721 722
            max_attr_position = max(max_attr_position, pos)

    def IntermediateValidationCheck(self):
        intermediate_outputs = self.intermediate_outputs
        forward_returns_list = self.forward_returns_list
        """
        Check whether intermediate_outputs are positioned
        at the very end of forward_returns_list
        """
        intermediate_positions = range(
            len(forward_returns_list) - len(intermediate_outputs),
723 724
            len(forward_returns_list),
        )
725 726
        for ret_name, _, pos in forward_returns_list:
            if ret_name in intermediate_outputs:
727
                assert pos in intermediate_positions, AssertMessage(
728 729
                    pos, intermediate_positions
                )
730 731 732 733 734 735 736 737 738 739

    def CollectBackwardInfo(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

        self.backward_api_name = forward_api_contents['backward']
        self.backward_forward_str = grad_api_contents['forward']
        backward_args_str = grad_api_contents['args']
        backward_returns_str = grad_api_contents['output']

740 741 742 743 744
        (
            self.backward_inputs_list,
            self.backward_attrs_list,
            self.backward_returns_list,
        ) = ParseYamlBackward(backward_args_str, backward_returns_str)
745

746 747 748 749 750 751 752 753
        # Remove the output which is intermediate
        if 'intermediate' in grad_api_contents:
            backward_returns_list_new = []
            for return_item in self.backward_returns_list:
                if return_item[0] not in grad_api_contents['intermediate']:
                    backward_returns_list_new.append(return_item)
            self.backward_returns_list = backward_returns_list_new

754 755 756 757
    def CollectForwardInfoFromBackwardContents(self):

        backward_forward_str = self.backward_forward_str

758 759 760 761 762
        (
            self.forward_inputs_list,
            self.forward_attrs_list,
            self.forward_returns_list,
        ) = ParseYamlForwardFromBackward(backward_forward_str)
763

764
    def CollectForwardInfoFromYamlForward(self):
765 766 767 768 769 770 771 772 773
        (
            self.forward_inputs_list,
            self.forward_attrs_list,
            self.forward_returns_list,
        ) = ParseYamlForwardFromBackward(
            self.forward_api_contents['args']
            + " -> "
            + self.forward_api_contents['output']
        )
774

775 776 777 778 779 780 781 782 783 784 785 786 787 788
    def SlotNameMatching(self):
        backward_inputs_list = self.backward_inputs_list
        backward_returns_list = self.backward_returns_list
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map

        for backward_input in backward_inputs_list:
            backward_input_name = backward_input[0]
            backward_input_type = backward_input[1]
            backward_input_pos = backward_input[2]

            backward_fwd_name = FindForwardName(backward_input_name)
            if backward_fwd_name:
                # Grad Input
789 790 791 792 793
                assert (
                    backward_fwd_name in forward_outputs_position_map.keys()
                ), AssertMessage(
                    backward_fwd_name, forward_outputs_position_map.keys()
                )
794
                matched_forward_output_type = forward_outputs_position_map[
795 796
                    backward_fwd_name
                ][0]
797
                matched_forward_output_pos = forward_outputs_position_map[
798 799
                    backward_fwd_name
                ][1]
800 801

                self.backward_grad_inputs_map[backward_input_name] = [
802 803 804
                    backward_input_type,
                    matched_forward_output_pos,
                    backward_input_pos,
805 806 807 808 809
                ]
            else:
                # TensorWrapper Input
                if backward_input_name in forward_inputs_position_map.keys():
                    tensor_wrapper_type = forward_inputs_position_map[
810 811
                        backward_input_name
                    ][0]
812
                    self.backward_forward_inputs_map[backward_input_name] = [
813 814 815
                        backward_input_type,
                        True,
                        backward_input_pos,
816 817 818 819
                    ]

                elif backward_input_name in forward_outputs_position_map.keys():
                    tensor_wrapper_type = forward_outputs_position_map[
820 821
                        backward_input_name
                    ][0]
822
                    self.backward_forward_inputs_map[backward_input_name] = [
823 824 825
                        backward_input_type,
                        False,
                        backward_input_pos,
826 827
                    ]
                else:
828 829 830
                    assert (
                        False
                    ), f"Cannot find {backward_input_name} in forward position map"
831 832 833 834 835 836 837

        for backward_output in backward_returns_list:
            backward_output_name = backward_output[0]
            backward_output_type = backward_output[1]
            backward_output_pos = backward_output[2]

            backward_fwd_name = FindForwardName(backward_output_name)
838 839 840 841 842 843 844 845
            assert (
                backward_fwd_name is not None
            ), f"Detected {backward_fwd_name} = None"
            assert (
                backward_fwd_name in forward_inputs_position_map.keys()
            ), AssertMessage(
                backward_fwd_name, forward_inputs_position_map.keys()
            )
846 847

            matched_forward_input_type = forward_inputs_position_map[
848 849
                backward_fwd_name
            ][0]
850
            matched_forward_input_pos = forward_inputs_position_map[
851 852
                backward_fwd_name
            ][1]
853 854

            self.backward_grad_outputs_map[backward_output_name] = [
855 856 857
                backward_output_type,
                matched_forward_input_pos,
                backward_output_pos,
858 859
            ]

860 861 862 863 864 865 866 867
    def GetPassStopGradientArgsList(self, forward_outputs_position_map):
        pass_stop_gradient_args_list = ["false"]
        for name, (_, _) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
            pass_stop_gradient_args_list.append(output_autograd_meta_name)
        pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
        return pass_stop_gradient_args_str

868
    def GenerateNodeCreationCodes(self, for_backward=False):
869 870 871 872
        forward_api_name = self.forward_api_name
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
873
        backward_forward_inputs_map = self.backward_forward_inputs_map
874 875
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
876
        backward_attrs_list = self.backward_attrs_list
877
        optional_inputs = self.optional_inputs
J
Jiabin Yang 已提交
878
        is_composite_grad_api = (
J
Jiabin Yang 已提交
879
            False if self.composite_func_info == {} else True
J
Jiabin Yang 已提交
880
        )
881

882
        # Pass Stop Gradient Args
883
        pass_stop_gradient_args_str = self.GetPassStopGradientArgsList(
884 885
            forward_outputs_position_map
        )
886

887
        # Node Construction
888 889
        num_backward_inputs = len(forward_outputs_position_map.keys())
        num_backward_outputs = len(forward_inputs_position_map.keys())
890
        grad_node_name = GetGradNodeName(self.backward_api_name)
891 892 893

        # Helper
        indent = GetIndent(2)
894 895 896 897 898
        # NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar
        # which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared
        # request MEMALIGN for allocation (Maybe).
        # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
        # and https://github.com/MRtrix3/mrtrix3/issues/957
899
        node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
900 901 902 903 904 905 906 907 908

        # SetAttributes
        set_attributes_list = []
        forward_attrs_name_set = set()
        for name, _, _, _ in forward_attrs_list:
            forward_attrs_name_set.add(name)

        for name, _, default_val_attr, _ in backward_attrs_list:
            if name in forward_attrs_name_set:
909 910 911
                set_attributes = (
                    f"{indent}grad_node->SetAttribute{name}({name});"
                )
912
            else:
913
                set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});"
914 915
            set_attributes_list.append(set_attributes)
        set_attributes_str = "\n".join(set_attributes_list)
916

917
        # SetTensorWrappers
918 919
        set_input_tensor_wrappers_list = []
        set_output_tensor_wrappers_list = []
920
        num_fwd_outputs = len(forward_outputs_position_map.keys())
921 922 923 924 925 926
        for name, (
            atype,
            is_fwd_input,
            pos,
        ) in backward_forward_inputs_map.items():
            is_optional = name in optional_inputs
927

928 929
            if is_fwd_input:
                if is_optional:
930
                    set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
931
                else:
932 933 934
                    set_tensor_wrappers = (
                        f"{indent}grad_node->SetTensorWrapper{name}({name});"
                    )
935
                set_input_tensor_wrappers_list.append(set_tensor_wrappers)
936
            else:  # Forwad's output as backward's input
937 938
                if num_fwd_outputs > 1:
                    # Aligned with forward output position
939 940
                    assert (
                        name in forward_outputs_position_map.keys()
941
                    ), AssertMessage(name, forward_outputs_position_map.keys())
942

943
                if is_optional:
944
                    set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
945
                else:
946 947 948
                    set_tensor_wrappers = (
                        f"{indent}grad_node->SetTensorWrapper{name}({name});"
                    )
949 950
                set_output_tensor_wrappers_list.append(set_tensor_wrappers)
        set_input_tensor_wrappers_str = "\n".join(
951 952
            set_input_tensor_wrappers_list
        )
953
        set_output_tensor_wrappers_str = "\n".join(
954 955
            set_output_tensor_wrappers_list
        )
956

957
        # SetGradOutMeta & SetEdges
958
        grad_node_out_list = []
959 960 961
        set_grad_out_meta_list = []
        set_edges_list = []
        for name, (_, pos) in forward_inputs_position_map.items():
962 963
            # Has corresponding grad output
            has_corresponding_grad_output = False
964 965 966 967 968
            for _, (
                _,
                corresponding_pos,
                _,
            ) in backward_grad_outputs_map.items():
969 970 971 972 973
                if pos == corresponding_pos:
                    has_corresponding_grad_output = True
            if not has_corresponding_grad_output:
                continue

974
            grad_node_out_list.append(name)
975
            is_optional = name in self.optional_inputs
H
hong 已提交
976
            if is_optional:
977
                set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
H
hong 已提交
978
            else:
979 980 981
                set_grad_out_meta = (
                    f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
                )
982

983 984
            set_grad_out_meta_list.append(set_grad_out_meta)
        set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
985

J
Jiabin Yang 已提交
986
        # SetOutRank & SetHistory & SetGradInMeta
987 988 989 990 991 992 993
        set_out_rank_list = []
        set_history_list = []
        set_grad_in_meta_list = []
        set_retain_grad_list = []
        num_outputs = len(forward_outputs_position_map.keys())
        for name, (_, pos) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
994 995 996 997 998 999 1000
            set_out_rank = f"""{indent}if ({output_autograd_meta_name}) {{
{indent}  egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});
{indent}}}"""

            set_history = f"""{indent}if ({output_autograd_meta_name}) {{
{indent}  egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);
{indent}}}"""
1001

1002 1003 1004 1005 1006 1007
            set_grad_in_meta = (
                f"{indent}grad_node->SetGradInMeta({name}, {pos});"
            )
            set_retain_grad = (
                f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
            )
1008

1009 1010 1011 1012
            set_out_rank_list.append(set_out_rank)
            set_history_list.append(set_history)
            set_grad_in_meta_list.append(set_grad_in_meta)
            set_retain_grad_list.append(set_retain_grad)
1013

1014 1015 1016 1017
        set_out_rank_str = "\n".join(set_out_rank_list)
        set_history_str = "\n".join(set_history_list)
        set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
        set_retain_grad_str = "\n".join(set_retain_grad_list)
1018

1019
        node_event_name = forward_api_name + " node_creation"
C
chenjian 已提交
1020
        node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
1021 1022
        if not for_backward:
            self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
                node_creation_event_str,
                pass_stop_gradient_args_str,
                node_construction_str,
                set_attributes_str,
                set_input_tensor_wrappers_str,
                set_grad_out_meta_str,
                set_out_rank_str,
                set_history_str,
                set_grad_in_meta_str,
                set_retain_grad_str,
                set_output_tensor_wrappers_str,
            )
1035
        else:
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
            self.node_creation_str = (
                HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE.format(
                    node_creation_event_str,
                    node_construction_str,
                    set_attributes_str,
                    set_input_tensor_wrappers_str,
                    set_grad_out_meta_str,
                    set_out_rank_str,
                    set_history_str,
                    set_grad_in_meta_str,
                    set_retain_grad_str,
                    set_output_tensor_wrappers_str,
                )
            )
1050

1051
        self.grad_node_out_list = grad_node_out_list
1052

1053 1054 1055
    def run(self):
        # Basic Validation Check
        self.DygraphYamlValidationCheck()
1056

1057 1058 1059
        ########################
        # Parsing Raw Contents #
        ########################
1060 1061
        # Parse forward and backward inplace_map
        self.ParseForwardInplaceInfo()
1062 1063 1064 1065
        if self.grad_api_contents is not None:
            self.ParseBackwardInplaceInfo()
            # Parse no_need_buffer
            self.ParseNoNeedBuffer()
J
Jiabin Yang 已提交
1066 1067
            # Parse composite
            self.ParseComposite()
1068 1069 1070 1071 1072 1073 1074 1075

        # Parse optional_inputs
        self.ParseDispensable()

        # Parse intermediate_outputs
        self.ParseIntermediate()
        self.IntermediateValidationCheck()

1076 1077 1078
        if self.grad_api_contents is not None:
            # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
            self.CollectBackwardInfo()
1079

1080 1081 1082 1083 1084
            # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
            self.CollectForwardInfoFromBackwardContents()

        if self.is_forward_only:
            self.CollectForwardInfoFromYamlForward()
1085 1086 1087 1088 1089 1090 1091

        # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
        self.CollectOriginalForwardInfo()

        # Forwards Validation Check
        self.ForwardsValidationCheck()

1092 1093 1094
        ###########################
        # Process Parsed Contents #
        ###########################
1095
        # Initialize forward_inputs_position_map, forward_outputs_position_map
1096 1097 1098
        self.DetermineForwardPositionMap(
            self.forward_inputs_list, self.forward_returns_list
        )
1099

1100 1101 1102 1103 1104
        if self.grad_api_contents is not None:
            # Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
            self.SlotNameMatching()
            # Backward Validation Check
            self.BackwardValidationCheck()
1105 1106 1107


class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
    ):
        DygraphFunctionGeneratorBase.__init__(
            self,
            forward_api_contents,
            grad_api_contents,
            forward_apis_dict,
            namespace,
        )
1122 1123 1124 1125

        # Generated Results
        self.forward_definition_str = ""
        self.forward_declaration_str = ""
1126

1127 1128 1129 1130 1131 1132 1133 1134 1135 1136
    def GenerateForwardLayoutAutotune(
        self,
        forward_api_name,
        amp_tensors_vector_list,
        layout_tensors_vector_optional_list,
        layout_autotune_list_str,
        returns_type_str,
        returns_str,
        amp_inputs_call_args_str,
    ):
1137 1138 1139
        intermediate_outputs = self.intermediate_outputs
        forward_attrs_list = self.forward_attrs_list
        forward_outputs_position_map = self.forward_outputs_position_map
1140 1141 1142
        num_outputs = len(forward_outputs_position_map.keys()) - len(
            intermediate_outputs
        )
1143 1144
        # for layout autotune attr
        lightly_sensitive_attr = [
1145 1146 1147 1148 1149 1150 1151 1152
            'axis',
            'axes',
            'dim',
            'dims',
            'start',
            'end',
            'stop',
            'perm',
1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
        ]
        heavily_sensitive_attr = ['data_format', 'data_layout']
        layout_autotune_attr = []
        layout_autotune_attr_code_list = []
        layout_autotune_attr_type_list = []
        layout_autotune_attr_code_list.append(
            f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");\n"
        )

        lightly_flag = False
        heavily_flag = False
        for name, atype, default_val, pos in forward_attrs_list:
            for attr_name in lightly_sensitive_attr:
1166 1167 1168
                if name.find(attr_name) != -1 and (
                    name not in layout_autotune_attr
                ):
1169 1170 1171 1172 1173 1174
                    lightly_flag = True
                    layout_autotune_attr.append(name)
                    layout_autotune_attr_type_list.append(atype)
            if lightly_flag is False:
                for attr_name in heavily_sensitive_attr:
                    if name.find(attr_name) != -1 and (
1175 1176
                        name not in layout_autotune_attr
                    ):
1177 1178 1179 1180 1181
                        layout_autotune_attr.append(name)
                        layout_autotune_attr_type_list.append(atype)
                        heavily_flag = True
        if len(layout_autotune_attr) == 0:
            layout_autotune_attr_code_list.append(
1182
                "auto transformer = egr::EagerLayoutAutotune(op_name, tensors_vector);\n"
1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
            )
        elif len(layout_autotune_attr) == 1:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}>(op_name, tensors_vector, &{layout_autotune_attr[0]});\n"
            )
        elif len(layout_autotune_attr) == 2:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}, {layout_autotune_attr_type_list[1]}>(op_name, tensors_vector, &{layout_autotune_attr[0]}, &{layout_autotune_attr[1]});\n"
            )
        else:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}>(op_name, tensors_vector,&{layout_autotune_attr[0]});\n"
            )
        # Out tensor
        layout_inputs_call_args_str = amp_inputs_call_args_str
        forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
        layout_tmp_result_list = []
        layout_autotune_outs_list = []
        result_name = "api_result"
        if num_outputs == 1:
            result_name = returns_str
            layout_autotune_outs_list.append(
1205 1206
                f"transformer -> SetOutTensorLayout(&{returns_str});\n"
            )
1207 1208 1209 1210 1211 1212 1213 1214
        else:
            for name, (rtype, pos) in forward_outputs_position_map.items():
                if name in intermediate_outputs:
                    continue
                layout_autotune_outs_list.append(
                    f"    auto& {name} = std::get<{len(layout_tmp_result_list)}>(api_result);\n"
                )
                layout_autotune_outs_list.append(
1215 1216
                    f"    transformer -> SetOutTensorLayout(&{name});\n"
                )
1217 1218
                layout_tmp_result_list.append(f"{name}")

1219 1220 1221
        tensors_vector_list_str = (
            "{ " + ",".join(amp_tensors_vector_list) + " }"
        )
1222

1223
        if len(amp_tensors_vector_list) == 0:
1224 1225 1226 1227 1228 1229
            layout_logic_str = ""
        else:
            after_call_str = f"{returns_type_str} {result_name} = {forward_function_name}({layout_inputs_call_args_str});\n"
            layout_logic_str = LAYOUT_LOGIC_TEMPLATE.format(
                tensors_vector_list_str,
                "    ".join(layout_tensors_vector_optional_list),
1230 1231 1232 1233 1234 1235 1236
                "    ".join(layout_autotune_attr_code_list)
                + "    "
                + layout_autotune_list_str,
                after_call_str,
                "    ".join(layout_autotune_outs_list),
                returns_str,
            )
1237 1238 1239

        return layout_logic_str

1240
    def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1241
        namespace = self.namespace
Z
zyfncg 已提交
1242 1243
        if self.forward_api_name[-1] == '_' and not is_inplaced:
            return
1244 1245 1246 1247 1248
        forward_api_name = (
            GetInplacedFunctionName(self.forward_api_name)
            if is_inplaced
            else self.forward_api_name
        )
1249

1250 1251 1252
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
1253 1254
        if not self.is_forward_only:
            backward_grad_outputs_map = self.backward_grad_outputs_map
1255

1256 1257
        optional_inputs = self.optional_inputs
        intermediate_outputs = self.intermediate_outputs
1258
        forward_inplace_map = self.forward_inplace_map if is_inplaced else {}
1259
        indent = GetIndent(1)
1260 1261 1262

        # Get Function Args
        num_inputs = len(forward_attrs_list) + len(
1263 1264
            forward_inputs_position_map.keys()
        )
1265 1266 1267
        inputs_args_definition_list = ["" for i in range(num_inputs)]
        inputs_args_declaration_list = ["" for i in range(num_inputs)]
        inputs_call_list = ["" for i in range(num_inputs)]
1268

1269 1270 1271 1272 1273
        amp_inputs_call_list = ["" for i in range(num_inputs)]
        amp_tensors_vector_list = []
        amp_tensors_vector_optional_list = []
        amp_autocast_list = []
        amp_autocast_optional_list = []
1274 1275 1276
        layout_autotune_list = []
        layout_autotune_optional_list = []
        layout_tensors_vector_optional_list = []
1277 1278
        for name, (ttype, pos) in forward_inputs_position_map.items():
            inputs_call_list[pos] = f"{name}"
1279
            amp_inputs_call_list[pos] = f"new_{name}"
1280
            is_optional = name in optional_inputs
1281 1282
            if IsPlainTensorType(ttype):
                if is_optional:
1283 1284 1285 1286 1287
                    if (
                        self.is_forward_only
                        and is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1288 1289 1290 1291
                    ):
                        arg_str = f"paddle::optional<paddle::experimental::Tensor>& {name}"
                    else:
                        arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
1292
                    amp_tensors_vector_optional_list.append(
1293
                        f"if ({name}) amp_tensors_vector.push_back({{ *{name} }});\n"
1294 1295
                    )
                    amp_autocast_optional_list.append(
1296
                        f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1297
                    )
1298 1299 1300 1301
                    layout_tensors_vector_optional_list.append(
                        f"if ({name}) tensors_vector.push_back({{ *{name} }});\n"
                    )
                    layout_autotune_optional_list.append(
1302
                        f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n"
1303
                    )
1304
                else:
1305 1306 1307 1308
                    if (
                        is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1309
                    ):
1310
                        arg_str = f"paddle::experimental::Tensor& {name}"
1311 1312
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
1313
                            f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1314
                        )
1315 1316
                    else:
                        arg_str = f"const paddle::experimental::Tensor& {name}"
1317 1318
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
1319
                            f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1320
                        )
1321
                    layout_autotune_list.append(
1322
                        f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n"
1323
                    )
1324 1325
            else:
                assert IsVectorTensorType(ttype)
1326
                if is_optional:
1327 1328 1329 1330 1331
                    if (
                        self.is_forward_only
                        and is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1332 1333 1334 1335
                    ):
                        arg_str = f"paddle::optional<std::vector<paddle::experimental::Tensor>>& {name}"
                    else:
                        arg_str = f"const paddle::optional<std::vector<paddle::experimental::Tensor>>& {name}"
1336 1337 1338 1339
                    amp_tensors_vector_optional_list.append(
                        f"if ({name}) amp_tensors_vector.push_back( *{name} );\n"
                    )
                    amp_autocast_optional_list.append(
1340
                        f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1341
                    )
1342
                    layout_autotune_optional_list.append(
1343
                        f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n"
1344
                    )
1345
                else:
1346 1347 1348 1349
                    if (
                        is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1350
                    ):
1351 1352 1353
                        arg_str = (
                            f"std::vector<paddle::experimental::Tensor>& {name}"
                        )
1354 1355
                    else:
                        arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
1356 1357
                    amp_tensors_vector_list.append(f"{name}")
                    amp_autocast_list.append(
1358
                        f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1359
                    )
1360
                    layout_autotune_list.append(
1361
                        f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n"
1362
                    )
1363 1364 1365 1366

            inputs_args_definition_list[pos] = arg_str
            inputs_args_declaration_list[pos] = arg_str

1367
        # forward attrs
1368 1369
        for name, atype, default_val, pos in forward_attrs_list:
            inputs_call_list[pos] = name
1370
            amp_inputs_call_list[pos] = name
1371 1372
            if default_val is not None:
                inputs_args_declaration_list[
1373 1374
                    pos
                ] = f"{atype} {name} = {default_val}"
1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385
            else:
                inputs_args_declaration_list[pos] = f"{atype} {name}"
            inputs_args_definition_list[pos] = f"{atype} {name}"

        inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
        inputs_args_definition_str = ", ".join(inputs_args_definition_list)
        inputs_call_args_str = ", ".join(inputs_call_list)

        # Forward Full Logic
        function_name = forward_api_name
        if len(intermediate_outputs) > 0:
1386
            if is_inplaced:
1387 1388 1389
                function_name = (
                    GetIntermediateAPIFunctionName(forward_api_name[:-1]) + '_'
                )
1390 1391
            else:
                function_name = GetIntermediateAPIFunctionName(function_name)
1392

1393 1394 1395 1396
        api_out_type = "auto"
        if is_inplaced and len(forward_outputs_position_map) == 1:
            api_out_type = "auto&"
        forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
1397 1398 1399
        num_outputs = len(forward_outputs_position_map.keys()) - len(
            intermediate_outputs
        )
1400

1401
        # Check Nan and Inf
1402
        check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
1403 1404
            function_name, "api_result"
        )
1405

1406 1407 1408 1409
        # Get Outputs
        get_outputs_str = ""
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if num_outputs == 1 and len(intermediate_outputs) == 0:
1410
                get_outputs_str += f"{indent}auto& {name} = api_result;\n"
1411
            else:
1412 1413 1414
                get_outputs_str += (
                    f"{indent}auto& {name} = std::get<{pos}>(api_result);\n"
                )
1415 1416

        # Get return type list & outputs
1417 1418 1419 1420 1421
        returns_type_list = ["" for i in range(num_outputs)]
        returns_list = ["" for i in range(num_outputs)]
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if name in intermediate_outputs:
                continue
1422
            returns_list[pos] = f"{name}"
1423 1424

            if IsPlainTensorType(rtype):
1425 1426 1427 1428
                if (
                    is_inplaced
                    and forward_inplace_map
                    and name in forward_inplace_map.values()
1429
                ):
1430
                    ind = list(forward_inplace_map.values()).index(name)
1431 1432 1433 1434
                    if (
                        list(forward_inplace_map.keys())[ind]
                        in self.optional_inputs
                    ):
1435
                        returns_type_list[pos] = inplace_optional_out_type_map[
1436 1437
                            rtype
                        ]
1438 1439
                    else:
                        returns_type_list[pos] = "paddle::experimental::Tensor&"
1440 1441
                else:
                    returns_type_list[pos] = "paddle::experimental::Tensor"
1442 1443
            else:
                assert IsVectorTensorType(rtype)
1444 1445 1446 1447
                if (
                    is_inplaced
                    and forward_inplace_map
                    and name in forward_inplace_map.values()
1448
                ):
1449
                    ind = list(forward_inplace_map.values()).index(name)
1450 1451 1452 1453
                    if (
                        list(forward_inplace_map.keys())[ind]
                        in self.optional_inputs
                    ):
1454
                        returns_type_list[pos] = inplace_optional_out_type_map[
1455 1456
                            rtype
                        ]
1457 1458
                    else:
                        returns_type_list[
1459 1460
                            pos
                        ] = "std::vector<paddle::experimental::Tensor>&"
1461 1462
                else:
                    returns_type_list[
1463 1464
                        pos
                    ] = "std::vector<paddle::experimental::Tensor>"
1465 1466 1467 1468 1469 1470 1471 1472

        if num_outputs == 1:
            returns_str = returns_list[0]
            returns_type_str = returns_type_list[0]
        else:
            returns_type_str = ", ".join(returns_type_list)
            returns_type_str = f"std::tuple<{returns_type_str}>"
            returns_str = ", ".join(returns_list)
1473
            returns_str = f"{returns_type_str}{{{returns_str}}}"
1474

1475
        # Node Creation Pre-Processing
1476
        inputs_names = []
1477
        if not self.is_forward_only:
Z
zyfncg 已提交
1478
            # 1. Get Input AutoGradMeta
1479 1480 1481 1482 1483
            inputs_autograd_meta_list = []
            compute_require_grad_args_list = ["trace_backward"]
            for name, (ttype, pos) in forward_inputs_position_map.items():
                # Has corresponding grad output
                has_corresponding_grad_output = False
1484 1485 1486 1487 1488
                for _, (
                    _,
                    corresponding_pos,
                    _,
                ) in backward_grad_outputs_map.items():
Z
zyfncg 已提交
1489 1490
                    if pos == corresponding_pos:
                        has_corresponding_grad_output = True
1491 1492 1493 1494 1495 1496 1497 1498
                if (
                    has_corresponding_grad_output
                    or (
                        name in forward_inplace_map
                        and forward_api_name not in inplace_check_blacklist
                    )
                    or self.is_forward_only
                ):
1499 1500 1501 1502 1503
                    input_autograd_meta_name = GetAutoGradMetaName(name)
                    if IsPlainTensorType(ttype):
                        input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
                    else:
                        assert IsVectorTensorType(ttype)
1504 1505 1506
                        input_autograd_meta_vec_name = (
                            GetAutoGradMetaVectorName(name)
                        )
1507 1508 1509 1510
                        input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
                        input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
                    inputs_autograd_meta_list.append(input_autograd_meta)
                    compute_require_grad_args_list.append(
1511 1512
                        input_autograd_meta_name
                    )
1513 1514 1515

            inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
            compute_require_grad_args_str = ",".join(
1516 1517
                compute_require_grad_args_list
            )
1518

Z
zyfncg 已提交
1519
            # 2. Get Output AutoGradMeta
1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532
            outputs_autograd_meta_list = []
            num_fwd_outputs = len(forward_outputs_position_map.keys())

            for name, (rtype, pos) in forward_outputs_position_map.items():
                output_autograd_meta_name = GetAutoGradMetaName(name)
                output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
                if num_fwd_outputs == 1:
                    if IsPlainTensorType(rtype):
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
                    else:
                        assert IsVectorTensorType(rtype)
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1533
                else:
1534 1535 1536 1537 1538 1539 1540
                    # Tuple api_result
                    if IsPlainTensorType(rtype):
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
                    else:
                        assert IsVectorTensorType(rtype)
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1541

1542 1543
                outputs_autograd_meta_list.append(output_autograd_meta)
            outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
1544

Z
zyfncg 已提交
1545 1546 1547 1548 1549 1550 1551
            # 3. Check Inplace
            check_inplace_str = ""
            bump_inplace_version_str = ""
            if is_inplaced:
                for inplace_name in forward_inplace_map.keys():
                    if forward_api_name not in inplace_check_blacklist:
                        inplace_autograd_meta_name = GetAutoGradMetaName(
1552 1553
                            inplace_name
                        )
Z
zyfncg 已提交
1554
                        check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
1555 1556 1557 1558 1559 1560 1561
                            inplace_name, inplace_autograd_meta_name
                        )
                    bump_inplace_version_str += (
                        BUMP_INPLACE_VERSION_TEMPLATE.format(
                            inplace_name, inplace_name
                        )
                    )
Z
zyfncg 已提交
1562 1563

            # Node Creation
1564 1565
            self.GenerateNodeCreationCodes()
            node_creation_str = self.node_creation_str
1566

1567
        dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
J
Jiabin Yang 已提交
1568
        forward_ad_function_name = GetDygraphForwardFunctionName(
1569 1570
            forward_api_name
        )
1571

1572
        # Forward amp logic
1573 1574 1575 1576 1577 1578
        kernel_trans2_op_name_str = (
            f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
        )
        amp_tensors_vector_list_str = (
            "{ " + ",".join(amp_tensors_vector_list) + " }"
        )
1579
        amp_tensors_vector_optional_list_str = "    ".join(
1580 1581
            amp_tensors_vector_optional_list
        )
1582
        amp_get_dst_dtype_str = "auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
1583 1584 1585 1586 1587
        amp_autocast_list_str = (
            "    ".join(amp_autocast_list)
            + "    "
            + "    ".join(amp_autocast_optional_list)
        )
1588
        amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
1589 1590 1591
        amp_call_str = (
            f"return {forward_ad_function_name}({amp_inputs_call_args_str});"
        )
1592
        if is_inplaced or (forward_api_name == "cast"):
J
Jiabin Yang 已提交
1593
            amp_logic_str = "\n VLOG(5) << \" No AMP for {} because it is a inplace or cast api. \"; ".format(
1594 1595
                forward_ad_function_name
            )
1596 1597
        else:
            amp_logic_str = AMP_LOGIC_TEMPLATE.format(
1598 1599 1600 1601 1602 1603 1604
                kernel_trans2_op_name_str,
                amp_tensors_vector_list_str,
                amp_tensors_vector_optional_list_str,
                amp_get_dst_dtype_str,
                amp_autocast_list_str,
                amp_call_str,
            )
1605

1606
        # Forward layout autotune
1607
        layout_autotune_list_str = "    ".join(
1608 1609
            layout_autotune_list
        ) + "    ".join(layout_autotune_optional_list)
1610
        layout_logic_str = self.GenerateForwardLayoutAutotune(
1611 1612 1613 1614 1615 1616 1617 1618
            forward_api_name,
            amp_tensors_vector_list,
            layout_tensors_vector_optional_list,
            layout_autotune_list_str,
            returns_type_str,
            returns_str,
            amp_inputs_call_args_str,
        )
1619

J
Jiabin Yang 已提交
1620 1621 1622 1623
        # For inputs outputs prepare for logging
        var_str = f"\n{indent}  std::string input_str = \"\";"
        var_str += f"\n{indent}  std::string output_str = \"\";"
        for name, (ttype, pos) in forward_inputs_position_map.items():
J
Jiabin Yang 已提交
1624
            var_str += f"\n{indent}  const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";"
J
Jiabin Yang 已提交
1625 1626
            var_str += f"\n{indent}  std::string input_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));"
            var_str += f"\n{indent}  input_str += input_{name}_str; "
1627 1628

        before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str)
J
Jiabin Yang 已提交
1629
        for name, (ttype, pos) in forward_outputs_position_map.items():
J
Jiabin Yang 已提交
1630
            var_str += f"\n{indent}  const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";"
J
Jiabin Yang 已提交
1631 1632 1633
            var_str += f"\n{indent}  std::string output_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));"
            var_str += f"\n{indent}  output_str += output_{name}_str; "

1634
        log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str)
J
Jiabin Yang 已提交
1635

1636
        # Generate forward_definition_str and forward_declaration_str
Z
zyfncg 已提交
1637 1638
        if self.is_forward_only:
            if len(amp_tensors_vector_list) == 0:
J
Jiabin Yang 已提交
1639
                amp_logic_str = "\n VLOG(7) << \" No AMP for {} because it has no input. \"; ".format(
1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659
                    forward_ad_function_name
                )
            self.forward_definition_str += (
                FORWARD_ONLY_FUNCTION_TEMPLATE.format(
                    returns_type_str,
                    forward_ad_function_name,
                    inputs_args_definition_str,
                    forward_api_name,
                    dygraph_event_str,
                    amp_logic_str,
                    layout_logic_str,
                    forward_api_name,
                    before_log_str,
                    forward_call_str,
                    get_outputs_str,
                    forward_api_name,
                    log_str,
                    returns_str,
                )
            )
Z
zyfncg 已提交
1660
        else:
1661
            self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683
                returns_type_str,
                forward_ad_function_name,
                inputs_args_definition_str,
                forward_api_name,
                dygraph_event_str,
                amp_logic_str,
                layout_logic_str,
                inputs_autograd_meta_str,
                forward_api_name,
                before_log_str,
                forward_call_str,
                check_nan_inf_str,
                get_outputs_str,
                outputs_autograd_meta_str,
                compute_require_grad_args_str,
                check_inplace_str,
                bump_inplace_version_str,
                node_creation_str,
                forward_api_name,
                log_str,
                returns_str,
            )
1684

J
Jiabin Yang 已提交
1685
        self.forward_declaration_str += f"{returns_type_str} {forward_ad_function_name}({inputs_args_declaration_str});\n"
1686 1687 1688 1689 1690 1691

    def GenerateInplacedForwardDygraphFunctions(self):
        # Inplaced Version Dygraph Function Generation
        forward_api_name = self.forward_api_name
        forward_api_contents = self.forward_api_contents

1692 1693 1694
        if (
            forward_api_name != "sum"
            and "inplace" in forward_api_contents.keys()
1695
        ):
1696 1697
            # Function Definition and Declaration Generation
            self.GenerateForwardDefinitionAndDeclaration(is_inplaced=True)
1698 1699 1700
            self.UpdateCoreOpsInformation(is_inplaced=True)

    def UpdateCoreOpsInformation(self, is_inplaced):
1701 1702 1703 1704 1705
        forward_api_name = (
            GetInplacedFunctionName(self.forward_api_name)
            if is_inplaced
            else self.forward_api_name
        )
1706 1707 1708 1709
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list

1710 1711 1712
        num_args = len(forward_inputs_position_map.keys()) + len(
            forward_attrs_list
        )
1713 1714
        num_returns = len(forward_outputs_position_map.keys())

1715 1716 1717 1718
        fwd_api_name = "" + forward_api_name
        core_ops_returns_info[fwd_api_name] = ["" for i in range(num_returns)]
        core_ops_args_info[fwd_api_name] = ["" for i in range(num_args)]
        core_ops_args_type_info[fwd_api_name] = ["" for i in range(num_args)]
1719

1720
        for name, (ttype, pos) in forward_inputs_position_map.items():
1721
            core_ops_args_info[fwd_api_name][pos] = name
1722
            if IsPlainTensorType(ttype):
1723
                core_ops_args_type_info[fwd_api_name][pos] = "tensor"
1724 1725
            else:
                assert IsVectorTensorType(ttype)
1726
                core_ops_args_type_info[fwd_api_name][pos] = "list"
1727 1728

        for name, _, _, pos in forward_attrs_list:
1729
            core_ops_args_info[fwd_api_name][pos] = name
1730 1731

        for name, (ttype, pos) in forward_outputs_position_map.items():
1732
            core_ops_returns_info[fwd_api_name][pos] = name
1733 1734

    def run(self):
1735
        super().run()
1736

1737 1738 1739
        ###################
        # Code Generation #
        ###################
1740 1741 1742

        # Definition And Declaration
        self.GenerateForwardDefinitionAndDeclaration(is_inplaced=False)
1743

1744
        self.UpdateCoreOpsInformation(is_inplaced=False)
1745

1746
        self.GenerateInplacedForwardDygraphFunctions()
1747 1748


1749
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
        next_grad_api_contents=None,
    ):
        DygraphFunctionGeneratorBase.__init__(
            self,
            forward_api_contents,
            grad_api_contents,
            forward_apis_dict,
            namespace,
        )
1765

1766
        # Record name mapping from forward_var_name to grad_var_names
1767 1768
        self.to_next_grad_name_mapping = {}  # {name : name}

1769 1770 1771
        # Generated Results
        self.node_declaration_str = ""
        self.node_definition_str = ""
1772
        self.next_grad_api_contents = next_grad_api_contents
1773

1774 1775 1776 1777 1778 1779
    def TransformToNextGradName(self, string):
        name_mapping = self.to_next_grad_name_mapping
        if string in name_mapping.keys():
            return name_mapping[string]
        return string

1780 1781 1782 1783 1784 1785 1786 1787 1788
    def ResetOptionalInputs(self):
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents

        base_generator = FunctionGeneratorBase(grad_api_contents, namespace)
        base_generator.ParseDispensable()

        self.optional_inputs = base_generator.optional_inputs

1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804
    def RecordGrad2NextGradNameMapping(self, next_node_generator):
        next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
        next_orig_returns_list = next_node_generator.orig_forward_returns_list

        next_forward_inputs_list = next_node_generator.forward_inputs_list
        next_forward_returns_list = next_node_generator.forward_returns_list
        for i in range(len(next_orig_inputs_list)):
            grad_name = next_orig_inputs_list[i][0]
            next_forward_name = next_forward_inputs_list[i][0]
            self.to_next_grad_name_mapping[grad_name] = next_forward_name

        for i in range(len(next_orig_returns_list)):
            grad_ret_name = next_orig_returns_list[i][0]
            next_ret_name = next_forward_returns_list[i][0]
            self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

1805
    def GenerateHigherOrderNodeCreationCode(self):
1806
        has_higher_order_node = False
1807 1808
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents
1809
        forward_apis_dict = self.forward_apis_dict
1810 1811
        next_grad_api_contents = self.next_grad_api_contents

1812 1813
        next_grad_node_creation_str = ""
        next_grad_node_out_list = []
J
Jiabin Yang 已提交
1814
        next_node_generator = None
1815
        if next_grad_api_contents:
1816
            # Fake forward_api_contents and backward_api_contents
1817
            forward_api_contents = grad_api_contents
1818
            forward_api_contents['op'] = forward_api_contents['backward_op']
1819 1820 1821
            backward_api_contents = next_grad_api_contents

            next_node_generator = DygraphFunctionGeneratorBase(
1822 1823 1824 1825 1826
                forward_api_contents,
                backward_api_contents,
                forward_apis_dict,
                namespace,
            )
1827
            next_node_generator.run()
1828
            next_node_generator.GenerateNodeCreationCodes(True)
J
Jiabin Yang 已提交
1829

1830 1831
            next_grad_node_creation_str = next_node_generator.node_creation_str
            next_grad_node_out_list = next_node_generator.grad_node_out_list
1832

1833
            self.RecordGrad2NextGradNameMapping(next_node_generator)
1834 1835 1836 1837

        is_invoke_forward_api = IsInvokeForwardApi(
            self.grad_api_contents, self.forward_apis_dict
        )
J
Jiabin Yang 已提交
1838
        is_composite_grad_api = (
J
Jiabin Yang 已提交
1839
            False if self.composite_func_info == {} else True
J
Jiabin Yang 已提交
1840 1841
        )

1842 1843 1844 1845 1846 1847 1848
        if is_composite_grad_api and next_grad_node_creation_str != '':
            next_grad_node_creation_str = f"""
 if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
    {next_grad_node_creation_str}
 }}
  """

J
Jiabin Yang 已提交
1849
        if next_node_generator is not None:
1850
            has_higher_order_node = True
1851
            return (
1852 1853
                has_higher_order_node,
                is_invoke_forward_api,
J
Jiabin Yang 已提交
1854
                is_composite_grad_api,
1855 1856 1857 1858
                next_grad_node_creation_str,
                next_grad_node_out_list,
                next_node_generator.backward_forward_inputs_map,
            )
J
Jiabin Yang 已提交
1859 1860
        # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered
        elif not is_invoke_forward_api and not is_composite_grad_api:
1861 1862 1863 1864 1865 1866 1867 1868 1869
            next_grad_node_creation_str = f"""  if(trace_backward) {{
    PADDLE_THROW(phi::errors::Unavailable(
    \"The Op {self.backward_api_name} doesn't have any grad\"
    \"op. If you don't intend calculating higher order\"
    \"derivatives, please set `create_graph`to False.\"));
  }}"""
        return (
            has_higher_order_node,
            is_invoke_forward_api,
J
Jiabin Yang 已提交
1870
            is_composite_grad_api,
1871 1872 1873 1874
            next_grad_node_creation_str,
            next_grad_node_out_list,
            None,
        )
1875

1876 1877 1878 1879 1880
    def GenerateNodeDeclaration(self):
        forward_op_name = self.forward_api_name
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_attrs_list = self.backward_attrs_list
        no_need_buffers = self.no_need_buffers
1881

1882
        # SetTensorWrapper Methods & TensorWrapper Members & ClearTensorWrappers
1883 1884 1885
        set_tensor_wrapper_methods_str = ""
        tensor_wrapper_members_str = ""
        clear_tensor_wrapper_str = ""
1886 1887 1888 1889 1890
        for tname, (
            ttype,
            is_fwd_input,
            _,
        ) in backward_forward_inputs_map.items():
1891 1892 1893
            no_need_buffer = "true" if tname in no_need_buffers else "false"
            tensor_wrapper_name = GetSavedName(tname)
            if IsPlainTensorType(ttype):
1894 1895 1896 1897 1898
                set_tensor_wrapper_methods_str += (
                    SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
                        tname, tname, tensor_wrapper_name, tname, no_need_buffer
                    )
                )
1899

1900 1901 1902
                tensor_wrapper_members_str += (
                    PLAIN_TENSOR_MEMBER_TEMPLATE.format(tensor_wrapper_name)
                )
1903

1904 1905 1906
                clear_tensor_wrapper_str += (
                    CLEAR_TENSOR_WRAPPER_TEMPLATE.format(tensor_wrapper_name)
                )
1907

1908 1909
            else:
                assert IsVectorTensorType(ttype)
1910 1911 1912 1913 1914
                set_tensor_wrapper_methods_str += (
                    SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
                        tname, tname, tname, tensor_wrapper_name, no_need_buffer
                    )
                )
1915

1916 1917 1918
                tensor_wrapper_members_str += (
                    VECTOR_TENSOR_MEMBER_TEMPLATE.format(tensor_wrapper_name)
                )
1919

1920 1921 1922 1923 1924
                clear_tensor_wrapper_str += (
                    CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
                        tensor_wrapper_name
                    )
                )
1925 1926 1927 1928 1929 1930 1931

        # SetAttributes & Attribute Members
        set_attribute_methods_str = ""
        attribute_members_str = ""
        for aname, atype, default_val, _ in backward_attrs_list:
            saved_attr_name = GetSavedName(aname)
            set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
1932 1933
                aname, GetConstReference(atype), aname, saved_attr_name, aname
            )
1934 1935

            if default_val:
1936 1937 1938 1939 1940 1941 1942
                attribute_members_str += (
                    ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
                        RemoveConstAndReference(atype),
                        saved_attr_name,
                        default_val,
                    )
                )
1943 1944
            else:
                attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
1945 1946
                    RemoveConstAndReference(atype), saved_attr_name
                )
1947

1948
        grad_node_name = GetGradNodeName(self.backward_api_name)
1949
        self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965
            grad_node_name,
            grad_node_name,
            grad_node_name,
            grad_node_name,
            grad_node_name,
            clear_tensor_wrapper_str,
            grad_node_name,
            grad_node_name,
            set_tensor_wrapper_methods_str,
            set_attribute_methods_str,
            tensor_wrapper_members_str,
            attribute_members_str,
        )

    def GenerateNodeDefinition(
        self,
1966 1967
        has_higher_order_node,
        is_invoke_forward_api,
J
Jiabin Yang 已提交
1968
        is_composite_grad_api,
1969 1970 1971 1972
        next_grad_node_creation_str,
        next_grad_node_out_list,
        backward_forward_inputs_map_next,
    ):
1973 1974 1975
        namespace = self.namespace
        forward_api_name = self.forward_api_name
        backward_api_name = self.backward_api_name
J
Jiabin Yang 已提交
1976
        composite_grad_api_name = (
J
Jiabin Yang 已提交
1977
            self.composite_func_info["name"] if is_composite_grad_api else None
J
Jiabin Yang 已提交
1978
        )
1979 1980 1981 1982
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
        backward_attrs_list = self.backward_attrs_list
1983
        backward_inplace_map = self.backward_inplace_map
1984
        indent = GetIndent(1)
1985
        need_gen_trace_backard_for_inplace = False
1986 1987 1988

        # Construct grad_api function args
        # Order: TensorWrappers, GradTensors, Attributes
1989 1990 1991 1992 1993
        grad_api_args_len = (
            len(backward_forward_inputs_map.keys())
            + len(backward_grad_inputs_map.keys())
            + len(backward_attrs_list)
        )
1994
        grad_api_args = ["" for i in range(grad_api_args_len)]
1995 1996 1997 1998
        get_grad_in_args_list = []

        # Fill Grad Ins with Zero
        fill_zero_str = ""
1999
        if backward_api_name in ops_to_fill_zero_for_empty_grads:
2000 2001 2002 2003 2004 2005 2006 2007
            fill_zero_str = (
                f"{indent}const auto& input_metas = this->InputMeta();\n"
            )
            for name, (
                ttype,
                fwd_position,
                grad_api_position,
            ) in backward_grad_inputs_map.items():
2008 2009 2010 2011 2012 2013 2014 2015
                if name in self.optional_inputs:
                    if IsPlainTensorType(ttype):
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n"
                else:
                    if IsPlainTensorType(ttype):
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n"
                    else:
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
2016

2017
        inplace_grad_input_str = ""
J
Jiabin Yang 已提交
2018
        inplace_check_str = ""
2019
        optional_inplace_var_name = []
2020
        # Grad Ins from TensorWrappers
2021 2022 2023 2024
        for (
            name,
            (backward_input_type, is_fwd_input, grad_api_position),
        ) in backward_forward_inputs_map.items():
2025
            tensor_wrapper_name = GetSavedName(name)
2026
            transformed_tensor_name = self.TransformToNextGradName(name)
2027

2028
            is_optional = name in self.optional_inputs
2029
            tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
2030
            if backward_inplace_map and name in backward_inplace_map.keys():
2031
                if has_higher_order_node:
2032 2033 2034 2035 2036 2037 2038 2039
                    if (
                        transformed_tensor_name
                        in backward_forward_inputs_map_next
                    ) and (
                        backward_forward_inputs_map_next[
                            transformed_tensor_name
                        ][1]
                    ):
2040
                        optional_inplace_var_name.append(
2041 2042 2043 2044 2045
                            transformed_tensor_name
                        )
                tensor_wrapper_intermidiate_tensor_str = (
                    f"(&this->{tensor_wrapper_name})->get_intermidiate_tensor()"
                )
J
Jiabin Yang 已提交
2046
                inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
2047 2048 2049 2050 2051 2052 2053 2054 2055
                    transformed_tensor_name,
                    transformed_tensor_name,
                    name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    tensor_wrapper_intermidiate_tensor_str,
                )
2056
                inplace_grad_input_str = transformed_tensor_name
2057
            if is_optional:
2058
                if backward_input_type == "std::vector<Tensor>":
2059 2060 2061 2062 2063 2064 2065 2066 2067
                    tensor_wrapper_recover_str += (
                        "\n"
                        + CREATE_RECOVER_OPTIONAL_VECTOR_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
2068
                else:
2069 2070 2071 2072 2073 2074 2075 2076 2077
                    tensor_wrapper_recover_str += (
                        "\n"
                        + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
H
hong 已提交
2078

2079 2080 2081
                grad_api_args[grad_api_position] = (
                    transformed_tensor_name + "_optional"
                )
H
hong 已提交
2082

2083
            else:
H
hong 已提交
2084 2085
                grad_api_args[grad_api_position] = transformed_tensor_name

2086 2087 2088
            get_grad_in_args_list.append(tensor_wrapper_recover_str)

        # Grad Ins from grads
2089 2090 2091 2092 2093
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_inputs_map.items():
2094
            transformed_tensor_name = self.TransformToNextGradName(name)
2095

2096
            is_optional = name in self.optional_inputs
2097
            if IsPlainTensorType(ttype):
2098 2099
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"

2100 2101
                # Inplace in backward op
                if backward_inplace_map and name in backward_inplace_map.keys():
2102
                    if has_higher_order_node:
2103 2104 2105 2106 2107 2108 2109 2110
                        if (
                            transformed_tensor_name
                            in backward_forward_inputs_map_next
                        ) and (
                            backward_forward_inputs_map_next[
                                transformed_tensor_name
                            ][1]
                        ):
2111
                            optional_inplace_var_name.append(
2112 2113
                                transformed_tensor_name
                            )
2114
                    grads_tensor_str = f"grads[{fwd_position}][0]"
J
Jiabin Yang 已提交
2115
                    inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
2116 2117 2118 2119 2120 2121 2122 2123 2124
                        transformed_tensor_name,
                        transformed_tensor_name,
                        name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        grads_tensor_str,
                    )
2125 2126
                    inplace_grad_input_str = transformed_tensor_name

2127
                if is_optional:
2128 2129 2130 2131 2132 2133 2134 2135 2136
                    get_tensor_str += (
                        "\n"
                        + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
2137
                    grad_api_args[
2138 2139
                        grad_api_position
                    ] = f"{transformed_tensor_name}_optional"
2140 2141
                else:
                    grad_api_args[grad_api_position] = transformed_tensor_name
2142 2143
            else:
                assert IsVectorTensorType(ttype)
2144 2145 2146
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
                grad_api_args[grad_api_position] = transformed_tensor_name

2147
            get_grad_in_args_list.append(get_tensor_str)
2148

2149
        # Grad Attrs
2150 2151
        for name, _, _, grad_api_position in backward_attrs_list:
            saved_attribute_name = GetSavedName(name)
2152 2153 2154
            get_attr_str = (
                f"{indent}auto& {name} = this->{saved_attribute_name};"
            )
2155 2156 2157 2158 2159 2160

            grad_api_args[grad_api_position] = name
            get_grad_in_args_list.append(get_attr_str)

        get_grad_in_args_str = "\n".join(get_grad_in_args_list)

2161 2162 2163
        # Grad Function Call String
        slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
        grad_api_namespace = f"paddle::experimental::{namespace}"
J
Jiabin Yang 已提交
2164
        composite_grad_api_namespace = f"paddle::prim::{namespace}"
J
Jiabin Yang 已提交
2165
        grad_function_prepare_str = f"""
2166 2167 2168
  const auto& out_metas = OutputMeta();
  paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs});
  for (int i = 0; i < {slot_num_bwd_outputs}; ++i) {{
2169
    out_metas[i].size() == 0 ? returns[i].resize(1) : returns[i].resize(out_metas[i].size());
2170 2171
  }}
"""
J
Jiabin Yang 已提交
2172 2173
        inplace_for_grad_outs_str = ""
        optional_inplace_str = ""
2174 2175
        # Grad Outputs
        out_index = -1
2176
        out_assign_str = ""
2177 2178 2179 2180 2181
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
2182 2183
            transformed_tensor_name = self.TransformToNextGradName(name)
            out_index = out_index + 1
2184 2185
            if is_invoke_forward_api:
                if len(backward_grad_outputs_map) == 1:
2186 2187 2188
                    out_assign_str += (
                        f"{indent}*api_output_{out_index} = api_output;\n"
                    )
2189 2190 2191 2192
                else:
                    out_assign_str += f"{indent}*api_output_{out_index} = std::get<{out_index}>(api_output);\n"
            else:
                grad_api_args.append(f"api_output_{out_index}")
2193 2194
            if inplace_grad_input_str in optional_inplace_var_name:
                optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input: {inplace_grad_input_str}\";"
J
Jiabin Yang 已提交
2195 2196
            else:
                optional_inplace_str = f"""if (api_output_{out_index} != nullptr && can_be_inplaced) {{
2197 2198
      egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
    }}"""
2199
            if IsPlainTensorType(ttype):
J
Jiabin Yang 已提交
2200

2201 2202 2203
                if (
                    backward_inplace_map
                    and name in backward_inplace_map.values()
2204
                ):
2205 2206 2207
                    inplace_str = f""" if (api_output_{out_index} != nullptr && can_be_inplaced) {{
      egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
    }}"""
2208
                    if has_higher_order_node:
J
Jiabin Yang 已提交
2209
                        inplace_for_grad_outs_str += f"""
2210
  if (trace_backward) {{
J
Jiabin Yang 已提交
2211
    {optional_inplace_str}
2212 2213
  }} else {{
    {inplace_str}
J
Jiabin Yang 已提交
2214
  }}"""
2215
                        need_gen_trace_backard_for_inplace = True
J
Jiabin Yang 已提交
2216 2217
                    else:
                        inplace_for_grad_outs_str += inplace_str
2218

J
Jiabin Yang 已提交
2219 2220
                grad_function_prepare_str += f"""
  auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];"""
2221 2222 2223

            else:
                assert IsVectorTensorType(ttype)
J
Jiabin Yang 已提交
2224
                grad_function_prepare_str += f"""
2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235
  std::vector<paddle::experimental::Tensor*> api_output_{out_index};
  api_output_{out_index}.reserve(returns[{fwd_position}].size());
  for (size_t i = 0; i < returns[{fwd_position}].size(); ++i) {{
    if (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][i].IsStopGradient()) {{
      api_output_{out_index}.push_back(nullptr);
    }} else {{
      api_output_{out_index}.push_back(&returns[{fwd_position}][i]);
    }}
  }}"""

        grad_api_args_str = ", ".join(grad_api_args)
J
Jiabin Yang 已提交
2236 2237
        composite_grad_api_args_str = ", ".join(grad_api_args)
        composite_template_name = "<paddle::experimental::Tensor>"
2238

2239 2240
        if is_invoke_forward_api:
            autograd_api_out = "auto"
2241 2242 2243 2244
            if (
                len(self.backward_inplace_map) > 0
                and len(backward_grad_outputs_map) == 1
            ):
2245
                autograd_api_out = "auto&"
2246 2247 2248
            forward_api_name = (
                self.grad_api_contents['invoke'].split('(')[0].strip()
            )
2249
            autograd_api = self.grad_api_contents['invoke'].replace(
J
Jiabin Yang 已提交
2250
                forward_api_name,
2251 2252 2253
                GetDygraphForwardFunctionName(forward_api_name),
                1,
            )
2254
            grad_function_call_str = f"""
2255
  if (trace_backward) {{
2256 2257 2258 2259
  {indent}{autograd_api_out} api_output = {autograd_api};
  {out_assign_str}}} else {{
  {indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']};
  {out_assign_str}{indent}}}
J
Jiabin Yang 已提交
2260 2261 2262 2263 2264 2265
  """
        # TODO(Ruting):using composite only when we don't have backward kernel in the future.
        elif is_composite_grad_api:
            grad_function_call_str = f"""
  if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
  {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
J
Jiabin Yang 已提交
2266
  VLOG(4) << "Composite api {composite_grad_api_name} is called ";
J
Jiabin Yang 已提交
2267 2268
  }}else{{
  {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});
J
Jiabin Yang 已提交
2269
  VLOG(4) << "Fused api {backward_api_name} is called ";
J
Jiabin Yang 已提交
2270
  }}
2271 2272 2273
  """
        else:
            grad_function_call_str = f"""
2274
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
2275

2276
        # Check Nan and Inf
2277
        check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
2278 2279
            backward_api_name, "returns"
        )
2280

2281 2282
        # Prepare for Node Creation if Necessary
        outputs_autograd_meta_str = ""
2283
        compute_require_next_grad_str = ""
2284 2285 2286 2287 2288
        if (
            len(next_grad_node_creation_str) > 0
            or is_invoke_forward_api
            or need_gen_trace_backard_for_inplace
        ):
2289 2290 2291 2292 2293 2294 2295
            compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"

        # 3. Get Output AutoGradMeta
        outputs_autograd_meta_list = []
        # TODO(jiabin): Optimize this with SetStopGradient instead of Pass Stop gradient

        num_fwd_outputs = len(backward_grad_outputs_map.keys())
2296 2297 2298 2299 2300
        for name, (
            rtype,
            pos,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
2301 2302 2303
            transformed_tensor_name = self.TransformToNextGradName(name)

            output_autograd_meta_name = GetAutoGradMetaName(
2304 2305
                transformed_tensor_name
            )
2306
            output_autograd_meta_vec_name = GetAutoGradMetaVectorName(
2307 2308
                transformed_tensor_name
            )
2309 2310
            if IsPlainTensorType(rtype):
                output_autograd_meta = f"""
2311
  auto& {transformed_tensor_name} = returns[{pos}][0];
2312 2313 2314
  egr::AutogradMeta* {output_autograd_meta_name} = returns[{pos}][0].initialized() ? egr::EagerUtils::autograd_meta(&{transformed_tensor_name}) : nullptr;
  if ({output_autograd_meta_name}) {output_autograd_meta_name}->SetStopGradient(false);
  """
2315

2316 2317
            else:
                assert IsVectorTensorType(rtype)
2318
                if has_higher_order_node > 0:
2319 2320 2321 2322 2323 2324 2325 2326
                    output_autograd_meta = f"""
    auto& {transformed_tensor_name} = returns[{pos}];
    std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});
    std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};
    for(auto* meta : {output_autograd_meta_vec_name}){{
        meta->SetStopGradient(false);
    }}
"""
2327
                else:
2328
                    output_autograd_meta = f"""
2329 2330 2331 2332 2333
    auto& {transformed_tensor_name} = returns[{pos}];
    std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});
    for(auto* meta : {output_autograd_meta_vec_name}){{
        meta->SetStopGradient(false);
    }}
2334
"""
2335
            outputs_autograd_meta_list.append(output_autograd_meta)
2336

2337
        outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
2338

2339
        returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
2340
        returns_str += f"{indent}return returns;\n"
2341

2342
        grad_node_name = GetGradNodeName(self.backward_api_name)
J
Jiabin Yang 已提交
2343 2344 2345
        # For inputs outputs prepare for logging
        var_str = f"\n{indent}  std::string input_str = \"\";"
        var_str += f"\n{indent}  std::string output_str = \"\";"
2346 2347 2348 2349 2350
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_inputs_map.items():
J
Jiabin Yang 已提交
2351
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2352
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2353 2354 2355
            var_str += f"\n{indent}  std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  input_str += input_{new_name}_str; "

2356 2357 2358 2359
        for (
            name,
            (backward_input_type, is_fwd_input, grad_api_position),
        ) in backward_forward_inputs_map.items():
J
Jiabin Yang 已提交
2360
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2361
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2362 2363 2364
            var_str += f"\n{indent}  std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  input_str += input_{new_name}_str; "

2365 2366
        before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str)

2367 2368 2369 2370 2371
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
J
Jiabin Yang 已提交
2372
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2373
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n ( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2374 2375 2376
            var_str += f"\n{indent}  std::string output_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  output_str += output_{new_name}_str; "

2377
        log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str)
2378

2379
        self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397
            grad_node_name,
            self.backward_api_name,
            fill_zero_str,
            get_grad_in_args_str,
            grad_function_prepare_str,
            compute_require_next_grad_str,
            inplace_check_str,
            inplace_for_grad_outs_str,
            self.backward_api_name,
            before_log_str,
            grad_function_call_str,
            check_nan_inf_str,
            outputs_autograd_meta_str,
            next_grad_node_creation_str,
            self.backward_api_name,
            log_str,
            returns_str,
        )
2398 2399 2400

    def run(self):
        super().run()
2401

2402 2403
        self.ResetOptionalInputs()

2404 2405 2406
        ###################
        # Code Generation #
        ###################
2407
        # Higher-order GradNode generation
2408
        (
2409 2410
            has_higher_order_node,
            is_invoke_forward_api,
J
Jiabin Yang 已提交
2411
            is_composite_grad_api,
2412 2413 2414 2415
            next_grad_node_creation_str,
            next_grad_node_out_list,
            backward_forward_inputs_map,
        ) = self.GenerateHigherOrderNodeCreationCode()
2416

2417 2418
        self.GenerateNodeDeclaration()

2419
        self.GenerateNodeDefinition(
2420 2421
            has_higher_order_node,
            is_invoke_forward_api,
J
Jiabin Yang 已提交
2422
            is_composite_grad_api,
2423 2424 2425 2426
            next_grad_node_creation_str,
            next_grad_node_out_list,
            backward_forward_inputs_map,
        )
2427 2428


2429
class DygraphForwardAndNodesGenerator(GeneratorBase):
2430
    def __init__(self, api_yaml_path, backward_yaml_path):
2431
        # Parent members:
2432 2433 2434
        # self.namespace
        # self.api_yaml_path
        # self.forward_api_list
2435
        GeneratorBase.__init__(self, api_yaml_path)
2436 2437 2438 2439 2440

        self.backward_yaml_path = backward_yaml_path
        self.grad_api_dict = {}

        self.forward_declaration_str = ""
2441 2442
        self.forward_definition_str = ""

2443 2444 2445
        self.node_declaration_str = ""
        self.node_definition_str = ""

2446
    def CollectIsForwardOnly(self, forward_api_contents):
2447 2448 2449
        self.is_forward_only = (
            False if 'backward' in forward_api_contents.keys() else True
        )
2450

2451 2452 2453 2454
    def ParseYamlContents(self):
        self.ParseForwardYamlContents()

        backward_yaml_path = self.backward_yaml_path
2455 2456 2457 2458

        # string api is forward_only, no backward_yaml respectively
        if backward_yaml_path is not None:
            self.grad_api_dict = ReadBwdFile(backward_yaml_path)
2459 2460 2461 2462

    def GetBackwardAPIContents(self, forward_api_contents):
        grad_api_dict = self.grad_api_dict

2463 2464
        if 'backward' not in forward_api_contents.keys():
            return None
2465 2466

        backward_api_name = forward_api_contents['backward']
2467
        assert backward_api_name in grad_api_dict.keys(), AssertMessage(
2468 2469
            backward_api_name, grad_api_dict.keys()
        )
2470 2471 2472 2473 2474 2475
        backward_api_contents = grad_api_dict[backward_api_name]

        return backward_api_contents

    def GenerateCode(self):
        forward_api_list = self.forward_api_list
2476 2477
        forward_apis_dict = {}
        for api_item in forward_api_list:
2478
            forward_apis_dict[api_item['op']] = api_item
2479 2480 2481
        namespace = self.namespace

        for forward_api_contents in forward_api_list:
2482 2483
            if forward_api_contents['op'] in black_ops_list:
                continue
W
Weilong Wu 已提交
2484

2485 2486 2487 2488 2489 2490
            self.CollectIsForwardOnly(forward_api_contents)

            if self.is_forward_only:
                backward_api_contents = None
            else:
                backward_api_contents = self.GetBackwardAPIContents(
2491 2492
                    forward_api_contents
                )
2493

2494
            # Generate Dygraph Forward Function
2495
            function_generator = DygraphForwardFunctionGenerator(
2496 2497 2498 2499 2500
                forward_api_contents,
                backward_api_contents,
                forward_apis_dict,
                namespace,
            )
2501 2502
            function_generator.run()

2503 2504 2505 2506 2507 2508
            self.forward_definition_str += (
                function_generator.forward_definition_str + "\n"
            )
            self.forward_declaration_str += (
                function_generator.forward_declaration_str + "\n"
            )
2509

2510
            # Generate Dygraph GradNode Function
2511
            while True:
2512 2513
                if backward_api_contents is None:
                    break
2514
                next_grad_api_contents = self.GetBackwardAPIContents(
2515 2516
                    backward_api_contents
                )
2517

2518 2519 2520 2521 2522 2523 2524
                node_generator = DygraphNodeGenerator(
                    forward_api_contents,
                    backward_api_contents,
                    forward_apis_dict,
                    namespace,
                    next_grad_api_contents,
                )
2525
                node_generator.run()
2526 2527 2528 2529 2530 2531
                self.node_declaration_str += (
                    node_generator.node_declaration_str + "\n"
                )
                self.node_definition_str += (
                    node_generator.node_definition_str + "\n"
                )
2532

2533 2534
                if next_grad_api_contents is None:
                    break
2535 2536 2537 2538 2539

                # Detect if there exists higher-order GradNode
                forward_api_contents = backward_api_contents

                # Fake forward_api_content
2540
                forward_api_contents['op'] = forward_api_contents['backward_op']
2541
                backward_api_contents = next_grad_api_contents
2542 2543 2544 2545 2546

        if len(namespace) > 0:
            if namespace.endswith("::"):
                namespace = namespace[:-2]
            self.forward_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2547 2548
                namespace, self.forward_definition_str
            )
2549
            self.forward_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2550 2551
                namespace, self.forward_declaration_str
            )
2552
            self.node_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2553 2554
                namespace, self.node_declaration_str
            )
2555
            self.node_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2556 2557
                namespace, self.node_definition_str
            )
2558 2559 2560 2561 2562 2563 2564 2565 2566

    def run(self):
        self.ParseYamlContents()

        self.InferNameSpace()

        self.GenerateCode()


2567 2568 2569
################
# File Writers #
################
2570
def GenerateNodeCCFile(filepath, node_definition_str):
2571 2572
    if os.path.exists(filepath):
        os.remove(filepath)
2573

2574
    file_contents = NODE_CC_FILE_TEMPLATE.format(node_definition_str)
2575 2576 2577 2578 2579
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateNodeHFile(filepath, node_declaration_str):
2580 2581
    if os.path.exists(filepath):
        os.remove(filepath)
2582

2583
    file_contents = NODE_H_FILE_TEMPLATE.format(node_declaration_str)
2584 2585 2586 2587 2588
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardCCFile(filepath, forward_definition_str):
2589 2590
    if os.path.exists(filepath):
        os.remove(filepath)
2591

2592
    core_ops_info_str = GenerateCoreOpInfoDefinition()
2593 2594 2595
    file_contents = FORWARD_CC_FILE_TEMPLATE.format(
        core_ops_info_str, forward_definition_str
    )
2596 2597 2598 2599 2600
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardHFile(filepath, forward_function_declaration_str):
2601 2602
    if os.path.exists(filepath):
        os.remove(filepath)
2603

2604 2605
    core_ops_info_str = GenerateCoreOpInfoDeclaration()
    file_contents = FORWARD_H_FILE_TEMPLATE.format(
2606 2607
        core_ops_info_str, forward_function_declaration_str
    )
2608 2609 2610 2611 2612 2613 2614
    with open(filepath, 'a') as f:
        f.write(file_contents)


if __name__ == "__main__":
    args = ParseArguments()

2615 2616
    api_yaml_paths = args.api_yaml_path.split(",")
    backward_yaml_paths = args.backward_yaml_path.split(",")
2617 2618 2619 2620

    # Generate per Dygraph API
    node_declaration_str = ""
    node_definition_str = ""
2621

2622
    forward_declaration_str = ""
2623
    forward_definition_str = ""
2624

2625 2626
    for i in range(len(api_yaml_paths)):
        api_yaml_path = api_yaml_paths[i]
2627 2628

        # string api is forwrad only
C
Chen Weihang 已提交
2629
        if not api_yaml_path.endswith('strings_ops.yaml'):
2630 2631 2632
            backward_yaml_path = backward_yaml_paths[i]
        else:
            backward_yaml_path = None
2633

2634 2635 2636
        generator = DygraphForwardAndNodesGenerator(
            api_yaml_path, backward_yaml_path
        )
2637
        generator.run()
2638

2639 2640
        node_declaration_str += generator.node_declaration_str + "\n"
        node_definition_str += generator.node_definition_str + "\n"
2641

2642
        forward_declaration_str += generator.forward_declaration_str + "\n"
2643
        forward_definition_str += generator.forward_definition_str + "\n"
2644

2645 2646 2647 2648 2649 2650 2651 2652 2653 2654
    # Generate Files
    nodes_h_path = args.nodes_h_path
    nodes_cc_path = args.nodes_cc_path
    forwards_h_path = args.forwards_h_path
    forwards_cc_path = args.forwards_cc_path

    GenerateNodeCCFile(nodes_cc_path, node_definition_str)
    GenerateNodeHFile(nodes_h_path, node_declaration_str)
    GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
    GenerateForwardHFile(forwards_h_path, forward_declaration_str)