eager_gen.py 98.5 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import argparse
17
import os
18 19 20 21 22
from codegen_utils import (
    core_ops_returns_info,
    core_ops_args_info,
    core_ops_args_type_info,
)
23 24
from codegen_utils import ReadBwdFile
from codegen_utils import FindForwardName, GetGradNodeName, GetSavedName
25 26
from codegen_utils import IsPlainTensorType, IsVectorTensorType
from codegen_utils import GetConstReference, RemoveConstAndReference
27 28 29 30
from codegen_utils import (
    GetDygraphForwardFunctionName,
    GetIntermediateAPIFunctionName,
)
31 32
from codegen_utils import GetAutoGradMetaName, GetAutoGradMetaVectorName
from codegen_utils import GetInplacedFunctionName
33 34
from codegen_utils import ParseYamlForwardFromBackward
from codegen_utils import ParseYamlBackward
35
from codegen_utils import ParseYamlInplaceInfo
36
from codegen_utils import FunctionGeneratorBase, GeneratorBase
37
from codegen_utils import ops_to_fill_zero_for_empty_grads
38
from codegen_utils import AssertMessage, GetIndent
39

Z
zyfncg 已提交
40 41 42 43 44 45
# Note: assign is a inplace api when parameter(output) isn't none,
# so we should check parameter(output) with rule of inplace.
# But because there is no check in old dygraph mode, in order to
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"])
46 47 48

# Black Ops list that's NO NEED to apply code generation
black_ops_list = [
49 50 51 52 53
    "conv2d",
    "conv2d_grad",
    "conv2d_grad_grad",
    "add_n",
    "add_n_grad",
54
]
Z
zyfncg 已提交
55

56 57 58 59

###########
## Utils ##
###########
60 61
def ParseArguments():
    parser = argparse.ArgumentParser(
62 63
        description='Eager Code Generator Args Parser'
    )
64 65 66 67 68 69 70 71 72 73 74
    parser.add_argument('--nodes_h_path', type=str)
    parser.add_argument('--nodes_cc_path', type=str)
    parser.add_argument('--forwards_h_path', type=str)
    parser.add_argument('--forwards_cc_path', type=str)
    parser.add_argument('--api_yaml_path', type=str)
    parser.add_argument('--backward_yaml_path', type=str)

    args = parser.parse_args()
    return args


75 76 77
########################
## Code Gen Templates ##
########################
78
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = """  void SetTensorWrapper{}(const paddle::experimental::Tensor& {}) {{
79
    {} = egr::TensorWrapper({}, {});
80
  }}
81 82
"""

83
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = """  void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
84
    for(const auto& eager_tensor : {}) {{
85
      {}.emplace_back(egr::TensorWrapper(eager_tensor, {}));
86 87
    }};
  }}
88 89
"""

90
PLAIN_TENSOR_MEMBER_TEMPLATE = """  egr::TensorWrapper {};
91 92
"""

93
VECTOR_TENSOR_MEMBER_TEMPLATE = """  std::vector<egr::TensorWrapper> {};
94
"""
95

96
CLEAR_TENSOR_WRAPPER_TEMPLATE = """    {}.clear();
97 98
"""

99
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = """    for (auto& tw : {}) {{
100 101
      tw.clear();
    }}
102 103
"""

104
SET_ATTR_METHOD_TEMPLATE = """  void SetAttribute{}({} {}) {{
105 106
    {} = {};
  }}
107 108
"""

109
ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = """  {} {} = {};
110 111
"""

112
ATTRIBUTE_MEMBER_TEMPLATE = """  {} {};
113 114
"""

115
NODE_DECLARATION_TEMPLATE = """
116 117 118
class {} : public egr::GradNodeBase {{
 public:
  {}() : egr::GradNodeBase() {{}}
119
  {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
120 121 122
      egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
  ~{}() override = default;

123 124
  virtual paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> operator()(
      paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph = false, bool is_new_grad = false) override;
C
chenjian 已提交
125
  std::string name() override {{ return \"{}\"; }}
126

127
  void ClearTensorWrappers() override {{
128 129
{}
    SetIsTensorWrappersCleared(true);
130 131 132
  }}

  std::shared_ptr<GradNodeBase> Copy() const override {{
133 134
    auto copied_node = std::shared_ptr<{}>(new {}(*this));
    return copied_node;
135
  }}
136

137
  // SetTensorWrapperX, SetTensorWrapperY, ...
138
{}
139
  // SetAttributes
140
{}
141 142
 private:
  // TensorWrappers
143
{}
144
  // Attributes
145
{}}};
146 147
"""

148
GRAD_FUNCTION_TEMPLATE = """
149
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> {}::operator()(paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) {{
J
Jiabin Yang 已提交
150
  VLOG(3) << \"Running AD API GRAD: \" << \"{}\";
151
  // Fill Zero For GradIn Tensors
152
{}
153 154
  // Apply Gradient Hooks
  auto hooked_grads = ApplyGradientHooks(grads);
155

156
  // Collect GradIn Tensors, Attrs and Recovered TensorWrappers
157
{}
J
Jiabin Yang 已提交
158 159
  // Prepare Grad function call
{}
160
  // Runtime check if we need next grad
J
Jiabin Yang 已提交
161 162 163 164 165
{}
  // Inplace Check
{}
  // Inplace Strategy
{}
166

J
Jiabin Yang 已提交
167
  VLOG(5) << \"Running C++ API: \" << \"{}\";
168 169 170
  // Before log info
{}
  // Call grad_api function
171 172
{}
  // Check NaN and Inf id needed
173
{}
174
  // Get GradOut autograd_meta
175
{}
176
  // Create Grad Node
177
{}
J
Jiabin Yang 已提交
178 179 180
  VLOG(4) << \"Finish AD API GRAD: {}";
  // LOG IF DEBUG
  {}
181
  // Return
182 183
{}
}}
184 185
"""

186
FORWARD_FUNCTION_TEMPLATE = """
187
{} {}({}) {{
J
Jiabin Yang 已提交
188
  VLOG(3) << \"Running AD API: \" << \"{}\";
189
  // Dygraph Record Event
190
{}
191
  // AMP Logic
192 193
{}
  // Layout autotune
194
{}
195
  // Get Input AutoGradMeta
196
{}
197

J
Jiabin Yang 已提交
198
  VLOG(5) << \"Running C++ API: \" << \"{}\";
199 200 201
 // Before log info
{}
 // Forward API Call
202 203
{}
  // Check NaN and Inf if needed
204
{}
205
  // Get Outputs
206
{}
207
  // Get Output AutoGradMeta
208
{}
209 210
  bool trace_backward = egr::Controller::Instance().HasGrad();
  bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
211

212 213 214 215
  // Check Inplace if needed
{}{}
  // Node Creation
{}
J
Jiabin Yang 已提交
216 217 218 219

  VLOG(4) << \"Finish AD API: {}";
  // LOG IF DEBUG
  {}
220 221
  // Returns
  return {};
222
}}
223
"""
224

225
AFTER_LOG_PRINT_TEMPLATE = """
J
Jiabin Yang 已提交
226 227 228 229 230 231 232
  if(VLOG_IS_ON(4)){{
      const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s],  Output: [%s] }} \";
      {}
      VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str);
  }}
"""

233
BEFORE_LOG_PRINT_TEMPLATE = """
234 235 236 237 238 239
  if(VLOG_IS_ON(3)){{
      const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s]}} \";
      {}
      VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
  }}
"""
240

241
FORWARD_ONLY_FUNCTION_TEMPLATE = """
242
{} {}({}) {{
J
Jiabin Yang 已提交
243
  VLOG(3) << \"Running AD API: \" << \"{}\";
244 245 246 247
  // Dygraph Record Event
{}
  // AMP Logic
{}
248 249
  // Layout autotune
{}
J
Jiabin Yang 已提交
250
  VLOG(5) << \"Running C++ API: \" << \"{}\";
251 252 253
  // Before log info
{}
  // Forward API Call
254 255 256
{}
  // Get Outputs
{}
J
Jiabin Yang 已提交
257 258 259
  VLOG(4) << \"Finish AD API: {}";
  // LOG IF DEBUG
  {}
260 261 262 263 264
  // Returns
  return {};
}}
"""

265
FORWARD_BODY_TEMPLATE = """  if(require_any_grad) {{
266
{}
267 268 269
    egr::EagerUtils::PassStopGradient({});

    // Node Construction
270
{}
271
    // SetAttributes if needed
272
{}
273
    // Set TensorWrappers for Forward Inputs if needed
274
{}
275
    // SetGradOutMeta & SetEdges
276
{}
277
    // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
278 279 280
{}
{}
{}
281
{}
282
    // Set TensorWrappers for Forward Outputs if needed
283
{}
284
  }}
285
"""
286

287
HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """  if(trace_backward) {{
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
{}
    // Node Construction
{}
    // SetAttributes if needed
{}
    // Set TensorWrappers for Forward Inputs if needed
{}
    // SetGradOutMeta & SetEdges
{}
    // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
    // Set TensorWrappers for Forward Outputs if needed
{}
  }}
"""

307
NAMESPACE_WRAPPER_TEMPLATE = """
308 309 310
namespace {} {{
    {}
}}
311
"""
312

313
NODE_CC_FILE_TEMPLATE = """
314 315 316 317 318 319
#include "glog/logging.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
320
#include "paddle/fluid/platform/profiler/event_tracing.h"
321 322 323
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
324
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
325
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
326
#include "paddle/fluid/eager/nan_inf_utils.h"
327
#include "paddle/phi/api/include/sparse_api.h"
328
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
329
DECLARE_bool(check_nan_inf);
330 331 332
{}
"""

333
NODE_H_FILE_TEMPLATE = """
334 335 336
#pragma once
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/eager/grad_node_info.h"
W
Weilong Wu 已提交
337
#include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h"
338

339 340
{}
"""
341

342
FORWARD_CC_FILE_TEMPLATE = """
343 344 345
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
346
#include "paddle/fluid/eager/eager_layout_auto_tune.h"
347
#include "paddle/phi/api/include/strings_api.h"
348 349 350
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
351 352
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
353
#include "paddle/phi/backends/gpu/gpu_info.h"
354
#include "paddle/fluid/eager/nan_inf_utils.h"
355
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
356
DECLARE_bool(check_nan_inf);
357 358
{}
{}
359 360
"""

361
FORWARD_H_FILE_TEMPLATE = """
362 363 364 365 366 367 368
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/phi/api/all.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
369
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
W
Weilong Wu 已提交
370

371
using CPUPlace = phi::CPUPlace;
372 373 374
{}
{}
"""
375

376
CORE_OPS_INFO_TEMPLATE = """
377
std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info = {{
378 379
    {}
}};
380
std::unordered_map<std::string, std::vector<std::string>> core_ops_args_type_info = {{
381 382
    {}
}};
383
std::unordered_map<std::string, std::vector<std::string>> core_ops_returns_info = {{
384 385 386 387
    {}
}};

"""
388

389
CORE_OPS_DECLARATION_TEMPLATE = """
390 391 392
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_returns_info;
393 394 395

"""

396
CHECK_INPLACE_TEMPLATE = """
397
  egr::EagerUtils::CheckInplace({}, {}, require_any_grad);
398 399
"""

400
BUMP_INPLACE_VERSION_TEMPLATE = """
401 402 403
  // Bump Inplace Version
  {}.bump_inplace_version();
  VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";
404 405
"""

406
AMP_LOGIC_TEMPLATE = """  if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
407 408 409 410 411 412 413 414 415
    VLOG(5) << "Check and Prepare For AMP";
    {}
    paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
    {}
    {}
    {}
    {{
      paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), paddle::imperative::AmpLevel::O0);
      {}
416
    }}
417
  }}
418
"""
419
LAYOUT_LOGIC_TEMPLATE = """
420
  if (egr::Controller::Instance().UseLayoutAutoTune()) {{
421 422
    paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> tensors_vector = {};
    {}
423
    {}
424 425
    VLOG(5) << "Check and Prepare For LAYOUT "<< op_name;
    paddle::imperative::LayoutAutotuneGuard guard(egr::Controller::Instance().GetCurrentTracer(), false);
426 427 428 429 430 431
    {}
    {}
    // Returns
    return {};
  }}
"""
432
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = """
433 434
  paddle::optional<paddle::experimental::Tensor> {}_optional;
  if({}.initialized()) {}_optional = paddle::make_optional<paddle::experimental::Tensor>({});
435 436
"""

437
CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = """
438 439
  paddle::optional<paddle::experimental::Tensor> {}_optional;
  if( {}.impl() ) {}_optional = paddle::make_optional<paddle::experimental::Tensor>({});
H
hong 已提交
440 441
"""

442
CREATE_RECOVER_OPTIONAL_VECTOR_TENSOR_TEMPLATE = """
443 444 445 446
  paddle::optional<std::vector<paddle::experimental::Tensor>> {}_optional;
  if( !{}.empty() ) {}_optional = paddle::make_optional<std::vector<paddle::experimental::Tensor>>({});
"""

447
CHECK_BACKWARD_INPLACE_TEMPLATE = """
448 449 450 451 452 453 454 455
  bool can_be_inplaced = false;
  if ({}.initialized()) {{
    VLOG(10) << {}.name() << "({}) use_count: " << {}.impl().use_count();
    if ({}.impl().use_count() == 1 || ({}.impl().use_count() == 2 && {}.impl().get() == {}.impl().get())) {{
      can_be_inplaced = true;
    }}
  }}"""

456
CHECK_NAN_AND_INF_TEMPLATE = """  if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
457 458
"""

459
inplace_optional_out_type_map = {
460 461
    "Tensor": "paddle::optional<paddle::experimental::Tensor>&",
    "std::vector<Tensor>": "paddle::optional<std::vector<paddle::experimental::Tensor>>&",
462 463
}

464

465 466 467 468 469
def ExtractForwardApiNameFormInvoke(invoke_config):
    api_name = invoke_config.split('(')[0]
    if api_name[-1] == '_':
        api_name = api_name[:-1]
    return re.search(
470 471
        r"(?P<api_name>[a-zA-Z0-9_]+)(?P<intermediate>_intermediate)?", api_name
    ).group('api_name')
472 473 474


def IsInvokeForwardApi(api_contents, forward_api_name_list):
475 476 477 478 479
    return (
        'invoke' in api_contents
        and ExtractForwardApiNameFormInvoke(api_contents['invoke'])
        in forward_api_name_list
    )
480 481


482 483 484 485 486 487 488 489 490
#######################
## Generator Helpers ##
#######################
def GenerateCoreOpInfoDeclaration():
    return CORE_OPS_DECLARATION_TEMPLATE


def GenerateCoreOpInfoDefinition():

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
    op_args_info_list = []
    for op_name, arg_list in core_ops_args_info.items():
        arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
        op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
        op_args_info_list.append(op_args_info)

    op_types_info_list = []
    for op_name, type_list in core_ops_args_type_info.items():
        type_str = ",".join(["\"" + v + "\"" for v in type_list])
        op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
        op_types_info_list.append(op_types_info)

    op_returns_info_list = []
    for op_name, return_list in core_ops_returns_info.items():
        return_str = ",".join(["\"" + v + "\"" for v in return_list])
        return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
        op_returns_info_list.append(return_types_info)

    op_args_info_str = "\n".join(op_args_info_list)
    op_types_info_str = "\n".join(op_types_info_list)
    op_returns_info_str = "\n".join(op_returns_info_list)

    core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
514 515
        op_args_info_str, op_types_info_str, op_returns_info_str
    )
516 517 518 519

    return core_ops_info_definition_str


520 521 522
#####################
## Generator Class ##
#####################
523
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
524 525 526 527 528 529 530
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
    ):
531 532
        self.forward_api_contents = forward_api_contents
        # Members from Parent:
533 534 535 536 537 538 539 540 541 542 543 544
        # self.namespace
        # self.forward_api_contents
        # self.forward_api_name
        # self.orig_forward_inputs_list
        # self.orig_forward_attrs_list
        # self.orig_forward_returns_list
        # self.forward_inputs_position_map
        # self.forward_outputs_position_map
        # self.optional_inputs
        # self.no_need_buffers
        # self.intermediate_outputs
        # self.forward_inplace_map
545 546
        FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)

547
        self.forward_apis_dict = forward_apis_dict
548 549 550 551 552 553
        self.grad_api_contents = grad_api_contents

        # Raw Contents
        self.backward_forward_str = ""
        self.backward_api_name = ""

554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
        self.forward_attrs_list = (
            []
        )  # [ [attr_name, attr_type, default_value, orig_position], ...]
        self.forward_inputs_list = (
            []
        )  # [ [arg_name, arg_type, orig_position], ...]
        self.forward_returns_list = (
            []
        )  # [ [ret_name, ret_type, orig_position], ...]

        self.backward_attrs_list = (
            []
        )  # [ [attr_name, attr_type, default_value, orig_position], ...]
        self.backward_inputs_list = (
            []
        )  # [ [arg_name, arg_type, orig_position], ...]
        self.backward_returns_list = (
            []
        )  # [ [ret_name, ret_type, orig_position], ...]
573 574

        # SlotNameMatched Backward Data
575 576 577 578 579 580 581 582 583 584 585
        self.backward_forward_inputs_map = (
            {}
        )  # { "name" : [type, is_fwd_input, orig_position] ...}
        self.backward_grad_inputs_map = (
            {}
        )  # { "name" : [type, fwd_position, orig_position] ...}
        self.backward_grad_outputs_map = (
            {}
        )  # { "name" : [type, fwd_position, orig_position] ...}

        self.backward_inplace_map = {}  # {name : name, ...}
586 587 588

    def ParseBackwardInplaceInfo(self):
        grad_api_contents = self.grad_api_contents
589 590
        if 'inplace' not in grad_api_contents.keys():
            return
591 592 593 594

        inplace_map_str = grad_api_contents['inplace']
        self.backward_inplace_map = ParseYamlInplaceInfo(inplace_map_str)

595 596 597 598
    def DygraphYamlValidationCheck(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

599 600
        assert (
            'op' in forward_api_contents.keys()
601
        ), "Unable to find \"op\" in ops.yaml"
602 603
        assert (
            'args' in forward_api_contents.keys()
C
Chen Weihang 已提交
604
        ), "Unable to find \"args\" in ops.yaml"
605 606
        assert (
            'output' in forward_api_contents.keys()
C
Chen Weihang 已提交
607
        ), "Unable to find \"output\" in ops.yaml"
608

609
        if grad_api_contents is not None:
610 611
            assert (
                'backward' in forward_api_contents.keys()
C
Chen Weihang 已提交
612
            ), "Unable to find \"backward\" in ops.yaml"
613 614
            assert (
                'args' in grad_api_contents.keys()
615
            ), "Unable to find \"args\" in backward.yaml"
616 617
            assert (
                'output' in grad_api_contents.keys()
618
            ), "Unable to find \"output\" in backward.yaml"
619 620
            assert (
                'forward' in grad_api_contents.keys()
621
            ), "Unable to find \"forward\" in backward.yaml"
622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637

    def ForwardsValidationCheck(self):
        forward_inputs_list = self.forward_inputs_list
        forward_attrs_list = self.forward_attrs_list
        forward_returns_list = self.forward_returns_list

        orig_forward_inputs_list = self.orig_forward_inputs_list
        orig_forward_attrs_list = self.orig_forward_attrs_list
        orig_forward_returns_list = self.orig_forward_returns_list

        for i in range(len(forward_inputs_list)):
            forward_input_type = forward_inputs_list[i][1]
            forward_input_pos = forward_inputs_list[i][2]
            orig_input_type = orig_forward_inputs_list[i][1]
            orig_input_pos = orig_forward_inputs_list[i][2]

638
            assert forward_input_type == orig_input_type, AssertMessage(
639 640
                forward_input_type, orig_input_type
            )
641
            assert forward_input_pos == orig_input_pos, AssertMessage(
642 643
                forward_input_pos, orig_input_pos
            )
644 645 646 647 648 649

        for i in range(len(forward_attrs_list)):
            orig_attr_type = orig_forward_attrs_list[i][1]
            orig_attr_pos = orig_forward_attrs_list[i][3]
            forward_attr_type = forward_attrs_list[i][1]
            forward_attr_pos = forward_attrs_list[i][3]
650
            assert orig_attr_type == forward_attr_type, AssertMessage(
651 652
                orig_attr_type, forward_attr_type
            )
653
            assert orig_attr_pos == forward_attr_pos, AssertMessage(
654 655
                orig_attr_pos, forward_attr_pos
            )
656 657 658 659 660 661 662

        for i in range(len(forward_returns_list)):
            orig_return_type = orig_forward_returns_list[i][1]
            orig_return_pos = orig_forward_returns_list[i][2]
            forward_return_type = forward_returns_list[i][1]
            forward_return_pos = forward_returns_list[i][2]

663
            assert orig_return_type == forward_return_type, AssertMessage(
664 665
                orig_return_type, forward_return_type
            )
666
            assert orig_return_pos == forward_return_pos, AssertMessage(
667 668
                orig_return_pos, forward_return_pos
            )
669 670 671 672 673 674 675

        # Check Order: Inputs, Attributes
        max_input_position = -1
        for _, _, pos in forward_inputs_list:
            max_input_position = max(max_input_position, pos)

        for _, _, _, pos in forward_attrs_list:
676
            assert pos > max_input_position, AssertMessage(
677 678
                pos, max_input_position
            )
679 680 681 682 683 684 685 686 687 688 689 690 691

    def BackwardValidationCheck(self):
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_attrs_list = self.backward_attrs_list

        # Check Order: TensorWrappers, GradTensors, Attributes
        max_fwd_input_position = -1
        for _, (_, _, pos) in backward_forward_inputs_map.items():
            max_fwd_input_position = max(max_fwd_input_position, pos)

        max_grad_tensor_position = -1
        for _, (_, _, pos) in backward_grad_inputs_map.items():
692
            assert pos > max_fwd_input_position, AssertMessage(
693 694
                pos, max_grad_tensor_position
            )
695 696 697 698
            max_grad_tensor_position = max(max_grad_tensor_position, pos)

        max_attr_position = -1
        for _, _, _, pos in backward_attrs_list:
699
            assert pos > max_grad_tensor_position, AssertMessage(
700 701
                pos, max_grad_tensor_position
            )
702 703 704 705 706 707 708 709 710 711 712
            max_attr_position = max(max_attr_position, pos)

    def IntermediateValidationCheck(self):
        intermediate_outputs = self.intermediate_outputs
        forward_returns_list = self.forward_returns_list
        """
        Check whether intermediate_outputs are positioned
        at the very end of forward_returns_list
        """
        intermediate_positions = range(
            len(forward_returns_list) - len(intermediate_outputs),
713 714
            len(forward_returns_list),
        )
715 716
        for ret_name, _, pos in forward_returns_list:
            if ret_name in intermediate_outputs:
717
                assert pos in intermediate_positions, AssertMessage(
718 719
                    pos, intermediate_positions
                )
720 721 722 723 724 725 726 727 728 729

    def CollectBackwardInfo(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

        self.backward_api_name = forward_api_contents['backward']
        self.backward_forward_str = grad_api_contents['forward']
        backward_args_str = grad_api_contents['args']
        backward_returns_str = grad_api_contents['output']

730 731 732 733 734
        (
            self.backward_inputs_list,
            self.backward_attrs_list,
            self.backward_returns_list,
        ) = ParseYamlBackward(backward_args_str, backward_returns_str)
735

736 737 738 739
    def CollectForwardInfoFromBackwardContents(self):

        backward_forward_str = self.backward_forward_str

740 741 742 743 744
        (
            self.forward_inputs_list,
            self.forward_attrs_list,
            self.forward_returns_list,
        ) = ParseYamlForwardFromBackward(backward_forward_str)
745

746
    def CollectForwardInfoFromYamlForward(self):
747 748 749 750 751 752 753 754 755
        (
            self.forward_inputs_list,
            self.forward_attrs_list,
            self.forward_returns_list,
        ) = ParseYamlForwardFromBackward(
            self.forward_api_contents['args']
            + " -> "
            + self.forward_api_contents['output']
        )
756

757 758 759 760 761 762 763 764 765 766 767 768 769 770
    def SlotNameMatching(self):
        backward_inputs_list = self.backward_inputs_list
        backward_returns_list = self.backward_returns_list
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map

        for backward_input in backward_inputs_list:
            backward_input_name = backward_input[0]
            backward_input_type = backward_input[1]
            backward_input_pos = backward_input[2]

            backward_fwd_name = FindForwardName(backward_input_name)
            if backward_fwd_name:
                # Grad Input
771 772 773 774 775
                assert (
                    backward_fwd_name in forward_outputs_position_map.keys()
                ), AssertMessage(
                    backward_fwd_name, forward_outputs_position_map.keys()
                )
776
                matched_forward_output_type = forward_outputs_position_map[
777 778
                    backward_fwd_name
                ][0]
779
                matched_forward_output_pos = forward_outputs_position_map[
780 781
                    backward_fwd_name
                ][1]
782 783

                self.backward_grad_inputs_map[backward_input_name] = [
784 785 786
                    backward_input_type,
                    matched_forward_output_pos,
                    backward_input_pos,
787 788 789 790 791
                ]
            else:
                # TensorWrapper Input
                if backward_input_name in forward_inputs_position_map.keys():
                    tensor_wrapper_type = forward_inputs_position_map[
792 793
                        backward_input_name
                    ][0]
794
                    self.backward_forward_inputs_map[backward_input_name] = [
795 796 797
                        backward_input_type,
                        True,
                        backward_input_pos,
798 799 800 801
                    ]

                elif backward_input_name in forward_outputs_position_map.keys():
                    tensor_wrapper_type = forward_outputs_position_map[
802 803
                        backward_input_name
                    ][0]
804
                    self.backward_forward_inputs_map[backward_input_name] = [
805 806 807
                        backward_input_type,
                        False,
                        backward_input_pos,
808 809
                    ]
                else:
810 811 812
                    assert (
                        False
                    ), f"Cannot find {backward_input_name} in forward position map"
813 814 815 816 817 818 819

        for backward_output in backward_returns_list:
            backward_output_name = backward_output[0]
            backward_output_type = backward_output[1]
            backward_output_pos = backward_output[2]

            backward_fwd_name = FindForwardName(backward_output_name)
820 821 822 823 824 825 826 827
            assert (
                backward_fwd_name is not None
            ), f"Detected {backward_fwd_name} = None"
            assert (
                backward_fwd_name in forward_inputs_position_map.keys()
            ), AssertMessage(
                backward_fwd_name, forward_inputs_position_map.keys()
            )
828 829

            matched_forward_input_type = forward_inputs_position_map[
830 831
                backward_fwd_name
            ][0]
832
            matched_forward_input_pos = forward_inputs_position_map[
833 834
                backward_fwd_name
            ][1]
835 836

            self.backward_grad_outputs_map[backward_output_name] = [
837 838 839
                backward_output_type,
                matched_forward_input_pos,
                backward_output_pos,
840 841
            ]

842 843 844 845 846 847 848 849
    def GetPassStopGradientArgsList(self, forward_outputs_position_map):
        pass_stop_gradient_args_list = ["false"]
        for name, (_, _) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
            pass_stop_gradient_args_list.append(output_autograd_meta_name)
        pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
        return pass_stop_gradient_args_str

850
    def GenerateNodeCreationCodes(self, for_backward=False):
851 852 853 854
        forward_api_name = self.forward_api_name
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
855
        backward_forward_inputs_map = self.backward_forward_inputs_map
856 857
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
858
        backward_attrs_list = self.backward_attrs_list
859
        optional_inputs = self.optional_inputs
860

861
        # Pass Stop Gradient Args
862
        pass_stop_gradient_args_str = self.GetPassStopGradientArgsList(
863 864
            forward_outputs_position_map
        )
865

866
        # Node Construction
867 868
        num_backward_inputs = len(forward_outputs_position_map.keys())
        num_backward_outputs = len(forward_inputs_position_map.keys())
869
        grad_node_name = GetGradNodeName(self.backward_api_name)
870 871 872

        # Helper
        indent = GetIndent(2)
873 874 875 876 877
        # NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar
        # which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared
        # request MEMALIGN for allocation (Maybe).
        # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
        # and https://github.com/MRtrix3/mrtrix3/issues/957
878
        node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
879 880 881 882 883 884 885 886 887

        # SetAttributes
        set_attributes_list = []
        forward_attrs_name_set = set()
        for name, _, _, _ in forward_attrs_list:
            forward_attrs_name_set.add(name)

        for name, _, default_val_attr, _ in backward_attrs_list:
            if name in forward_attrs_name_set:
888 889 890
                set_attributes = (
                    f"{indent}grad_node->SetAttribute{name}({name});"
                )
891
            else:
892
                set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});"
893 894
            set_attributes_list.append(set_attributes)
        set_attributes_str = "\n".join(set_attributes_list)
895

896
        # SetTensorWrappers
897 898
        set_input_tensor_wrappers_list = []
        set_output_tensor_wrappers_list = []
899
        num_fwd_outputs = len(forward_outputs_position_map.keys())
900 901 902 903 904 905
        for name, (
            atype,
            is_fwd_input,
            pos,
        ) in backward_forward_inputs_map.items():
            is_optional = name in optional_inputs
906

907 908
            if is_fwd_input:
                if is_optional:
909
                    set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
910
                else:
911 912 913
                    set_tensor_wrappers = (
                        f"{indent}grad_node->SetTensorWrapper{name}({name});"
                    )
914
                set_input_tensor_wrappers_list.append(set_tensor_wrappers)
915
            else:  # Forwad's output as backward's input
916 917
                if num_fwd_outputs > 1:
                    # Aligned with forward output position
918 919
                    assert (
                        name in forward_outputs_position_map.keys()
920
                    ), AssertMessage(name, forward_outputs_position_map.keys())
921

922
                if is_optional:
923
                    set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper{name}(*{name});"
924
                else:
925 926 927
                    set_tensor_wrappers = (
                        f"{indent}grad_node->SetTensorWrapper{name}({name});"
                    )
928 929
                set_output_tensor_wrappers_list.append(set_tensor_wrappers)
        set_input_tensor_wrappers_str = "\n".join(
930 931
            set_input_tensor_wrappers_list
        )
932
        set_output_tensor_wrappers_str = "\n".join(
933 934
            set_output_tensor_wrappers_list
        )
935

936
        # SetGradOutMeta & SetEdges
937
        grad_node_out_list = []
938 939 940
        set_grad_out_meta_list = []
        set_edges_list = []
        for name, (_, pos) in forward_inputs_position_map.items():
941 942
            # Has corresponding grad output
            has_corresponding_grad_output = False
943 944 945 946 947
            for _, (
                _,
                corresponding_pos,
                _,
            ) in backward_grad_outputs_map.items():
948 949 950 951 952
                if pos == corresponding_pos:
                    has_corresponding_grad_output = True
            if not has_corresponding_grad_output:
                continue

953
            grad_node_out_list.append(name)
954
            is_optional = name in self.optional_inputs
H
hong 已提交
955
            if is_optional:
956
                set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
H
hong 已提交
957
            else:
958 959 960
                set_grad_out_meta = (
                    f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
                )
961

962 963
            set_grad_out_meta_list.append(set_grad_out_meta)
        set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
964

J
Jiabin Yang 已提交
965
        # SetOutRank & SetHistory & SetGradInMeta
966 967 968 969 970 971 972
        set_out_rank_list = []
        set_history_list = []
        set_grad_in_meta_list = []
        set_retain_grad_list = []
        num_outputs = len(forward_outputs_position_map.keys())
        for name, (_, pos) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
973 974 975 976 977 978 979
            set_out_rank = f"""{indent}if ({output_autograd_meta_name}) {{
{indent}  egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});
{indent}}}"""

            set_history = f"""{indent}if ({output_autograd_meta_name}) {{
{indent}  egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);
{indent}}}"""
980

981 982 983 984 985 986
            set_grad_in_meta = (
                f"{indent}grad_node->SetGradInMeta({name}, {pos});"
            )
            set_retain_grad = (
                f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
            )
987

988 989 990 991
            set_out_rank_list.append(set_out_rank)
            set_history_list.append(set_history)
            set_grad_in_meta_list.append(set_grad_in_meta)
            set_retain_grad_list.append(set_retain_grad)
992

993 994 995 996
        set_out_rank_str = "\n".join(set_out_rank_list)
        set_history_str = "\n".join(set_history_list)
        set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
        set_retain_grad_str = "\n".join(set_retain_grad_list)
997

998
        node_event_name = forward_api_name + " node_creation"
C
chenjian 已提交
999
        node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
1000 1001
        if not for_backward:
            self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
                node_creation_event_str,
                pass_stop_gradient_args_str,
                node_construction_str,
                set_attributes_str,
                set_input_tensor_wrappers_str,
                set_grad_out_meta_str,
                set_out_rank_str,
                set_history_str,
                set_grad_in_meta_str,
                set_retain_grad_str,
                set_output_tensor_wrappers_str,
            )
1014
        else:
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028
            self.node_creation_str = (
                HIHGER_ORDER_DERIVATIVE_VALUE_TEMPLATE.format(
                    node_creation_event_str,
                    node_construction_str,
                    set_attributes_str,
                    set_input_tensor_wrappers_str,
                    set_grad_out_meta_str,
                    set_out_rank_str,
                    set_history_str,
                    set_grad_in_meta_str,
                    set_retain_grad_str,
                    set_output_tensor_wrappers_str,
                )
            )
1029

1030
        self.grad_node_out_list = grad_node_out_list
1031

1032 1033 1034
    def run(self):
        # Basic Validation Check
        self.DygraphYamlValidationCheck()
1035

1036 1037 1038
        ##########################
        ## Parsing Raw Contents ##
        ##########################
1039 1040
        # Parse forward and backward inplace_map
        self.ParseForwardInplaceInfo()
1041 1042 1043 1044
        if self.grad_api_contents is not None:
            self.ParseBackwardInplaceInfo()
            # Parse no_need_buffer
            self.ParseNoNeedBuffer()
1045 1046 1047 1048 1049 1050 1051 1052

        # Parse optional_inputs
        self.ParseDispensable()

        # Parse intermediate_outputs
        self.ParseIntermediate()
        self.IntermediateValidationCheck()

1053 1054 1055
        if self.grad_api_contents is not None:
            # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
            self.CollectBackwardInfo()
1056

1057 1058 1059 1060 1061
            # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
            self.CollectForwardInfoFromBackwardContents()

        if self.is_forward_only:
            self.CollectForwardInfoFromYamlForward()
1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072

        # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
        self.CollectOriginalForwardInfo()

        # Forwards Validation Check
        self.ForwardsValidationCheck()

        #############################
        ## Process Parsed Contents ##
        #############################
        # Initialize forward_inputs_position_map, forward_outputs_position_map
1073 1074 1075
        self.DetermineForwardPositionMap(
            self.forward_inputs_list, self.forward_returns_list
        )
1076

1077 1078 1079 1080 1081
        if self.grad_api_contents is not None:
            # Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
            self.SlotNameMatching()
            # Backward Validation Check
            self.BackwardValidationCheck()
1082 1083 1084


class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
    ):
        DygraphFunctionGeneratorBase.__init__(
            self,
            forward_api_contents,
            grad_api_contents,
            forward_apis_dict,
            namespace,
        )
1099 1100 1101 1102

        # Generated Results
        self.forward_definition_str = ""
        self.forward_declaration_str = ""
1103

1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
    def GenerateForwardLayoutAutotune(
        self,
        forward_api_name,
        amp_tensors_vector_list,
        layout_tensors_vector_optional_list,
        layout_autotune_list_str,
        returns_type_str,
        returns_str,
        amp_inputs_call_args_str,
    ):
1114 1115 1116
        intermediate_outputs = self.intermediate_outputs
        forward_attrs_list = self.forward_attrs_list
        forward_outputs_position_map = self.forward_outputs_position_map
1117 1118 1119
        num_outputs = len(forward_outputs_position_map.keys()) - len(
            intermediate_outputs
        )
1120 1121
        # for layout autotune attr
        lightly_sensitive_attr = [
1122 1123 1124 1125 1126 1127 1128 1129
            'axis',
            'axes',
            'dim',
            'dims',
            'start',
            'end',
            'stop',
            'perm',
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
        ]
        heavily_sensitive_attr = ['data_format', 'data_layout']
        layout_autotune_attr = []
        layout_autotune_attr_code_list = []
        layout_autotune_attr_type_list = []
        layout_autotune_attr_code_list.append(
            f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");\n"
        )

        lightly_flag = False
        heavily_flag = False
        for name, atype, default_val, pos in forward_attrs_list:
            for attr_name in lightly_sensitive_attr:
1143 1144 1145
                if name.find(attr_name) != -1 and (
                    name not in layout_autotune_attr
                ):
1146 1147 1148 1149 1150 1151
                    lightly_flag = True
                    layout_autotune_attr.append(name)
                    layout_autotune_attr_type_list.append(atype)
            if lightly_flag is False:
                for attr_name in heavily_sensitive_attr:
                    if name.find(attr_name) != -1 and (
1152 1153
                        name not in layout_autotune_attr
                    ):
1154 1155 1156 1157 1158
                        layout_autotune_attr.append(name)
                        layout_autotune_attr_type_list.append(atype)
                        heavily_flag = True
        if len(layout_autotune_attr) == 0:
            layout_autotune_attr_code_list.append(
1159
                "auto transformer = egr::EagerLayoutAutotune(op_name, tensors_vector);\n"
1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
            )
        elif len(layout_autotune_attr) == 1:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}>(op_name, tensors_vector, &{layout_autotune_attr[0]});\n"
            )
        elif len(layout_autotune_attr) == 2:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}, {layout_autotune_attr_type_list[1]}>(op_name, tensors_vector, &{layout_autotune_attr[0]}, &{layout_autotune_attr[1]});\n"
            )
        else:
            layout_autotune_attr_code_list.append(
                f"auto transformer = egr::EagerLayoutAutotune<{layout_autotune_attr_type_list[0]}>(op_name, tensors_vector,&{layout_autotune_attr[0]});\n"
            )
        # Out tensor
        layout_inputs_call_args_str = amp_inputs_call_args_str
        forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
        layout_tmp_result_list = []
        layout_autotune_outs_list = []
        result_name = "api_result"
        if num_outputs == 1:
            result_name = returns_str
            layout_autotune_outs_list.append(
1182 1183
                f"transformer -> SetOutTensorLayout(&{returns_str});\n"
            )
1184 1185 1186 1187 1188 1189 1190 1191
        else:
            for name, (rtype, pos) in forward_outputs_position_map.items():
                if name in intermediate_outputs:
                    continue
                layout_autotune_outs_list.append(
                    f"    auto& {name} = std::get<{len(layout_tmp_result_list)}>(api_result);\n"
                )
                layout_autotune_outs_list.append(
1192 1193
                    f"    transformer -> SetOutTensorLayout(&{name});\n"
                )
1194 1195
                layout_tmp_result_list.append(f"{name}")

1196 1197 1198
        tensors_vector_list_str = (
            "{ " + ",".join(amp_tensors_vector_list) + " }"
        )
1199

1200
        if len(amp_tensors_vector_list) == 0:
1201 1202 1203 1204 1205 1206
            layout_logic_str = ""
        else:
            after_call_str = f"{returns_type_str} {result_name} = {forward_function_name}({layout_inputs_call_args_str});\n"
            layout_logic_str = LAYOUT_LOGIC_TEMPLATE.format(
                tensors_vector_list_str,
                "    ".join(layout_tensors_vector_optional_list),
1207 1208 1209 1210 1211 1212 1213
                "    ".join(layout_autotune_attr_code_list)
                + "    "
                + layout_autotune_list_str,
                after_call_str,
                "    ".join(layout_autotune_outs_list),
                returns_str,
            )
1214 1215 1216

        return layout_logic_str

1217
    def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
1218
        namespace = self.namespace
Z
zyfncg 已提交
1219 1220
        if self.forward_api_name[-1] == '_' and not is_inplaced:
            return
1221 1222 1223 1224 1225
        forward_api_name = (
            GetInplacedFunctionName(self.forward_api_name)
            if is_inplaced
            else self.forward_api_name
        )
1226

1227 1228 1229
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
1230 1231
        if not self.is_forward_only:
            backward_grad_outputs_map = self.backward_grad_outputs_map
1232

1233 1234
        optional_inputs = self.optional_inputs
        intermediate_outputs = self.intermediate_outputs
1235
        forward_inplace_map = self.forward_inplace_map if is_inplaced else {}
1236
        indent = GetIndent(1)
1237 1238 1239

        # Get Function Args
        num_inputs = len(forward_attrs_list) + len(
1240 1241
            forward_inputs_position_map.keys()
        )
1242 1243 1244
        inputs_args_definition_list = ["" for i in range(num_inputs)]
        inputs_args_declaration_list = ["" for i in range(num_inputs)]
        inputs_call_list = ["" for i in range(num_inputs)]
1245

1246 1247 1248 1249 1250
        amp_inputs_call_list = ["" for i in range(num_inputs)]
        amp_tensors_vector_list = []
        amp_tensors_vector_optional_list = []
        amp_autocast_list = []
        amp_autocast_optional_list = []
1251 1252 1253
        layout_autotune_list = []
        layout_autotune_optional_list = []
        layout_tensors_vector_optional_list = []
1254 1255
        for name, (ttype, pos) in forward_inputs_position_map.items():
            inputs_call_list[pos] = f"{name}"
1256
            amp_inputs_call_list[pos] = f"new_{name}"
1257
            is_optional = name in optional_inputs
1258 1259
            if IsPlainTensorType(ttype):
                if is_optional:
1260 1261 1262 1263 1264
                    if (
                        self.is_forward_only
                        and is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1265 1266 1267 1268
                    ):
                        arg_str = f"paddle::optional<paddle::experimental::Tensor>& {name}"
                    else:
                        arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
1269
                    amp_tensors_vector_optional_list.append(
1270
                        f"if ({name}) amp_tensors_vector.push_back({{ *{name} }});\n"
1271 1272
                    )
                    amp_autocast_optional_list.append(
1273
                        f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1274
                    )
1275 1276 1277 1278
                    layout_tensors_vector_optional_list.append(
                        f"if ({name}) tensors_vector.push_back({{ *{name} }});\n"
                    )
                    layout_autotune_optional_list.append(
1279
                        f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n"
1280
                    )
1281
                else:
1282 1283 1284 1285
                    if (
                        is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1286
                    ):
1287
                        arg_str = f"paddle::experimental::Tensor& {name}"
1288 1289
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
1290
                            f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1291
                        )
1292 1293
                    else:
                        arg_str = f"const paddle::experimental::Tensor& {name}"
1294 1295
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
1296
                            f"auto new_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1297
                        )
1298
                    layout_autotune_list.append(
1299
                        f"auto new_{name} = transformer->TransInTensor(\"{name}\", {name});\n"
1300
                    )
1301 1302
            else:
                assert IsVectorTensorType(ttype)
1303
                if is_optional:
1304 1305 1306 1307 1308
                    if (
                        self.is_forward_only
                        and is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1309 1310 1311 1312
                    ):
                        arg_str = f"paddle::optional<std::vector<paddle::experimental::Tensor>>& {name}"
                    else:
                        arg_str = f"const paddle::optional<std::vector<paddle::experimental::Tensor>>& {name}"
1313 1314 1315 1316
                    amp_tensors_vector_optional_list.append(
                        f"if ({name}) amp_tensors_vector.push_back( *{name} );\n"
                    )
                    amp_autocast_optional_list.append(
1317
                        f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1318
                    )
1319
                    layout_autotune_optional_list.append(
1320
                        f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n"
1321
                    )
1322
                else:
1323 1324 1325 1326
                    if (
                        is_inplaced
                        and forward_inplace_map
                        and name in forward_inplace_map.keys()
1327
                    ):
1328 1329 1330
                        arg_str = (
                            f"std::vector<paddle::experimental::Tensor>& {name}"
                        )
1331 1332
                    else:
                        arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
1333 1334
                    amp_tensors_vector_list.append(f"{name}")
                    amp_autocast_list.append(
1335
                        f"auto new_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
1336
                    )
1337
                    layout_autotune_list.append(
1338
                        f"auto new_{name} = transformer->TransInTensors(\"{name}\", {name});\n"
1339
                    )
1340 1341 1342 1343

            inputs_args_definition_list[pos] = arg_str
            inputs_args_declaration_list[pos] = arg_str

1344
        # forward attrs
1345 1346
        for name, atype, default_val, pos in forward_attrs_list:
            inputs_call_list[pos] = name
1347
            amp_inputs_call_list[pos] = name
1348 1349
            if default_val is not None:
                inputs_args_declaration_list[
1350 1351
                    pos
                ] = f"{atype} {name} = {default_val}"
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362
            else:
                inputs_args_declaration_list[pos] = f"{atype} {name}"
            inputs_args_definition_list[pos] = f"{atype} {name}"

        inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
        inputs_args_definition_str = ", ".join(inputs_args_definition_list)
        inputs_call_args_str = ", ".join(inputs_call_list)

        # Forward Full Logic
        function_name = forward_api_name
        if len(intermediate_outputs) > 0:
1363
            if is_inplaced:
1364 1365 1366
                function_name = (
                    GetIntermediateAPIFunctionName(forward_api_name[:-1]) + '_'
                )
1367 1368
            else:
                function_name = GetIntermediateAPIFunctionName(function_name)
1369

1370 1371 1372 1373
        api_out_type = "auto"
        if is_inplaced and len(forward_outputs_position_map) == 1:
            api_out_type = "auto&"
        forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
1374 1375 1376
        num_outputs = len(forward_outputs_position_map.keys()) - len(
            intermediate_outputs
        )
1377

1378
        # Check Nan and Inf
1379
        check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
1380 1381
            function_name, "api_result"
        )
1382

1383 1384 1385 1386
        # Get Outputs
        get_outputs_str = ""
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if num_outputs == 1 and len(intermediate_outputs) == 0:
1387
                get_outputs_str += f"{indent}auto& {name} = api_result;\n"
1388
            else:
1389 1390 1391
                get_outputs_str += (
                    f"{indent}auto& {name} = std::get<{pos}>(api_result);\n"
                )
1392 1393

        # Get return type list & outputs
1394 1395 1396 1397 1398
        returns_type_list = ["" for i in range(num_outputs)]
        returns_list = ["" for i in range(num_outputs)]
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if name in intermediate_outputs:
                continue
1399
            returns_list[pos] = f"{name}"
1400 1401

            if IsPlainTensorType(rtype):
1402 1403 1404 1405
                if (
                    is_inplaced
                    and forward_inplace_map
                    and name in forward_inplace_map.values()
1406
                ):
1407
                    ind = list(forward_inplace_map.values()).index(name)
1408 1409 1410 1411
                    if (
                        list(forward_inplace_map.keys())[ind]
                        in self.optional_inputs
                    ):
1412
                        returns_type_list[pos] = inplace_optional_out_type_map[
1413 1414
                            rtype
                        ]
1415 1416
                    else:
                        returns_type_list[pos] = "paddle::experimental::Tensor&"
1417 1418
                else:
                    returns_type_list[pos] = "paddle::experimental::Tensor"
1419 1420
            else:
                assert IsVectorTensorType(rtype)
1421 1422 1423 1424
                if (
                    is_inplaced
                    and forward_inplace_map
                    and name in forward_inplace_map.values()
1425
                ):
1426
                    ind = list(forward_inplace_map.values()).index(name)
1427 1428 1429 1430
                    if (
                        list(forward_inplace_map.keys())[ind]
                        in self.optional_inputs
                    ):
1431
                        returns_type_list[pos] = inplace_optional_out_type_map[
1432 1433
                            rtype
                        ]
1434 1435
                    else:
                        returns_type_list[
1436 1437
                            pos
                        ] = "std::vector<paddle::experimental::Tensor>&"
1438 1439
                else:
                    returns_type_list[
1440 1441
                        pos
                    ] = "std::vector<paddle::experimental::Tensor>"
1442 1443 1444 1445 1446 1447 1448 1449

        if num_outputs == 1:
            returns_str = returns_list[0]
            returns_type_str = returns_type_list[0]
        else:
            returns_type_str = ", ".join(returns_type_list)
            returns_type_str = f"std::tuple<{returns_type_str}>"
            returns_str = ", ".join(returns_list)
1450
            returns_str = f"{returns_type_str}{{{returns_str}}}"
1451

1452
        # Node Creation Pre-Processing
1453
        inputs_names = []
1454
        if not self.is_forward_only:
Z
zyfncg 已提交
1455
            # 1. Get Input AutoGradMeta
1456 1457 1458 1459 1460
            inputs_autograd_meta_list = []
            compute_require_grad_args_list = ["trace_backward"]
            for name, (ttype, pos) in forward_inputs_position_map.items():
                # Has corresponding grad output
                has_corresponding_grad_output = False
1461 1462 1463 1464 1465
                for _, (
                    _,
                    corresponding_pos,
                    _,
                ) in backward_grad_outputs_map.items():
Z
zyfncg 已提交
1466 1467
                    if pos == corresponding_pos:
                        has_corresponding_grad_output = True
1468 1469 1470 1471 1472 1473 1474 1475
                if (
                    has_corresponding_grad_output
                    or (
                        name in forward_inplace_map
                        and forward_api_name not in inplace_check_blacklist
                    )
                    or self.is_forward_only
                ):
1476 1477 1478 1479 1480
                    input_autograd_meta_name = GetAutoGradMetaName(name)
                    if IsPlainTensorType(ttype):
                        input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
                    else:
                        assert IsVectorTensorType(ttype)
1481 1482 1483
                        input_autograd_meta_vec_name = (
                            GetAutoGradMetaVectorName(name)
                        )
1484 1485 1486 1487
                        input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
                        input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
                    inputs_autograd_meta_list.append(input_autograd_meta)
                    compute_require_grad_args_list.append(
1488 1489
                        input_autograd_meta_name
                    )
1490 1491 1492

            inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
            compute_require_grad_args_str = ",".join(
1493 1494
                compute_require_grad_args_list
            )
1495

Z
zyfncg 已提交
1496
            # 2. Get Output AutoGradMeta
1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509
            outputs_autograd_meta_list = []
            num_fwd_outputs = len(forward_outputs_position_map.keys())

            for name, (rtype, pos) in forward_outputs_position_map.items():
                output_autograd_meta_name = GetAutoGradMetaName(name)
                output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
                if num_fwd_outputs == 1:
                    if IsPlainTensorType(rtype):
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
                    else:
                        assert IsVectorTensorType(rtype)
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1510
                else:
1511 1512 1513 1514 1515 1516 1517
                    # Tuple api_result
                    if IsPlainTensorType(rtype):
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
                    else:
                        assert IsVectorTensorType(rtype)
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1518

1519 1520
                outputs_autograd_meta_list.append(output_autograd_meta)
            outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
1521

Z
zyfncg 已提交
1522 1523 1524 1525 1526 1527 1528
            # 3. Check Inplace
            check_inplace_str = ""
            bump_inplace_version_str = ""
            if is_inplaced:
                for inplace_name in forward_inplace_map.keys():
                    if forward_api_name not in inplace_check_blacklist:
                        inplace_autograd_meta_name = GetAutoGradMetaName(
1529 1530
                            inplace_name
                        )
Z
zyfncg 已提交
1531
                        check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
1532 1533 1534 1535 1536 1537 1538
                            inplace_name, inplace_autograd_meta_name
                        )
                    bump_inplace_version_str += (
                        BUMP_INPLACE_VERSION_TEMPLATE.format(
                            inplace_name, inplace_name
                        )
                    )
Z
zyfncg 已提交
1539 1540

            # Node Creation
1541 1542
            self.GenerateNodeCreationCodes()
            node_creation_str = self.node_creation_str
1543

1544
        dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
J
Jiabin Yang 已提交
1545
        forward_ad_function_name = GetDygraphForwardFunctionName(
1546 1547
            forward_api_name
        )
1548

1549
        # Forward amp logic
1550 1551 1552 1553 1554 1555
        kernel_trans2_op_name_str = (
            f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
        )
        amp_tensors_vector_list_str = (
            "{ " + ",".join(amp_tensors_vector_list) + " }"
        )
1556
        amp_tensors_vector_optional_list_str = "    ".join(
1557 1558
            amp_tensors_vector_optional_list
        )
1559
        amp_get_dst_dtype_str = "auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
1560 1561 1562 1563 1564
        amp_autocast_list_str = (
            "    ".join(amp_autocast_list)
            + "    "
            + "    ".join(amp_autocast_optional_list)
        )
1565
        amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
1566 1567 1568
        amp_call_str = (
            f"return {forward_ad_function_name}({amp_inputs_call_args_str});"
        )
1569
        if is_inplaced or (forward_api_name == "cast"):
J
Jiabin Yang 已提交
1570
            amp_logic_str = "\n VLOG(5) << \" No AMP for {} because it is a inplace or cast api. \"; ".format(
1571 1572
                forward_ad_function_name
            )
1573 1574
        else:
            amp_logic_str = AMP_LOGIC_TEMPLATE.format(
1575 1576 1577 1578 1579 1580 1581
                kernel_trans2_op_name_str,
                amp_tensors_vector_list_str,
                amp_tensors_vector_optional_list_str,
                amp_get_dst_dtype_str,
                amp_autocast_list_str,
                amp_call_str,
            )
1582

1583
        # Forward layout autotune
1584
        layout_autotune_list_str = "    ".join(
1585 1586
            layout_autotune_list
        ) + "    ".join(layout_autotune_optional_list)
1587
        layout_logic_str = self.GenerateForwardLayoutAutotune(
1588 1589 1590 1591 1592 1593 1594 1595
            forward_api_name,
            amp_tensors_vector_list,
            layout_tensors_vector_optional_list,
            layout_autotune_list_str,
            returns_type_str,
            returns_str,
            amp_inputs_call_args_str,
        )
1596

J
Jiabin Yang 已提交
1597 1598 1599 1600
        # For inputs outputs prepare for logging
        var_str = f"\n{indent}  std::string input_str = \"\";"
        var_str += f"\n{indent}  std::string output_str = \"\";"
        for name, (ttype, pos) in forward_inputs_position_map.items():
J
Jiabin Yang 已提交
1601
            var_str += f"\n{indent}  const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";"
J
Jiabin Yang 已提交
1602 1603
            var_str += f"\n{indent}  std::string input_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));"
            var_str += f"\n{indent}  input_str += input_{name}_str; "
1604 1605

        before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str)
J
Jiabin Yang 已提交
1606
        for name, (ttype, pos) in forward_outputs_position_map.items():
J
Jiabin Yang 已提交
1607
            var_str += f"\n{indent}  const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";"
J
Jiabin Yang 已提交
1608 1609 1610
            var_str += f"\n{indent}  std::string output_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));"
            var_str += f"\n{indent}  output_str += output_{name}_str; "

1611
        log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str)
J
Jiabin Yang 已提交
1612

1613
        # Generate forward_definition_str and forward_declaration_str
Z
zyfncg 已提交
1614 1615
        if self.is_forward_only:
            if len(amp_tensors_vector_list) == 0:
J
Jiabin Yang 已提交
1616
                amp_logic_str = "\n VLOG(7) << \" No AMP for {} because it has no input. \"; ".format(
1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636
                    forward_ad_function_name
                )
            self.forward_definition_str += (
                FORWARD_ONLY_FUNCTION_TEMPLATE.format(
                    returns_type_str,
                    forward_ad_function_name,
                    inputs_args_definition_str,
                    forward_api_name,
                    dygraph_event_str,
                    amp_logic_str,
                    layout_logic_str,
                    forward_api_name,
                    before_log_str,
                    forward_call_str,
                    get_outputs_str,
                    forward_api_name,
                    log_str,
                    returns_str,
                )
            )
Z
zyfncg 已提交
1637
        else:
1638
            self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
                returns_type_str,
                forward_ad_function_name,
                inputs_args_definition_str,
                forward_api_name,
                dygraph_event_str,
                amp_logic_str,
                layout_logic_str,
                inputs_autograd_meta_str,
                forward_api_name,
                before_log_str,
                forward_call_str,
                check_nan_inf_str,
                get_outputs_str,
                outputs_autograd_meta_str,
                compute_require_grad_args_str,
                check_inplace_str,
                bump_inplace_version_str,
                node_creation_str,
                forward_api_name,
                log_str,
                returns_str,
            )
1661

J
Jiabin Yang 已提交
1662
        self.forward_declaration_str += f"{returns_type_str} {forward_ad_function_name}({inputs_args_declaration_str});\n"
1663 1664 1665 1666 1667 1668

    def GenerateInplacedForwardDygraphFunctions(self):
        # Inplaced Version Dygraph Function Generation
        forward_api_name = self.forward_api_name
        forward_api_contents = self.forward_api_contents

1669 1670 1671
        if (
            forward_api_name != "sum"
            and "inplace" in forward_api_contents.keys()
1672
        ):
1673 1674
            # Function Definition and Declaration Generation
            self.GenerateForwardDefinitionAndDeclaration(is_inplaced=True)
1675 1676 1677
            self.UpdateCoreOpsInformation(is_inplaced=True)

    def UpdateCoreOpsInformation(self, is_inplaced):
1678 1679 1680 1681 1682
        forward_api_name = (
            GetInplacedFunctionName(self.forward_api_name)
            if is_inplaced
            else self.forward_api_name
        )
1683 1684 1685 1686
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list

1687 1688 1689
        num_args = len(forward_inputs_position_map.keys()) + len(
            forward_attrs_list
        )
1690 1691
        num_returns = len(forward_outputs_position_map.keys())

1692 1693 1694 1695
        fwd_api_name = "" + forward_api_name
        core_ops_returns_info[fwd_api_name] = ["" for i in range(num_returns)]
        core_ops_args_info[fwd_api_name] = ["" for i in range(num_args)]
        core_ops_args_type_info[fwd_api_name] = ["" for i in range(num_args)]
1696

1697
        for name, (ttype, pos) in forward_inputs_position_map.items():
1698
            core_ops_args_info[fwd_api_name][pos] = name
1699
            if IsPlainTensorType(ttype):
1700
                core_ops_args_type_info[fwd_api_name][pos] = "tensor"
1701 1702
            else:
                assert IsVectorTensorType(ttype)
1703
                core_ops_args_type_info[fwd_api_name][pos] = "list"
1704 1705

        for name, _, _, pos in forward_attrs_list:
1706
            core_ops_args_info[fwd_api_name][pos] = name
1707 1708

        for name, (ttype, pos) in forward_outputs_position_map.items():
1709
            core_ops_returns_info[fwd_api_name][pos] = name
1710 1711

    def run(self):
1712
        super().run()
1713

1714 1715 1716
        #####################
        ## Code Generation ##
        #####################
1717 1718 1719

        # Definition And Declaration
        self.GenerateForwardDefinitionAndDeclaration(is_inplaced=False)
1720

1721
        self.UpdateCoreOpsInformation(is_inplaced=False)
1722

1723
        self.GenerateInplacedForwardDygraphFunctions()
1724 1725


1726
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741
    def __init__(
        self,
        forward_api_contents,
        grad_api_contents,
        forward_apis_dict,
        namespace,
        next_grad_api_contents=None,
    ):
        DygraphFunctionGeneratorBase.__init__(
            self,
            forward_api_contents,
            grad_api_contents,
            forward_apis_dict,
            namespace,
        )
1742

1743
        # Record name mapping from forward_var_name to grad_var_names
1744 1745
        self.to_next_grad_name_mapping = {}  # {name : name}

1746 1747 1748
        # Generated Results
        self.node_declaration_str = ""
        self.node_definition_str = ""
1749
        self.next_grad_api_contents = next_grad_api_contents
1750

1751 1752 1753 1754 1755 1756
    def TransformToNextGradName(self, string):
        name_mapping = self.to_next_grad_name_mapping
        if string in name_mapping.keys():
            return name_mapping[string]
        return string

1757 1758 1759 1760 1761 1762 1763 1764 1765
    def ResetOptionalInputs(self):
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents

        base_generator = FunctionGeneratorBase(grad_api_contents, namespace)
        base_generator.ParseDispensable()

        self.optional_inputs = base_generator.optional_inputs

1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781
    def RecordGrad2NextGradNameMapping(self, next_node_generator):
        next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
        next_orig_returns_list = next_node_generator.orig_forward_returns_list

        next_forward_inputs_list = next_node_generator.forward_inputs_list
        next_forward_returns_list = next_node_generator.forward_returns_list
        for i in range(len(next_orig_inputs_list)):
            grad_name = next_orig_inputs_list[i][0]
            next_forward_name = next_forward_inputs_list[i][0]
            self.to_next_grad_name_mapping[grad_name] = next_forward_name

        for i in range(len(next_orig_returns_list)):
            grad_ret_name = next_orig_returns_list[i][0]
            next_ret_name = next_forward_returns_list[i][0]
            self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

1782
    def GenerateHigherOrderNodeCreationCode(self):
1783
        has_higher_order_node = False
1784 1785
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents
1786
        forward_apis_dict = self.forward_apis_dict
1787 1788
        next_grad_api_contents = self.next_grad_api_contents

1789 1790
        next_grad_node_creation_str = ""
        next_grad_node_out_list = []
J
Jiabin Yang 已提交
1791
        next_node_generator = None
1792
        if next_grad_api_contents:
1793
            # Fake forward_api_contents and backward_api_contents
1794
            forward_api_contents = grad_api_contents
1795
            forward_api_contents['op'] = forward_api_contents['backward_op']
1796 1797 1798
            backward_api_contents = next_grad_api_contents

            next_node_generator = DygraphFunctionGeneratorBase(
1799 1800 1801 1802 1803
                forward_api_contents,
                backward_api_contents,
                forward_apis_dict,
                namespace,
            )
1804
            next_node_generator.run()
1805
            next_node_generator.GenerateNodeCreationCodes(True)
J
Jiabin Yang 已提交
1806

1807 1808
            next_grad_node_creation_str = next_node_generator.node_creation_str
            next_grad_node_out_list = next_node_generator.grad_node_out_list
1809

1810
            self.RecordGrad2NextGradNameMapping(next_node_generator)
1811 1812 1813 1814

        is_invoke_forward_api = IsInvokeForwardApi(
            self.grad_api_contents, self.forward_apis_dict
        )
J
Jiabin Yang 已提交
1815
        if next_node_generator is not None:
1816
            has_higher_order_node = True
1817
            return (
1818 1819
                has_higher_order_node,
                is_invoke_forward_api,
1820 1821 1822 1823
                next_grad_node_creation_str,
                next_grad_node_out_list,
                next_node_generator.backward_forward_inputs_map,
            )
1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837
        elif not is_invoke_forward_api:
            next_grad_node_creation_str = f"""  if(trace_backward) {{
    PADDLE_THROW(phi::errors::Unavailable(
    \"The Op {self.backward_api_name} doesn't have any grad\"
    \"op. If you don't intend calculating higher order\"
    \"derivatives, please set `create_graph`to False.\"));
  }}"""
        return (
            has_higher_order_node,
            is_invoke_forward_api,
            next_grad_node_creation_str,
            next_grad_node_out_list,
            None,
        )
1838

1839 1840 1841 1842 1843
    def GenerateNodeDeclaration(self):
        forward_op_name = self.forward_api_name
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_attrs_list = self.backward_attrs_list
        no_need_buffers = self.no_need_buffers
1844

1845
        # SetTensorWrapper Methods & TensorWrapper Members & ClearTensorWrappers
1846 1847 1848
        set_tensor_wrapper_methods_str = ""
        tensor_wrapper_members_str = ""
        clear_tensor_wrapper_str = ""
1849 1850 1851 1852 1853
        for tname, (
            ttype,
            is_fwd_input,
            _,
        ) in backward_forward_inputs_map.items():
1854 1855 1856
            no_need_buffer = "true" if tname in no_need_buffers else "false"
            tensor_wrapper_name = GetSavedName(tname)
            if IsPlainTensorType(ttype):
1857 1858 1859 1860 1861
                set_tensor_wrapper_methods_str += (
                    SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
                        tname, tname, tensor_wrapper_name, tname, no_need_buffer
                    )
                )
1862

1863 1864 1865
                tensor_wrapper_members_str += (
                    PLAIN_TENSOR_MEMBER_TEMPLATE.format(tensor_wrapper_name)
                )
1866

1867 1868 1869
                clear_tensor_wrapper_str += (
                    CLEAR_TENSOR_WRAPPER_TEMPLATE.format(tensor_wrapper_name)
                )
1870

1871 1872
            else:
                assert IsVectorTensorType(ttype)
1873 1874 1875 1876 1877
                set_tensor_wrapper_methods_str += (
                    SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
                        tname, tname, tname, tensor_wrapper_name, no_need_buffer
                    )
                )
1878

1879 1880 1881
                tensor_wrapper_members_str += (
                    VECTOR_TENSOR_MEMBER_TEMPLATE.format(tensor_wrapper_name)
                )
1882

1883 1884 1885 1886 1887
                clear_tensor_wrapper_str += (
                    CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
                        tensor_wrapper_name
                    )
                )
1888 1889 1890 1891 1892 1893 1894

        # SetAttributes & Attribute Members
        set_attribute_methods_str = ""
        attribute_members_str = ""
        for aname, atype, default_val, _ in backward_attrs_list:
            saved_attr_name = GetSavedName(aname)
            set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
1895 1896
                aname, GetConstReference(atype), aname, saved_attr_name, aname
            )
1897 1898

            if default_val:
1899 1900 1901 1902 1903 1904 1905
                attribute_members_str += (
                    ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
                        RemoveConstAndReference(atype),
                        saved_attr_name,
                        default_val,
                    )
                )
1906 1907
            else:
                attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
1908 1909
                    RemoveConstAndReference(atype), saved_attr_name
                )
1910

1911
        grad_node_name = GetGradNodeName(self.backward_api_name)
1912
        self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928
            grad_node_name,
            grad_node_name,
            grad_node_name,
            grad_node_name,
            grad_node_name,
            clear_tensor_wrapper_str,
            grad_node_name,
            grad_node_name,
            set_tensor_wrapper_methods_str,
            set_attribute_methods_str,
            tensor_wrapper_members_str,
            attribute_members_str,
        )

    def GenerateNodeDefinition(
        self,
1929 1930
        has_higher_order_node,
        is_invoke_forward_api,
1931 1932 1933 1934
        next_grad_node_creation_str,
        next_grad_node_out_list,
        backward_forward_inputs_map_next,
    ):
1935 1936 1937 1938 1939 1940 1941
        namespace = self.namespace
        forward_api_name = self.forward_api_name
        backward_api_name = self.backward_api_name
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
        backward_attrs_list = self.backward_attrs_list
1942
        backward_inplace_map = self.backward_inplace_map
1943
        indent = GetIndent(1)
1944 1945 1946

        # Construct grad_api function args
        # Order: TensorWrappers, GradTensors, Attributes
1947 1948 1949 1950 1951
        grad_api_args_len = (
            len(backward_forward_inputs_map.keys())
            + len(backward_grad_inputs_map.keys())
            + len(backward_attrs_list)
        )
1952
        grad_api_args = ["" for i in range(grad_api_args_len)]
1953 1954 1955 1956
        get_grad_in_args_list = []

        # Fill Grad Ins with Zero
        fill_zero_str = ""
1957
        if backward_api_name in ops_to_fill_zero_for_empty_grads:
1958 1959 1960 1961 1962 1963 1964 1965
            fill_zero_str = (
                f"{indent}const auto& input_metas = this->InputMeta();\n"
            )
            for name, (
                ttype,
                fwd_position,
                grad_api_position,
            ) in backward_grad_inputs_map.items():
1966 1967 1968 1969 1970 1971 1972 1973
                if name in self.optional_inputs:
                    if IsPlainTensorType(ttype):
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n"
                else:
                    if IsPlainTensorType(ttype):
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}][0], input_metas[{fwd_position}][0]);\n"
                    else:
                        fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
1974

1975
        inplace_grad_input_str = ""
J
Jiabin Yang 已提交
1976 1977
        inplaced_tensor_wrapper = False
        inplace_check_str = ""
1978
        optional_inplace_var_name = []
1979
        # Grad Ins from TensorWrappers
1980 1981 1982 1983
        for (
            name,
            (backward_input_type, is_fwd_input, grad_api_position),
        ) in backward_forward_inputs_map.items():
1984
            tensor_wrapper_name = GetSavedName(name)
1985
            transformed_tensor_name = self.TransformToNextGradName(name)
1986

1987
            is_optional = name in self.optional_inputs
1988
            tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});"
1989
            if backward_inplace_map and name in backward_inplace_map.keys():
1990
                if has_higher_order_node:
1991 1992 1993 1994 1995 1996 1997 1998
                    if (
                        transformed_tensor_name
                        in backward_forward_inputs_map_next
                    ) and (
                        backward_forward_inputs_map_next[
                            transformed_tensor_name
                        ][1]
                    ):
1999
                        optional_inplace_var_name.append(
2000 2001 2002 2003 2004
                            transformed_tensor_name
                        )
                tensor_wrapper_intermidiate_tensor_str = (
                    f"(&this->{tensor_wrapper_name})->get_intermidiate_tensor()"
                )
J
Jiabin Yang 已提交
2005
                inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
2006 2007 2008 2009 2010 2011 2012 2013 2014
                    transformed_tensor_name,
                    transformed_tensor_name,
                    name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    transformed_tensor_name,
                    tensor_wrapper_intermidiate_tensor_str,
                )
2015
                inplace_grad_input_str = transformed_tensor_name
2016
            if is_optional:
2017
                if backward_input_type == "std::vector<Tensor>":
2018 2019 2020 2021 2022 2023 2024 2025 2026
                    tensor_wrapper_recover_str += (
                        "\n"
                        + CREATE_RECOVER_OPTIONAL_VECTOR_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
2027
                else:
2028 2029 2030 2031 2032 2033 2034 2035 2036
                    tensor_wrapper_recover_str += (
                        "\n"
                        + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
H
hong 已提交
2037

2038 2039 2040
                grad_api_args[grad_api_position] = (
                    transformed_tensor_name + "_optional"
                )
H
hong 已提交
2041

2042
            else:
H
hong 已提交
2043 2044
                grad_api_args[grad_api_position] = transformed_tensor_name

2045 2046 2047
            get_grad_in_args_list.append(tensor_wrapper_recover_str)

        # Grad Ins from grads
2048 2049 2050 2051 2052
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_inputs_map.items():
2053
            transformed_tensor_name = self.TransformToNextGradName(name)
2054

2055
            is_optional = name in self.optional_inputs
2056
            if IsPlainTensorType(ttype):
2057 2058
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"

2059 2060
                # Inplace in backward op
                if backward_inplace_map and name in backward_inplace_map.keys():
2061
                    if has_higher_order_node:
2062 2063 2064 2065 2066 2067 2068 2069
                        if (
                            transformed_tensor_name
                            in backward_forward_inputs_map_next
                        ) and (
                            backward_forward_inputs_map_next[
                                transformed_tensor_name
                            ][1]
                        ):
2070
                            optional_inplace_var_name.append(
2071 2072
                                transformed_tensor_name
                            )
2073
                    grads_tensor_str = f"grads[{fwd_position}][0]"
J
Jiabin Yang 已提交
2074
                    inplace_check_str += CHECK_BACKWARD_INPLACE_TEMPLATE.format(
2075 2076 2077 2078 2079 2080 2081 2082 2083
                        transformed_tensor_name,
                        transformed_tensor_name,
                        name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        transformed_tensor_name,
                        grads_tensor_str,
                    )
2084 2085
                    inplace_grad_input_str = transformed_tensor_name

2086
                if is_optional:
2087 2088 2089 2090 2091 2092 2093 2094 2095
                    get_tensor_str += (
                        "\n"
                        + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                            transformed_tensor_name,
                        )
                    )
2096
                    grad_api_args[
2097 2098
                        grad_api_position
                    ] = f"{transformed_tensor_name}_optional"
2099 2100
                else:
                    grad_api_args[grad_api_position] = transformed_tensor_name
2101 2102
            else:
                assert IsVectorTensorType(ttype)
2103 2104 2105
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
                grad_api_args[grad_api_position] = transformed_tensor_name

2106
            get_grad_in_args_list.append(get_tensor_str)
2107

2108
        # Grad Attrs
2109 2110
        for name, _, _, grad_api_position in backward_attrs_list:
            saved_attribute_name = GetSavedName(name)
2111 2112 2113
            get_attr_str = (
                f"{indent}auto& {name} = this->{saved_attribute_name};"
            )
2114 2115 2116 2117 2118 2119

            grad_api_args[grad_api_position] = name
            get_grad_in_args_list.append(get_attr_str)

        get_grad_in_args_str = "\n".join(get_grad_in_args_list)

2120 2121 2122
        # Grad Function Call String
        slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
        grad_api_namespace = f"paddle::experimental::{namespace}"
J
Jiabin Yang 已提交
2123
        grad_function_prepare_str = f"""
2124 2125 2126
  const auto& out_metas = OutputMeta();
  paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> returns({slot_num_bwd_outputs});
  for (int i = 0; i < {slot_num_bwd_outputs}; ++i) {{
2127
    out_metas[i].size() == 0 ? returns[i].resize(1) : returns[i].resize(out_metas[i].size());
2128 2129
  }}
"""
J
Jiabin Yang 已提交
2130 2131
        inplace_for_grad_outs_str = ""
        optional_inplace_str = ""
2132 2133
        # Grad Outputs
        out_index = -1
2134
        out_assign_str = ""
2135 2136 2137 2138 2139
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
2140 2141
            transformed_tensor_name = self.TransformToNextGradName(name)
            out_index = out_index + 1
2142 2143
            if is_invoke_forward_api:
                if len(backward_grad_outputs_map) == 1:
2144 2145 2146
                    out_assign_str += (
                        f"{indent}*api_output_{out_index} = api_output;\n"
                    )
2147 2148 2149 2150
                else:
                    out_assign_str += f"{indent}*api_output_{out_index} = std::get<{out_index}>(api_output);\n"
            else:
                grad_api_args.append(f"api_output_{out_index}")
2151 2152
            if inplace_grad_input_str in optional_inplace_var_name:
                optional_inplace_str = "VLOG(6) << \"No Inplace should happend for wrappered input: {inplace_grad_input_str}\";"
J
Jiabin Yang 已提交
2153 2154
            else:
                optional_inplace_str = f"""if (api_output_{out_index} != nullptr && can_be_inplaced) {{
2155 2156
      egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
    }}"""
2157
            if IsPlainTensorType(ttype):
J
Jiabin Yang 已提交
2158

2159 2160 2161
                if (
                    backward_inplace_map
                    and name in backward_inplace_map.values()
2162
                ):
2163 2164 2165
                    inplace_str = f""" if (api_output_{out_index} != nullptr && can_be_inplaced) {{
      egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index});
    }}"""
2166
                    if has_higher_order_node:
J
Jiabin Yang 已提交
2167
                        inplace_for_grad_outs_str += f"""
2168
  if (trace_backward) {{
J
Jiabin Yang 已提交
2169
    {optional_inplace_str}
2170 2171
  }} else {{
    {inplace_str}
J
Jiabin Yang 已提交
2172 2173 2174
  }}"""
                    else:
                        inplace_for_grad_outs_str += inplace_str
2175

J
Jiabin Yang 已提交
2176 2177
                grad_function_prepare_str += f"""
  auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];"""
2178 2179 2180

            else:
                assert IsVectorTensorType(ttype)
J
Jiabin Yang 已提交
2181
                grad_function_prepare_str += f"""
2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193
  std::vector<paddle::experimental::Tensor*> api_output_{out_index};
  api_output_{out_index}.reserve(returns[{fwd_position}].size());
  for (size_t i = 0; i < returns[{fwd_position}].size(); ++i) {{
    if (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][i].IsStopGradient()) {{
      api_output_{out_index}.push_back(nullptr);
    }} else {{
      api_output_{out_index}.push_back(&returns[{fwd_position}][i]);
    }}
  }}"""

        grad_api_args_str = ", ".join(grad_api_args)

2194 2195
        if is_invoke_forward_api:
            autograd_api_out = "auto"
2196 2197 2198 2199
            if (
                len(self.backward_inplace_map) > 0
                and len(backward_grad_outputs_map) == 1
            ):
2200
                autograd_api_out = "auto&"
2201 2202 2203
            forward_api_name = (
                self.grad_api_contents['invoke'].split('(')[0].strip()
            )
2204
            autograd_api = self.grad_api_contents['invoke'].replace(
J
Jiabin Yang 已提交
2205
                forward_api_name,
2206 2207 2208
                GetDygraphForwardFunctionName(forward_api_name),
                1,
            )
2209
            grad_function_call_str = f"""
2210
  if (trace_backward) {{
2211 2212 2213 2214 2215 2216 2217
  {indent}{autograd_api_out} api_output = {autograd_api};
  {out_assign_str}}} else {{
  {indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']};
  {out_assign_str}{indent}}}
  """
        else:
            grad_function_call_str = f"""
2218
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
2219

2220
        # Check Nan and Inf
2221
        check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
2222 2223
            backward_api_name, "returns"
        )
2224

2225 2226
        # Prepare for Node Creation if Necessary
        outputs_autograd_meta_str = ""
2227
        compute_require_next_grad_str = ""
2228
        if len(next_grad_node_creation_str) > 0 or is_invoke_forward_api:
2229 2230 2231 2232 2233 2234 2235
            compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"

        # 3. Get Output AutoGradMeta
        outputs_autograd_meta_list = []
        # TODO(jiabin): Optimize this with SetStopGradient instead of Pass Stop gradient

        num_fwd_outputs = len(backward_grad_outputs_map.keys())
2236 2237 2238 2239 2240
        for name, (
            rtype,
            pos,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
2241 2242 2243
            transformed_tensor_name = self.TransformToNextGradName(name)

            output_autograd_meta_name = GetAutoGradMetaName(
2244 2245
                transformed_tensor_name
            )
2246
            output_autograd_meta_vec_name = GetAutoGradMetaVectorName(
2247 2248
                transformed_tensor_name
            )
2249 2250
            if IsPlainTensorType(rtype):
                output_autograd_meta = f"""
2251
  auto& {transformed_tensor_name} = returns[{pos}][0];
2252 2253 2254
  egr::AutogradMeta* {output_autograd_meta_name} = returns[{pos}][0].initialized() ? egr::EagerUtils::autograd_meta(&{transformed_tensor_name}) : nullptr;
  if ({output_autograd_meta_name}) {output_autograd_meta_name}->SetStopGradient(false);
  """
2255

2256 2257
            else:
                assert IsVectorTensorType(rtype)
2258
                if has_higher_order_node > 0:
2259 2260 2261 2262 2263 2264 2265 2266
                    output_autograd_meta = f"""
    auto& {transformed_tensor_name} = returns[{pos}];
    std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});
    std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};
    for(auto* meta : {output_autograd_meta_vec_name}){{
        meta->SetStopGradient(false);
    }}
"""
2267
                else:
2268
                    output_autograd_meta = f"""
2269 2270 2271 2272 2273
    auto& {transformed_tensor_name} = returns[{pos}];
    std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});
    for(auto* meta : {output_autograd_meta_vec_name}){{
        meta->SetStopGradient(false);
    }}
2274
"""
2275
            outputs_autograd_meta_list.append(output_autograd_meta)
2276

2277
        outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
2278

2279
        returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
2280
        returns_str += f"{indent}return returns;\n"
2281

2282
        grad_node_name = GetGradNodeName(self.backward_api_name)
J
Jiabin Yang 已提交
2283 2284 2285
        # For inputs outputs prepare for logging
        var_str = f"\n{indent}  std::string input_str = \"\";"
        var_str += f"\n{indent}  std::string output_str = \"\";"
2286 2287 2288 2289 2290
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_inputs_map.items():
J
Jiabin Yang 已提交
2291
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2292
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2293 2294 2295
            var_str += f"\n{indent}  std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  input_str += input_{new_name}_str; "

2296 2297 2298 2299
        for (
            name,
            (backward_input_type, is_fwd_input, grad_api_position),
        ) in backward_forward_inputs_map.items():
J
Jiabin Yang 已提交
2300
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2301
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2302 2303 2304
            var_str += f"\n{indent}  std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  input_str += input_{new_name}_str; "

2305 2306
        before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str)

2307 2308 2309 2310 2311
        for name, (
            ttype,
            fwd_position,
            grad_api_position,
        ) in backward_grad_outputs_map.items():
J
Jiabin Yang 已提交
2312
            new_name = self.TransformToNextGradName(name)
J
Jiabin Yang 已提交
2313
            var_str += f"\n{indent}  const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n ( {new_name} , [%s]), \";"
J
Jiabin Yang 已提交
2314 2315 2316
            var_str += f"\n{indent}  std::string output_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));"
            var_str += f"\n{indent}  output_str += output_{new_name}_str; "

2317
        log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str)
2318

2319
        self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337
            grad_node_name,
            self.backward_api_name,
            fill_zero_str,
            get_grad_in_args_str,
            grad_function_prepare_str,
            compute_require_next_grad_str,
            inplace_check_str,
            inplace_for_grad_outs_str,
            self.backward_api_name,
            before_log_str,
            grad_function_call_str,
            check_nan_inf_str,
            outputs_autograd_meta_str,
            next_grad_node_creation_str,
            self.backward_api_name,
            log_str,
            returns_str,
        )
2338 2339 2340

    def run(self):
        super().run()
2341

2342 2343
        self.ResetOptionalInputs()

2344 2345 2346
        #####################
        ## Code Generation ##
        #####################
2347
        # Higher-order GradNode generation
2348
        (
2349 2350
            has_higher_order_node,
            is_invoke_forward_api,
2351 2352 2353 2354
            next_grad_node_creation_str,
            next_grad_node_out_list,
            backward_forward_inputs_map,
        ) = self.GenerateHigherOrderNodeCreationCode()
2355

2356 2357
        self.GenerateNodeDeclaration()

2358
        self.GenerateNodeDefinition(
2359 2360
            has_higher_order_node,
            is_invoke_forward_api,
2361 2362 2363 2364
            next_grad_node_creation_str,
            next_grad_node_out_list,
            backward_forward_inputs_map,
        )
2365 2366


2367
class DygraphForwardAndNodesGenerator(GeneratorBase):
2368
    def __init__(self, api_yaml_path, backward_yaml_path):
2369
        # Parent members:
2370 2371 2372
        # self.namespace
        # self.api_yaml_path
        # self.forward_api_list
2373
        GeneratorBase.__init__(self, api_yaml_path)
2374 2375 2376 2377 2378

        self.backward_yaml_path = backward_yaml_path
        self.grad_api_dict = {}

        self.forward_declaration_str = ""
2379 2380
        self.forward_definition_str = ""

2381 2382 2383
        self.node_declaration_str = ""
        self.node_definition_str = ""

2384
    def CollectIsForwardOnly(self, forward_api_contents):
2385 2386 2387
        self.is_forward_only = (
            False if 'backward' in forward_api_contents.keys() else True
        )
2388

2389 2390 2391 2392
    def ParseYamlContents(self):
        self.ParseForwardYamlContents()

        backward_yaml_path = self.backward_yaml_path
2393 2394 2395 2396

        # string api is forward_only, no backward_yaml respectively
        if backward_yaml_path is not None:
            self.grad_api_dict = ReadBwdFile(backward_yaml_path)
2397 2398 2399 2400

    def GetBackwardAPIContents(self, forward_api_contents):
        grad_api_dict = self.grad_api_dict

2401 2402
        if 'backward' not in forward_api_contents.keys():
            return None
2403 2404

        backward_api_name = forward_api_contents['backward']
2405
        assert backward_api_name in grad_api_dict.keys(), AssertMessage(
2406 2407
            backward_api_name, grad_api_dict.keys()
        )
2408 2409 2410 2411 2412 2413 2414
        backward_api_contents = grad_api_dict[backward_api_name]

        return backward_api_contents

    def GenerateCode(self):
        forward_api_list = self.forward_api_list
        grad_api_dict = self.grad_api_dict
2415 2416
        forward_apis_dict = {}
        for api_item in forward_api_list:
2417
            forward_apis_dict[api_item['op']] = api_item
2418 2419 2420
        namespace = self.namespace

        for forward_api_contents in forward_api_list:
2421 2422
            if forward_api_contents['op'] in black_ops_list:
                continue
W
Weilong Wu 已提交
2423

2424 2425 2426 2427 2428 2429
            self.CollectIsForwardOnly(forward_api_contents)

            if self.is_forward_only:
                backward_api_contents = None
            else:
                backward_api_contents = self.GetBackwardAPIContents(
2430 2431
                    forward_api_contents
                )
2432

2433
            # Generate Dygraph Forward Function
2434
            function_generator = DygraphForwardFunctionGenerator(
2435 2436 2437 2438 2439
                forward_api_contents,
                backward_api_contents,
                forward_apis_dict,
                namespace,
            )
2440 2441
            function_generator.run()

2442 2443 2444 2445 2446 2447
            self.forward_definition_str += (
                function_generator.forward_definition_str + "\n"
            )
            self.forward_declaration_str += (
                function_generator.forward_declaration_str + "\n"
            )
2448

2449
            # Generate Dygraph GradNode Function
2450
            while True:
2451 2452
                if backward_api_contents is None:
                    break
2453
                next_grad_api_contents = self.GetBackwardAPIContents(
2454 2455
                    backward_api_contents
                )
2456

2457 2458 2459 2460 2461 2462 2463
                node_generator = DygraphNodeGenerator(
                    forward_api_contents,
                    backward_api_contents,
                    forward_apis_dict,
                    namespace,
                    next_grad_api_contents,
                )
2464
                node_generator.run()
2465 2466 2467 2468 2469 2470
                self.node_declaration_str += (
                    node_generator.node_declaration_str + "\n"
                )
                self.node_definition_str += (
                    node_generator.node_definition_str + "\n"
                )
2471

2472 2473
                if next_grad_api_contents is None:
                    break
2474 2475 2476 2477 2478

                # Detect if there exists higher-order GradNode
                forward_api_contents = backward_api_contents

                # Fake forward_api_content
2479
                forward_api_contents['op'] = forward_api_contents['backward_op']
2480
                backward_api_contents = next_grad_api_contents
2481 2482 2483 2484 2485

        if len(namespace) > 0:
            if namespace.endswith("::"):
                namespace = namespace[:-2]
            self.forward_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2486 2487
                namespace, self.forward_definition_str
            )
2488
            self.forward_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2489 2490
                namespace, self.forward_declaration_str
            )
2491
            self.node_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2492 2493
                namespace, self.node_declaration_str
            )
2494
            self.node_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
2495 2496
                namespace, self.node_definition_str
            )
2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508

    def run(self):
        self.ParseYamlContents()

        self.InferNameSpace()

        self.GenerateCode()


##################
## File Writers ##
##################
2509
def GenerateNodeCCFile(filepath, node_definition_str):
2510 2511
    if os.path.exists(filepath):
        os.remove(filepath)
2512

2513
    file_contents = NODE_CC_FILE_TEMPLATE.format(node_definition_str)
2514 2515 2516 2517 2518
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateNodeHFile(filepath, node_declaration_str):
2519 2520
    if os.path.exists(filepath):
        os.remove(filepath)
2521

2522
    file_contents = NODE_H_FILE_TEMPLATE.format(node_declaration_str)
2523 2524 2525 2526 2527
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardCCFile(filepath, forward_definition_str):
2528 2529
    if os.path.exists(filepath):
        os.remove(filepath)
2530

2531
    core_ops_info_str = GenerateCoreOpInfoDefinition()
2532 2533 2534
    file_contents = FORWARD_CC_FILE_TEMPLATE.format(
        core_ops_info_str, forward_definition_str
    )
2535 2536 2537 2538 2539
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardHFile(filepath, forward_function_declaration_str):
2540 2541
    if os.path.exists(filepath):
        os.remove(filepath)
2542

2543 2544
    core_ops_info_str = GenerateCoreOpInfoDeclaration()
    file_contents = FORWARD_H_FILE_TEMPLATE.format(
2545 2546
        core_ops_info_str, forward_function_declaration_str
    )
2547 2548 2549 2550 2551 2552 2553
    with open(filepath, 'a') as f:
        f.write(file_contents)


if __name__ == "__main__":
    args = ParseArguments()

2554 2555
    api_yaml_paths = args.api_yaml_path.split(",")
    backward_yaml_paths = args.backward_yaml_path.split(",")
2556 2557 2558 2559

    # Generate per Dygraph API
    node_declaration_str = ""
    node_definition_str = ""
2560

2561
    forward_declaration_str = ""
2562
    forward_definition_str = ""
2563

2564 2565
    for i in range(len(api_yaml_paths)):
        api_yaml_path = api_yaml_paths[i]
2566 2567

        # string api is forwrad only
C
Chen Weihang 已提交
2568
        if not api_yaml_path.endswith('strings_ops.yaml'):
2569 2570 2571
            backward_yaml_path = backward_yaml_paths[i]
        else:
            backward_yaml_path = None
2572

2573 2574 2575
        generator = DygraphForwardAndNodesGenerator(
            api_yaml_path, backward_yaml_path
        )
2576
        generator.run()
2577

2578 2579
        node_declaration_str += generator.node_declaration_str + "\n"
        node_definition_str += generator.node_definition_str + "\n"
2580

2581
        forward_declaration_str += generator.forward_declaration_str + "\n"
2582
        forward_definition_str += generator.forward_definition_str + "\n"
2583

2584 2585 2586 2587 2588 2589 2590 2591 2592 2593
    # Generate Files
    nodes_h_path = args.nodes_h_path
    nodes_cc_path = args.nodes_cc_path
    forwards_h_path = args.forwards_h_path
    forwards_cc_path = args.forwards_cc_path

    GenerateNodeCCFile(nodes_cc_path, node_definition_str)
    GenerateNodeHFile(nodes_h_path, node_declaration_str)
    GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
    GenerateForwardHFile(forwards_h_path, forward_declaration_str)