eager_gen.py 68.8 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import yaml
import re
import argparse
18
import os
19
import logging
20 21 22 23 24 25 26 27 28 29 30 31 32 33
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile
from codegen_utils import FindGradName, FindForwardName, GetSavedName, GetGradNodeName
from codegen_utils import IsPlainTensorType, IsVectorTensorType
from codegen_utils import GetConstReference, RemoveConstAndReference
from codegen_utils import GetDygraphForwardFunctionName, GetIntermediateAPIFunctionName
from codegen_utils import GetAutoGradMetaName, GetAutoGradMetaVectorName
from codegen_utils import RemoveSpecialSymbolsInName, RecoverBaseNameOfInplaceFunction
from codegen_utils import GetInplacedFunctionName
from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromBackward
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
34
from codegen_utils import AssertMessage, GetIndent
35 36 37 38 39


###########
## Utils ##
###########
40 41 42 43 44 45 46 47 48 49 50 51 52 53
def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Eager Code Generator Args Parser')
    parser.add_argument('--nodes_h_path', type=str)
    parser.add_argument('--nodes_cc_path', type=str)
    parser.add_argument('--forwards_h_path', type=str)
    parser.add_argument('--forwards_cc_path', type=str)
    parser.add_argument('--api_yaml_path', type=str)
    parser.add_argument('--backward_yaml_path', type=str)

    args = parser.parse_args()
    return args


54 55 56 57 58
########################
## Code Gen Templates ##
########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE = \
"""
59
   void SetTensorWrapper{}(const paddle::experimental::Tensor& {}, bool full_reserved) {{
60
     {} = egr::TensorWrapper({}, full_reserved, {});
61 62 63
   }}
"""

64 65 66
PLAIN_TENSOR_MEMBER_TEMPLATE = \
"""
       egr::TensorWrapper {};
67
"""
68

69 70 71
CLEAR_TENSOR_WRAPPER_TEMPLATE = \
"""
       {}.clear();
72 73
"""

74 75 76 77 78 79 80
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE = \
"""
       void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}, bool full_reserved) {{
         for(const auto& eager_tensor : {}) {{
            {}.emplace_back( egr::TensorWrapper(eager_tensor, full_reserved, {}) );
         }};
       }}
81 82
"""

83 84 85
VECTOR_TENSOR_MEMBER_TEMPLATE = \
"""
       std::vector<egr::TensorWrapper> {};
86
"""
87

88
CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE = \
89
"""
90
       for (auto& tw : {}) {{
91
         tw.clear();
92
       }}
93 94
"""

95 96 97 98 99 100 101 102 103
SET_ATTR_METHOD_TEMPLATE = \
"""
       void SetAttribute{}({} {}) {{
         {} = {};
       }}
"""

ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE = \
"""
104
       {} {} = {};
105 106 107 108
"""

ATTRIBUTE_MEMBER_TEMPLATE = \
"""
109
       {} {};
110 111
"""

112 113
NODE_DECLARATION_TEMPLATE = \
"""
114 115 116 117 118 119 120 121 122 123 124 125
class {} : public egr::GradNodeBase {{
 public:
  {}() : egr::GradNodeBase() {{}}
  {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : 
      egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
  ~{}() override = default;

  virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
      std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
  std::string name() override {{ return \" {} \"; }}
  
  void ClearTensorWrappers() override {{
126
      {}
127 128 129 130 131 132 133
      SetIsTensorWrappersCleared(true);
  }}

  std::shared_ptr<GradNodeBase> Copy() const override {{
      auto copied_node = std::shared_ptr<{}>(new {}(*this));
      
      return copied_node;
134 135 136 137 138 139 140 141 142 143 144 145 146 147
  }}
  
  // SetTensorWrapperX, SetTensorWrapperY, ...
  {}
  // SetAttributes
  {}

 private:
  // TensorWrappers
  {}

  // Attributes
  {}
}};
148 149
"""

150
GRAD_FUNCTION_TEMPLATE = \
151
"""
152 153 154
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
    // Fill Zero For GradIn Tensors
{}
155

156 157 158 159 160
    // Apply Gradient Hooks
    auto hooked_grads = ApplyGradientHooks(grads);
    
    // Collect GradIn Tensors, Attrs and Recovered TensorWrappers
{}
161

162 163 164
    // Call grad_api function
    VLOG(3) << \"Final State Running: \" << \"{}\"; 
{}
165

166 167
    // Get Output
{}
168

169 170
    // Get GradIn autograd_meta
{}
171

172 173 174 175 176 177 178 179
    // Get GradOut autograd_meta
{}
    
    // Compute Require Grad
{}
    
    // Create Grad Node
{}
180

181 182 183 184
    // Return 
{}

}}
185 186
"""

187 188
FORWARD_FUNCTION_TEMPLATE = \
"""
189
{} {}({}) {{
190 191 192 193 194 195 196 197
    // Dygraph Record Event
{}
    // AMP Logic
{}
    
    // Get Input AutoGradMeta
{}
    // Forward API Call
198 199
{}
    // Get Outputs
200 201 202 203 204
{}
    // Get Output AutoGradMeta
{}
    bool trace_backward = egr::Controller::Instance().HasGrad();
    bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
205
    
206 207 208 209 210
    // Check Inplace & Bump Inplace Version
{}
{}
    // Node Creation
{}
211

212 213 214
    // Returns
    return {};
}}
215

216
"""
217

218
FORWARD_BODY_TEMPLATE = \
219
"""
220
    if(require_any_grad) {{
221
{}
222 223 224
      egr::EagerUtils::PassStopGradient({});
            
      // Node Construction
225
{}
226
      // SetAttributes
227
{}
228
      // SetTensorWrappers
229
{}
230
      // SetGradOutMeta & SetEdges
231 232
{}
{}
233
      // SetOutRank & SetHistory & SetGradInMeta & RetainGrad
234 235 236 237
{}
{}
{}
{}
238
    }}
239
"""
240

241 242 243 244 245
NAMESPACE_WRAPPER_TEMPLATE = \
"""
namespace {} {{
    {}
}}
246
"""
247

248 249 250 251 252 253 254 255
NODE_CC_FILE_TEMPLATE = \
"""
#include "glog/logging.h"
#include "paddle/phi/api/all.h"
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/backward/sparse_bw_api.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
256
#include "paddle/fluid/platform/profiler/event_tracing.h"
257 258 259 260
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
261

262
#include "paddle/phi/api/include/sparse_api.h"
263 264 265 266

{}
"""

267 268 269 270 271
NODE_H_FILE_TEMPLATE = \
"""
#pragma once
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/eager/grad_node_info.h"
272

273 274
{}
"""
275

276 277 278 279 280
FORWARD_CC_FILE_TEMPLATE = \
"""
#include "paddle/phi/api/lib/dygraph_api.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
281

282 283 284
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
285 286
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
287

288 289
{}
{}
290 291
"""

292 293 294 295 296 297 298 299 300
FORWARD_H_FILE_TEMPLATE = \
"""
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/phi/api/all.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
301

302 303 304
{}
{}
"""
305

306 307
CORE_OPS_INFO_TEMPLATE = \
"""
308 309 310 311 312 313 314 315 316 317 318
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info = {{
    {}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info = {{
    {}
}};
std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info = {{
    {}
}};

"""
319 320 321

CORE_OPS_DECLARATION_TEMPLATE = \
"""
322 323 324
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_args_type_info;
extern std::unordered_map<std::string, std::vector<std::string>> core_ops_final_state_returns_info;
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

"""

CHECK_INPLACE_TEMPLATE = \
"""
    egr::EagerUtils::CheckInplace({}, {}, require_any_grad);\n
"""

BUMP_INPLACE_VERSION_TEMPLATE = \
"""
    // Bump Inplace Version
    {}.bump_inplace_version();
    VLOG(3) << \"Tensor(\" << {}.name() << \") uses Inplace Strategy.\";\n
"""

340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
AMP_LOGIC_TEMPLATE = \
"""
    if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
        VLOG(5) << "Check and Prepare For AMP";
        {}
        std::vector<std::vector<paddle::experimental::Tensor>> amp_tensors_vector = {};
        {}
        {}
        {}
        {{
            paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentTracer(), paddle::imperative::AmpLevel::O0);
            {}
        }}
    }}
"""

356 357 358 359 360 361
CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE = \
"""
    paddle::optional<const paddle::experimental::Tensor&> {}_optional = paddle::none;
    if({}.initialized()) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""

H
hong 已提交
362 363 364 365 366 367
CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = \
"""
    paddle::optional<const paddle::experimental::Tensor&> {}_optional = paddle::none;
    if( {}.impl() ) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""

368

369 370 371 372 373 374 375 376 377
#######################
## Generator Helpers ##
#######################
def GenerateCoreOpInfoDeclaration():
    return CORE_OPS_DECLARATION_TEMPLATE


def GenerateCoreOpInfoDefinition():

378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
    op_args_info_list = []
    for op_name, arg_list in core_ops_args_info.items():
        arg_str = ",".join(["\"" + v + "\"" for v in arg_list])
        op_args_info = f"{{ \"{op_name}\", {{ {arg_str} }} }},"
        op_args_info_list.append(op_args_info)

    op_types_info_list = []
    for op_name, type_list in core_ops_args_type_info.items():
        type_str = ",".join(["\"" + v + "\"" for v in type_list])
        op_types_info = f"{{ \"{op_name}\", {{ {type_str} }} }},"
        op_types_info_list.append(op_types_info)

    op_returns_info_list = []
    for op_name, return_list in core_ops_returns_info.items():
        return_str = ",".join(["\"" + v + "\"" for v in return_list])
        return_types_info = f"{{ \"{op_name}\", {{ {return_str} }} }},"
        op_returns_info_list.append(return_types_info)

    op_args_info_str = "\n".join(op_args_info_list)
    op_types_info_str = "\n".join(op_types_info_list)
    op_returns_info_str = "\n".join(op_returns_info_list)

    core_ops_info_definition_str = CORE_OPS_INFO_TEMPLATE.format(
        op_args_info_str, op_types_info_str, op_returns_info_str)

    return core_ops_info_definition_str


406 407 408
#####################
## Generator Class ##
#####################
409
class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
410 411 412 413 414 415 416 417 418 419 420 421 422
    def __init__(self, forward_api_contents, grad_api_contents, namespace):
        self.forward_api_contents = forward_api_contents
        # Members from Parent:
        #self.namespace
        #self.forward_api_contents
        #self.forward_api_name
        #self.orig_forward_inputs_list
        #self.orig_forward_attrs_list
        #self.orig_forward_returns_list
        #self.forward_inputs_position_map
        #self.forward_outputs_position_map
        #self.optional_inputs
        #self.no_need_buffers
423
        #self.intermediate_outputs
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
        #self.inplace_map
        FunctionGeneratorBase.__init__(self, forward_api_contents, namespace)

        self.grad_api_contents = grad_api_contents

        # Raw Contents
        self.backward_forward_str = ""
        self.backward_api_name = ""

        self.forward_attrs_list = [
        ]  #[ [attr_name, attr_type, default_value, orig_position], ...]
        self.forward_inputs_list = [
        ]  #[ [arg_name, arg_type, orig_position], ...]
        self.forward_returns_list = [
        ]  #[ [ret_name, ret_type, orig_position], ...]

        self.backward_inputs_list = [
        ]  #[ [attr_name, attr_type, default_value, orig_position], ...]
        self.backward_attrs_list = [
        ]  #[ [arg_name, arg_type, orig_position], ...]
        self.backward_returns_list = [
        ]  #[ [ret_name, ret_type, orig_position], ...]

        # SlotNameMatched Backward Data
        self.backward_forward_inputs_map = {
        }  #{ "name" : [type, is_fwd_input, orig_position] ...}
        self.backward_grad_inputs_map = {
        }  #{ "name" : [type, fwd_position, orig_position] ...}
        self.backward_grad_outputs_map = {
        }  #{ "name" : [type, fwd_position, orig_position] ...}

    def DygraphYamlValidationCheck(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
        assert 'api' in forward_api_contents.keys(
        ), "Unable to find \"api\" in api.yaml"
        assert 'args' in forward_api_contents.keys(
        ), "Unable to find \"args\" in api.yaml"
        assert 'output' in forward_api_contents.keys(
        ), "Unable to find \"output\" in api.yaml"
        assert 'backward' in forward_api_contents.keys(
        ), "Unable to find \"backward\" in api.yaml"

        assert 'args' in grad_api_contents.keys(
        ), "Unable to find \"args\" in backward.yaml"
        assert 'output' in grad_api_contents.keys(
        ), "Unable to find \"output\" in backward.yaml"
        assert 'forward' in grad_api_contents.keys(
        ), "Unable to find \"forward\" in backward.yaml"
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489

    def ForwardsValidationCheck(self):
        forward_inputs_list = self.forward_inputs_list
        forward_attrs_list = self.forward_attrs_list
        forward_returns_list = self.forward_returns_list

        orig_forward_inputs_list = self.orig_forward_inputs_list
        orig_forward_attrs_list = self.orig_forward_attrs_list
        orig_forward_returns_list = self.orig_forward_returns_list

        for i in range(len(forward_inputs_list)):
            forward_input_type = forward_inputs_list[i][1]
            forward_input_pos = forward_inputs_list[i][2]
            orig_input_type = orig_forward_inputs_list[i][1]
            orig_input_pos = orig_forward_inputs_list[i][2]

490 491 492 493
            assert forward_input_type == orig_input_type, AssertMessage(
                forward_input_type, orig_input_type)
            assert forward_input_pos == orig_input_pos, AssertMessage(
                forward_input_pos, orig_input_pos)
494 495 496 497 498 499 500 501

        for i in range(len(forward_attrs_list)):
            orig_attr_type = orig_forward_attrs_list[i][1]
            orig_attr_default = orig_forward_attrs_list[i][2]
            orig_attr_pos = orig_forward_attrs_list[i][3]
            forward_attr_type = forward_attrs_list[i][1]
            forward_attr_default = forward_attrs_list[i][2]
            forward_attr_pos = forward_attrs_list[i][3]
502 503 504 505 506 507
            assert orig_attr_type == forward_attr_type, AssertMessage(
                orig_attr_type, forward_attr_type)
            assert orig_attr_default == forward_attr_default, AssertMessage(
                orig_attr_default, forward_attr_default)
            assert orig_attr_pos == forward_attr_pos, AssertMessage(
                orig_attr_pos, forward_attr_pos)
508 509 510 511 512 513 514

        for i in range(len(forward_returns_list)):
            orig_return_type = orig_forward_returns_list[i][1]
            orig_return_pos = orig_forward_returns_list[i][2]
            forward_return_type = forward_returns_list[i][1]
            forward_return_pos = forward_returns_list[i][2]

515 516 517 518
            assert orig_return_type == forward_return_type, AssertMessage(
                orig_return_type, forward_return_type)
            assert orig_return_pos == forward_return_pos, AssertMessage(
                orig_return_pos, forward_return_pos)
519 520 521 522 523 524 525 526

        # Check Order: Inputs, Attributes
        max_input_position = -1
        for _, _, pos in forward_inputs_list:
            max_input_position = max(max_input_position, pos)

        max_attr_position = -1
        for _, _, _, pos in forward_attrs_list:
527 528
            assert pos > max_input_position, AssertMessage(pos,
                                                           max_input_position)
529 530 531 532 533 534 535 536 537 538 539 540 541 542
            max_attr_position = max(max_attr_position, pos)

    def BackwardValidationCheck(self):
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_attrs_list = self.backward_attrs_list

        # Check Order: TensorWrappers, GradTensors, Attributes
        max_fwd_input_position = -1
        for _, (_, _, pos) in backward_forward_inputs_map.items():
            max_fwd_input_position = max(max_fwd_input_position, pos)

        max_grad_tensor_position = -1
        for _, (_, _, pos) in backward_grad_inputs_map.items():
543 544
            assert pos > max_fwd_input_position, AssertMessage(
                pos, max_grad_tensor_position)
545 546 547 548
            max_grad_tensor_position = max(max_grad_tensor_position, pos)

        max_attr_position = -1
        for _, _, _, pos in backward_attrs_list:
549 550
            assert pos > max_grad_tensor_position, AssertMessage(
                pos, max_grad_tensor_position)
551 552 553 554 555 556 557 558 559 560 561 562 563 564
            max_attr_position = max(max_attr_position, pos)

    def IntermediateValidationCheck(self):
        intermediate_outputs = self.intermediate_outputs
        forward_returns_list = self.forward_returns_list
        """
        Check whether intermediate_outputs are positioned
        at the very end of forward_returns_list
        """
        intermediate_positions = range(
            len(forward_returns_list) - len(intermediate_outputs),
            len(forward_returns_list))
        for ret_name, _, pos in forward_returns_list:
            if ret_name in intermediate_outputs:
565 566
                assert pos in intermediate_positions, AssertMessage(
                    pos, intermediate_positions)
567 568 569 570 571 572 573 574 575 576 577 578 579

    def CollectBackwardInfo(self):
        forward_api_contents = self.forward_api_contents
        grad_api_contents = self.grad_api_contents

        self.backward_api_name = forward_api_contents['backward']
        self.backward_forward_str = grad_api_contents['forward']

        backward_args_str = grad_api_contents['args']
        backward_returns_str = grad_api_contents['output']

        self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
            backward_args_str, backward_returns_str)
580 581 582 583 584 585

        logging.info(
            f"Parsed Backward Inputs List: {self.backward_inputs_list}")
        logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}")
        logging.info(
            f"Parsed Backward Returns List: {self.backward_returns_list}")
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607

    def CollectForwardInfoFromBackwardContents(self):

        backward_forward_str = self.backward_forward_str

        self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
            backward_forward_str)

    def SlotNameMatching(self):
        backward_inputs_list = self.backward_inputs_list
        backward_returns_list = self.backward_returns_list
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map

        for backward_input in backward_inputs_list:
            backward_input_name = backward_input[0]
            backward_input_type = backward_input[1]
            backward_input_pos = backward_input[2]

            backward_fwd_name = FindForwardName(backward_input_name)
            if backward_fwd_name:
                # Grad Input
608 609 610
                assert backward_fwd_name in forward_outputs_position_map.keys(
                ), AssertMessage(backward_fwd_name,
                                 forward_outputs_position_map.keys())
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
                matched_forward_output_type = forward_outputs_position_map[
                    backward_fwd_name][0]
                matched_forward_output_pos = forward_outputs_position_map[
                    backward_fwd_name][1]

                self.backward_grad_inputs_map[backward_input_name] = [
                    backward_input_type, matched_forward_output_pos,
                    backward_input_pos
                ]
            else:
                # TensorWrapper Input
                if backward_input_name in forward_inputs_position_map.keys():
                    tensor_wrapper_type = forward_inputs_position_map[
                        backward_input_name][0]
                    self.backward_forward_inputs_map[backward_input_name] = [
                        backward_input_type, True, backward_input_pos
                    ]

                elif backward_input_name in forward_outputs_position_map.keys():
                    tensor_wrapper_type = forward_outputs_position_map[
                        backward_input_name][0]
                    self.backward_forward_inputs_map[backward_input_name] = [
                        backward_input_type, False, backward_input_pos
                    ]
                else:
636
                    assert False, f"Cannot find {backward_input_name} in forward position map"
637 638 639 640 641 642 643

        for backward_output in backward_returns_list:
            backward_output_name = backward_output[0]
            backward_output_type = backward_output[1]
            backward_output_pos = backward_output[2]

            backward_fwd_name = FindForwardName(backward_output_name)
644
            assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None"
645
            assert backward_fwd_name in forward_inputs_position_map.keys(
646 647
            ), AssertMessage(backward_fwd_name,
                             forward_inputs_position_map.keys())
648 649 650 651 652 653 654 655 656 657

            matched_forward_input_type = forward_inputs_position_map[
                backward_fwd_name][0]
            matched_forward_input_pos = forward_inputs_position_map[
                backward_fwd_name][1]

            self.backward_grad_outputs_map[backward_output_name] = [
                backward_output_type, matched_forward_input_pos,
                backward_output_pos
            ]
658 659 660 661 662 663 664 665 666
        logging.info(
            f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}"
        )
        logging.info(
            f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}"
        )
        logging.info(
            f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
        )
667

668
    def GenerateNodeCreationCodes(self):
669 670 671 672
        forward_api_name = self.forward_api_name
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
673
        backward_forward_inputs_map = self.backward_forward_inputs_map
674 675
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
676
        backward_attrs_list = self.backward_attrs_list
677
        optional_inputs = self.optional_inputs
678

679
        # Pass Stop Gradient Args
680
        pass_stop_gradient_args_list = ["false"]
681
        for name, (_, _) in forward_outputs_position_map.items():
682 683 684
            output_autograd_meta_name = GetAutoGradMetaName(name)
            pass_stop_gradient_args_list.append(output_autograd_meta_name)
        pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
685

686 687 688 689
        # Node Construction        
        num_backward_inputs = len(forward_outputs_position_map.keys())
        num_backward_outputs = len(forward_inputs_position_map.keys())
        grad_node_name = GetGradNodeName(forward_api_name)
690 691 692

        # Helper
        indent = GetIndent(2)
693 694 695 696 697
        # NOTE(Aurelius74): DO NOT use make_shared here. Because some Node contains experimental::Scalar
        # which contains "complex128" as data. "complex128" is memory-aligned manually. But make_shared
        # request MEMALIGN for allocation (Maybe).
        # See https://stackoverflow.com/questions/31228656/how-can-shared-ptr-disrupt-alignment
        # and https://github.com/MRtrix3/mrtrix3/issues/957
698
        node_construction_str = f"{indent}auto grad_node = std::shared_ptr<{grad_node_name}>(new {grad_node_name}({num_backward_inputs}, {num_backward_outputs}));"
699 700 701 702 703 704 705 706 707

        # SetAttributes
        set_attributes_list = []
        forward_attrs_name_set = set()
        for name, _, _, _ in forward_attrs_list:
            forward_attrs_name_set.add(name)

        for name, _, default_val_attr, _ in backward_attrs_list:
            if name in forward_attrs_name_set:
708
                set_attributes = f"{indent}grad_node->SetAttribute{name}({name});"
709
            else:
710
                set_attributes = f"{indent}grad_node->SetAttribute{name}({default_val_attr});"
711 712
            set_attributes_list.append(set_attributes)
        set_attributes_str = "\n".join(set_attributes_list)
713

714 715
        # SetTensorWrappers
        set_tensor_wrappers_list = []
716
        num_fwd_outputs = len(forward_outputs_position_map.keys())
717 718 719
        for name, (atype, is_fwd_input,
                   pos) in backward_forward_inputs_map.items():
            is_optional = (name in optional_inputs)
720

721
            if is_fwd_input:
H
hong 已提交
722
                need_input_data = "false" if name in self.no_need_buffers else "true"
723
                if is_optional:
724
                    set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), true);"
725
                else:
H
hong 已提交
726
                    set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, {need_input_data});"
727 728 729 730 731 732
            else:
                if num_fwd_outputs > 1:
                    # Aligned with forward output position
                    assert name in forward_outputs_position_map.keys(
                    ), AssertMessage(name, forward_outputs_position_map.keys())
                    fwd_output_pos = forward_outputs_position_map[name][1]
733

734
                if is_optional:
735
                    set_tensor_wrappers = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetTensorWrapper{name}(*({name}.get_ptr()), false);"
736
                else:
737
                    set_tensor_wrappers = f"{indent}grad_node->SetTensorWrapper{name}({name}, false);"
738 739
            set_tensor_wrappers_list.append(set_tensor_wrappers)
        set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
740

741 742 743 744 745
        # SetGradOutMeta & SetEdges
        set_grad_out_meta_list = []
        set_edges_list = []
        for name, (_, pos) in forward_inputs_position_map.items():
            input_autograd_meta_name = GetAutoGradMetaName(name)
H
hong 已提交
746 747
            is_optional = (name in self.optional_inputs)
            if is_optional:
748 749
                set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});"
                set_edges = f"{indent}if({name}.get_ptr() != nullptr)  grad_node->AddEdges({input_autograd_meta_name}, {pos});"
H
hong 已提交
750
            else:
751 752
                set_grad_out_meta = f"{indent}grad_node->SetGradOutMeta({name}, {pos});"
                set_edges = f"{indent}grad_node->AddEdges({input_autograd_meta_name}, {pos});"
753

754 755 756 757
            set_grad_out_meta_list.append(set_grad_out_meta)
            set_edges_list.append(set_edges)
        set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
        set_edges_str = "\n".join(set_edges_list)
758

759 760 761 762 763 764 765 766
        # SetOutRank & SetHistory & SetGradInMeta
        set_out_rank_list = []
        set_history_list = []
        set_grad_in_meta_list = []
        set_retain_grad_list = []
        num_outputs = len(forward_outputs_position_map.keys())
        for name, (_, pos) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
767 768
            set_out_rank = f"{indent}egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
            set_history = f"{indent}egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
769

770 771
            set_retain_grad = f"{indent}egr::EagerUtils::CheckAndRetainGrad({name});"
            set_grad_in_meta = f"{indent}grad_node->SetGradInMeta({name}, {pos});"
772 773 774 775
            set_out_rank_list.append(set_out_rank)
            set_history_list.append(set_history)
            set_grad_in_meta_list.append(set_grad_in_meta)
            set_retain_grad_list.append(set_retain_grad)
776

777 778 779 780
        set_out_rank_str = "\n".join(set_out_rank_list)
        set_history_str = "\n".join(set_history_list)
        set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
        set_retain_grad_str = "\n".join(set_retain_grad_list)
781

782
        node_event_name = forward_api_name + " node_creation"
783
        node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
784

785
        self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
786 787 788 789
            node_creation_event_str, pass_stop_gradient_args_str,
            node_construction_str, set_attributes_str, set_tensor_wrappers_str,
            set_grad_out_meta_str, set_edges_str, set_out_rank_str,
            set_history_str, set_grad_in_meta_str, set_retain_grad_str)
790

791 792 793
    def run(self):
        # Basic Validation Check
        self.DygraphYamlValidationCheck()
794

795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
        ##########################
        ## Parsing Raw Contents ##
        ##########################
        # Parse inplace_map
        self.ParseInplaceInfo()

        # Parse no_need_buffer
        self.ParseNoNeedBuffer()

        # Parse optional_inputs
        self.ParseDispensable()

        # Parse intermediate_outputs
        self.ParseIntermediate()
        self.IntermediateValidationCheck()

        # Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
        self.CollectBackwardInfo()

        # Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
        self.CollectForwardInfoFromBackwardContents()

        # Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
        self.CollectOriginalForwardInfo()

        # Forwards Validation Check
        self.ForwardsValidationCheck()

        #############################
        ## Process Parsed Contents ##
        #############################
        # Initialize forward_inputs_position_map, forward_outputs_position_map
        self.DetermineForwardPositionMap(self.forward_inputs_list,
                                         self.forward_returns_list)

        # Initialize forward_inputs_position_map, forward_outputs_position_map
        self.SlotNameMatching()

        # Backward Validation Check
        self.BackwardValidationCheck()


class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
    def __init__(self, forward_api_contents, grad_api_contents, namespace):
        DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
                                              grad_api_contents, namespace)

        # Generated Results
        self.forward_definition_str = ""
        self.forward_declaration_str = ""
845 846 847 848 849 850 851 852 853 854 855 856 857 858 859

    def GenerateForwardDefinition(self, is_inplaced):
        namespace = self.namespace
        forward_api_name = GetInplacedFunctionName(
            self.forward_api_name) if is_inplaced else self.forward_api_name
        backward_api_name = self.backward_api_name
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
        backward_attrs_list = self.backward_attrs_list
        optional_inputs = self.optional_inputs
        intermediate_outputs = self.intermediate_outputs
860
        inplace_map = self.inplace_map if is_inplaced else {}
861
        indent = GetIndent(1)
862 863 864 865 866 867 868

        # Get Function Args
        num_inputs = len(forward_attrs_list) + len(
            forward_inputs_position_map.keys())
        inputs_args_definition_list = ["" for i in range(num_inputs)]
        inputs_args_declaration_list = ["" for i in range(num_inputs)]
        inputs_call_list = ["" for i in range(num_inputs)]
869 870 871 872 873
        amp_inputs_call_list = ["" for i in range(num_inputs)]
        amp_tensors_vector_list = []
        amp_tensors_vector_optional_list = []
        amp_autocast_list = []
        amp_autocast_optional_list = []
874 875
        for name, (ttype, pos) in forward_inputs_position_map.items():
            inputs_call_list[pos] = f"{name}"
876
            amp_inputs_call_list[pos] = f"NEW_{name}"
877 878 879
            is_optional = (name in optional_inputs)
            if IsPlainTensorType(ttype):
                if is_optional:
H
hong 已提交
880
                    arg_str = f"const paddle::optional<const paddle::experimental::Tensor&> {name}"
881
                    amp_tensors_vector_optional_list.append(
882
                        f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n"
883 884
                    )
                    amp_autocast_optional_list.append(
885
                        f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
886
                    )
887
                else:
888 889
                    if is_inplaced and inplace_map and name in inplace_map.keys(
                    ):
890
                        arg_str = f"paddle::experimental::Tensor& {name}"
891 892 893 894
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
                            f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
                        )
895 896
                    else:
                        arg_str = f"const paddle::experimental::Tensor& {name}"
897 898 899 900
                        amp_tensors_vector_list.append(f"{{{name}}}")
                        amp_autocast_list.append(
                            f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
                        )
901 902 903
            else:
                assert IsVectorTensorType(ttype)
                arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
904 905 906 907
                amp_tensors_vector_list.append(f"{name}")
                amp_autocast_list.append(
                    f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name);\n"
                )
908 909 910 911 912 913

            inputs_args_definition_list[pos] = arg_str
            inputs_args_declaration_list[pos] = arg_str

        for name, atype, default_val, pos in forward_attrs_list:
            inputs_call_list[pos] = name
914
            amp_inputs_call_list[pos] = name
915 916 917 918 919 920 921 922 923 924 925 926 927 928
            if default_val is not None:
                inputs_args_declaration_list[
                    pos] = f"{atype} {name} = {default_val}"
            else:
                inputs_args_declaration_list[pos] = f"{atype} {name}"
            inputs_args_definition_list[pos] = f"{atype} {name}"

        inputs_args_declaration_str = ", ".join(inputs_args_declaration_list)
        inputs_args_definition_str = ", ".join(inputs_args_definition_list)
        inputs_call_args_str = ", ".join(inputs_call_list)

        # Forward Full Logic
        function_name = forward_api_name
        if len(intermediate_outputs) > 0:
929 930 931 932 933
            if is_inplaced:
                function_name = GetIntermediateAPIFunctionName(
                    forward_api_name[:-1]) + '_'
            else:
                function_name = GetIntermediateAPIFunctionName(function_name)
934

935
        forward_call_str = f"{indent}auto api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});"
936 937
        num_outputs = len(forward_outputs_position_map.keys()) - len(
            intermediate_outputs)
938 939 940 941 942

        # Get Outputs
        get_outputs_str = ""
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if num_outputs == 1 and len(intermediate_outputs) == 0:
943
                get_outputs_str += f"{indent}auto& {name} = api_result;\n"
944
            else:
945
                get_outputs_str += f"{indent}auto& {name} = std::get<{pos}>(api_result);\n"
946 947

        # Get return type list & outputs
948 949 950 951 952
        returns_type_list = ["" for i in range(num_outputs)]
        returns_list = ["" for i in range(num_outputs)]
        for name, (rtype, pos) in forward_outputs_position_map.items():
            if name in intermediate_outputs:
                continue
953
            returns_list[pos] = f"{name}"
954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970

            if IsPlainTensorType(rtype):
                returns_type_list[pos] = "paddle::experimental::Tensor"
            else:
                assert IsVectorTensorType(rtype)
                returns_type_list[
                    pos] = "std::vector<paddle::experimental::Tensor>"

        if num_outputs == 1:
            returns_str = returns_list[0]
            returns_type_str = returns_type_list[0]
        else:
            returns_type_str = ", ".join(returns_type_list)
            returns_type_str = f"std::tuple<{returns_type_str}>"
            returns_str = ", ".join(returns_list)
            returns_str = f"std::make_tuple({returns_str})"

971 972 973 974 975 976 977
        # Node Creation Pre-Processing
        # 1. Get Input AutoGradMeta
        inputs_autograd_meta_list = []
        compute_require_grad_args_list = ["trace_backward"]
        for name, (ttype, pos) in forward_inputs_position_map.items():
            input_autograd_meta_name = GetAutoGradMetaName(name)
            if IsPlainTensorType(ttype):
978
                input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
979 980 981
            else:
                assert IsVectorTensorType(ttype)
                input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
982 983
                input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
                input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
984 985 986 987 988 989 990 991 992 993 994 995 996 997

            inputs_autograd_meta_list.append(input_autograd_meta)
            compute_require_grad_args_list.append(input_autograd_meta_name)
        inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
        compute_require_grad_args_str = ",".join(compute_require_grad_args_list)

        # 2. Get Output AutoGradMeta
        outputs_autograd_meta_list = []
        num_fwd_outputs = len(forward_outputs_position_map.keys())
        for name, (rtype, pos) in forward_outputs_position_map.items():
            output_autograd_meta_name = GetAutoGradMetaName(name)
            output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
            if num_fwd_outputs == 1:
                if IsPlainTensorType(rtype):
998
                    output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
999 1000
                else:
                    assert IsVectorTensorType(rtype)
1001 1002
                    output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                    output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1003 1004 1005
            else:
                # Tuple api_result
                if IsPlainTensorType(rtype):
1006
                    output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
1007 1008
                else:
                    assert IsVectorTensorType(rtype)
1009 1010
                    output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
                    output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1011 1012 1013 1014

            outputs_autograd_meta_list.append(output_autograd_meta)
        outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)

1015
        # 3. Check Inplace
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026
        check_inplace_str = ""
        bump_inplace_version_str = ""
        if is_inplaced:
            for inplace_name in inplace_map.keys():
                inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
                check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
                    inplace_name, inplace_autograd_meta_name)
                bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
                    inplace_name, inplace_name)

        self.GenerateNodeCreationCodes()
1027 1028

        node_creation_str = self.node_creation_str
1029
        dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
1030 1031
        forward_function_name = GetDygraphForwardFunctionName(forward_api_name)

1032 1033 1034 1035 1036 1037 1038 1039
        # Forward amp logic
        kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
        amp_tensors_vector_list_str = "{ " + ",".join(
            amp_tensors_vector_list) + " }"
        amp_tensors_vector_optional_list_str = "".join(
            amp_tensors_vector_optional_list)
        amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
        amp_autocast_list_str = "        ".join(
1040 1041
            amp_autocast_list) + "        " + "        ".join(
                amp_autocast_optional_list)
1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
        amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
        amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});"
        if is_inplaced or (forward_api_name == "cast"):
            amp_logic_str = ""
        else:
            amp_logic_str = AMP_LOGIC_TEMPLATE.format(
                kernel_trans2_op_name_str, amp_tensors_vector_list_str,
                amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
                amp_autocast_list_str, amp_call_str)

1052 1053
        self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
            returns_type_str, forward_function_name, inputs_args_definition_str,
1054
            dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
1055
            forward_call_str, get_outputs_str, outputs_autograd_meta_str,
1056 1057
            compute_require_grad_args_str, check_inplace_str,
            bump_inplace_version_str, node_creation_str, returns_str)
1058 1059
        self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"

1060 1061 1062 1063
        logging.info(
            f"Generated Forward Definition: {self.forward_definition_str}")
        logging.info(
            f"Generated Forward Declaration: {self.forward_declaration_str}")
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109

    def GenerateInplacedForwardDygraphFunctions(self):
        # Inplaced Version Dygraph Function Generation
        forward_api_name = self.forward_api_name
        forward_api_contents = self.forward_api_contents

        if forward_api_name != "sum" and "inplace" in forward_api_contents.keys(
        ):
            # Node Definition Generation
            self.GenerateForwardDefinition(is_inplaced=True)
            self.UpdateCoreOpsInformation(is_inplaced=True)

    def UpdateCoreOpsInformation(self, is_inplaced):
        forward_api_name = GetInplacedFunctionName(
            self.forward_api_name) if is_inplaced else self.forward_api_name
        forward_inputs_position_map = self.forward_inputs_position_map
        forward_outputs_position_map = self.forward_outputs_position_map
        forward_attrs_list = self.forward_attrs_list

        num_args = len(forward_inputs_position_map.keys()) + len(
            forward_attrs_list)
        num_returns = len(forward_outputs_position_map.keys())

        final_state_fwd_api_name = "final_state_" + forward_api_name
        core_ops_returns_info[
            final_state_fwd_api_name] = ["" for i in range(num_returns)]
        core_ops_args_info[
            final_state_fwd_api_name] = ["" for i in range(num_args)]
        core_ops_args_type_info[
            final_state_fwd_api_name] = ["" for i in range(num_args)]
        for name, (ttype, pos) in forward_inputs_position_map.items():
            core_ops_args_info[final_state_fwd_api_name][pos] = name
            if IsPlainTensorType(ttype):
                core_ops_args_type_info[final_state_fwd_api_name][
                    pos] = "tensor"
            else:
                assert IsVectorTensorType(ttype)
                core_ops_args_type_info[final_state_fwd_api_name][pos] = "list"

        for name, _, _, pos in forward_attrs_list:
            core_ops_args_info[final_state_fwd_api_name][pos] = name

        for name, (ttype, pos) in forward_outputs_position_map.items():
            core_ops_returns_info[final_state_fwd_api_name][pos] = name

    def run(self):
1110
        super().run()
1111

1112 1113 1114 1115
        #####################
        ## Code Generation ##
        #####################
        self.GenerateForwardDefinition(is_inplaced=False)
1116

1117
        self.UpdateCoreOpsInformation(is_inplaced=False)
1118

1119
        self.GenerateInplacedForwardDygraphFunctions()
1120 1121


1122
class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
1123 1124 1125 1126 1127
    def __init__(self,
                 forward_api_contents,
                 grad_api_contents,
                 namespace,
                 next_grad_api_contents=None):
1128 1129
        DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
                                              grad_api_contents, namespace)
1130

1131 1132 1133
        # Record name mapping from forward_api_name to grad_api_names
        self.to_next_grad_name_mapping = {}  # {name : name}

1134 1135 1136
        # Generated Results
        self.node_declaration_str = ""
        self.node_definition_str = ""
1137
        self.next_grad_api_contents = next_grad_api_contents
1138

1139 1140 1141 1142 1143 1144
    def TransformToNextGradName(self, string):
        name_mapping = self.to_next_grad_name_mapping
        if string in name_mapping.keys():
            return name_mapping[string]
        return string

1145 1146 1147 1148 1149 1150 1151 1152 1153
    def ResetOptionalInputs(self):
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents

        base_generator = FunctionGeneratorBase(grad_api_contents, namespace)
        base_generator.ParseDispensable()

        self.optional_inputs = base_generator.optional_inputs

1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169
    def RecordGrad2NextGradNameMapping(self, next_node_generator):
        next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
        next_orig_returns_list = next_node_generator.orig_forward_returns_list

        next_forward_inputs_list = next_node_generator.forward_inputs_list
        next_forward_returns_list = next_node_generator.forward_returns_list
        for i in range(len(next_orig_inputs_list)):
            grad_name = next_orig_inputs_list[i][0]
            next_forward_name = next_forward_inputs_list[i][0]
            self.to_next_grad_name_mapping[grad_name] = next_forward_name

        for i in range(len(next_orig_returns_list)):
            grad_ret_name = next_orig_returns_list[i][0]
            next_ret_name = next_forward_returns_list[i][0]
            self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
    def GenerateHigherOrderNodeCreationCode(self):
        namespace = self.namespace
        grad_api_contents = self.grad_api_contents
        next_grad_api_contents = self.next_grad_api_contents

        grad_node_creation_str = ""
        if next_grad_api_contents:
            forward_api_contents = grad_api_contents
            forward_api_contents['api'] = forward_api_contents['backward_api']
            backward_api_contents = next_grad_api_contents

            next_node_generator = DygraphFunctionGeneratorBase(
                forward_api_contents, backward_api_contents, namespace)
            next_node_generator.run()
            next_node_generator.GenerateNodeCreationCodes()
            grad_node_creation_str = next_node_generator.node_creation_str

1187 1188
            self.RecordGrad2NextGradNameMapping(next_node_generator)

1189 1190
        return grad_node_creation_str

1191 1192 1193 1194 1195
    def GenerateNodeDeclaration(self):
        forward_op_name = self.forward_api_name
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_attrs_list = self.backward_attrs_list
        no_need_buffers = self.no_need_buffers
1196

1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
        # SetTensorWrapper Methods & TensorWrapper Members
        set_tensor_wrapper_methods_str = ""
        tensor_wrapper_members_str = ""
        clear_tensor_wrapper_str = ""
        for tname, (ttype, is_fwd_input,
                    _) in backward_forward_inputs_map.items():
            no_need_buffer = "true" if tname in no_need_buffers else "false"
            tensor_wrapper_name = GetSavedName(tname)
            if IsPlainTensorType(ttype):
                set_tensor_wrapper_methods_str += SET_PLAIN_TENSOR_WRAPPER_TEMPLATE.format(
                    tname, tname, tensor_wrapper_name, tname, no_need_buffer)
1208

1209 1210
                tensor_wrapper_members_str += PLAIN_TENSOR_MEMBER_TEMPLATE.format(
                    tensor_wrapper_name)
1211

1212 1213
                clear_tensor_wrapper_str += CLEAR_TENSOR_WRAPPER_TEMPLATE.format(
                    tensor_wrapper_name)
1214

1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
            else:
                assert IsVectorTensorType(ttype)
                set_tensor_wrapper_methods_str += SET_VECTOR_TENSOR_WRAPPER_TEMPLATE.format(
                    tname, tname, tname, tensor_wrapper_name, no_need_buffer)

                tensor_wrapper_members_str += VECTOR_TENSOR_MEMBER_TEMPLATE.format(
                    tensor_wrapper_name)

                clear_tensor_wrapper_str += CLEAR_VECTOR_TENSOR_WRAPPERS_TEMPLATE.format(
                    tensor_wrapper_name)

        # SetAttributes & Attribute Members
        set_attribute_methods_str = ""
        attribute_members_str = ""
        for aname, atype, default_val, _ in backward_attrs_list:
            saved_attr_name = GetSavedName(aname)
            set_attribute_methods_str += SET_ATTR_METHOD_TEMPLATE.format(
                aname, GetConstReference(atype), aname, saved_attr_name, aname)

            if default_val:
                attribute_members_str += ATTRIBUTE_MEMBER_WITH_DEFAULT_TEMPLATE.format(
                    RemoveConstAndReference(atype), saved_attr_name,
                    default_val)
            else:
                attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
                    RemoveConstAndReference(atype), saved_attr_name)

        grad_node_name = GetGradNodeName(forward_op_name)
        self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
            grad_node_name, grad_node_name, grad_node_name, grad_node_name,
1245 1246 1247 1248
            grad_node_name, clear_tensor_wrapper_str, grad_node_name,
            grad_node_name, set_tensor_wrapper_methods_str,
            set_attribute_methods_str, tensor_wrapper_members_str,
            attribute_members_str)
1249 1250 1251

        logging.info(f"Generated Node Declaration: {self.node_declaration_str}")

1252
    def GenerateNodeDefinition(self, grad_node_creation_str):
1253 1254 1255 1256 1257 1258 1259
        namespace = self.namespace
        forward_api_name = self.forward_api_name
        backward_api_name = self.backward_api_name
        backward_forward_inputs_map = self.backward_forward_inputs_map
        backward_grad_inputs_map = self.backward_grad_inputs_map
        backward_grad_outputs_map = self.backward_grad_outputs_map
        backward_attrs_list = self.backward_attrs_list
1260
        indent = GetIndent(1)
1261 1262 1263 1264 1265 1266

        # Construct grad_api function args
        # Order: TensorWrappers, GradTensors, Attributes
        grad_api_args_len = len(backward_forward_inputs_map.keys()) + len(
            backward_grad_inputs_map.keys()) + len(backward_attrs_list)
        grad_api_args = ["" for i in range(grad_api_args_len)]
1267 1268 1269 1270
        get_grad_in_args_list = []

        # Fill Grad Ins with Zero
        fill_zero_str = ""
1271 1272
        if backward_api_name in ops_to_fill_zero_for_empty_grads:
            fill_zero_str = f"{indent}egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
1273 1274

        # Grad Ins from TensorWrappers
1275 1276 1277
        for name, (_, is_fwd_input,
                   grad_api_position), in backward_forward_inputs_map.items():
            tensor_wrapper_name = GetSavedName(name)
1278
            transformed_tensor_name = self.TransformToNextGradName(name)
1279

1280
            is_optional = (name in self.optional_inputs)
H
hong 已提交
1281
            tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
1282
            if is_optional:
H
hong 已提交
1283 1284 1285 1286 1287 1288 1289
                tensor_wrapper_recover_str += "\n" + CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE.format(
                    transformed_tensor_name, transformed_tensor_name,
                    transformed_tensor_name, transformed_tensor_name)

                grad_api_args[
                    grad_api_position] = transformed_tensor_name + "_optional"

1290
            else:
H
hong 已提交
1291 1292
                grad_api_args[grad_api_position] = transformed_tensor_name

1293 1294 1295 1296 1297
            get_grad_in_args_list.append(tensor_wrapper_recover_str)

        # Grad Ins from grads
        for name, (ttype, fwd_position,
                   grad_api_position) in backward_grad_inputs_map.items():
1298
            transformed_tensor_name = self.TransformToNextGradName(name)
1299

1300
            is_optional = (name in self.optional_inputs)
1301
            if IsPlainTensorType(ttype):
1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];"

                if is_optional:
                    get_tensor_str += "\n" + CREATE_PLAIN_OPTIONAL_TENSOR_TEMPLATE.format(
                        transformed_tensor_name, transformed_tensor_name,
                        transformed_tensor_name, transformed_tensor_name)
                    grad_api_args[
                        grad_api_position] = f"{transformed_tensor_name}_optional"
                else:
                    grad_api_args[grad_api_position] = transformed_tensor_name
1312 1313
            else:
                assert IsVectorTensorType(ttype)
1314 1315 1316
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}];"
                grad_api_args[grad_api_position] = transformed_tensor_name

1317
            get_grad_in_args_list.append(get_tensor_str)
1318

1319
        # Grad Attrs
1320 1321
        for name, _, _, grad_api_position in backward_attrs_list:
            saved_attribute_name = GetSavedName(name)
1322
            get_attr_str = f"{indent}auto& {name} = this->{saved_attribute_name};"
1323 1324 1325 1326 1327

            grad_api_args[grad_api_position] = name
            get_grad_in_args_list.append(get_attr_str)

        get_grad_in_args_str = "\n".join(get_grad_in_args_list)
1328 1329
        grad_api_args_str = ", ".join(grad_api_args)

1330 1331
        # Grad Function Call String
        grad_api_namespace = f"paddle::experimental::{namespace}"
1332
        grad_function_call_str = f"{indent}auto grad_api_result = {grad_api_namespace}{backward_api_name}({grad_api_args_str});"
1333 1334 1335 1336 1337 1338

        # Get Grad Outputs
        get_outputs_str = ""
        num_outputs = len(backward_grad_outputs_map.keys())
        for name, (ttype, fwd_position,
                   grad_api_position) in backward_grad_outputs_map.items():
1339
            transformed_tensor_name = self.TransformToNextGradName(name)
1340 1341

            if num_outputs == 1:
1342
                get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
1343
            else:
1344 1345 1346 1347 1348
                if IsPlainTensorType(ttype):
                    get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}][0];"
                else:
                    assert IsVectorTensorType(ttype)
                    get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result[{grad_api_position}];"
1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360
            get_outputs_str += get_tensor_str + "\n"

        # Prepare for Node Creation if Necessary
        inputs_autograd_meta_str = ""
        outputs_autograd_meta_str = ""
        compute_require_grad_str = ""
        if len(grad_node_creation_str) > 0:
            # 1. Get Input AutoGradMeta
            inputs_autograd_meta_list = []
            compute_require_grad_args_list = ["trace_backward"]
            for name, (ttype, pos,
                       grad_api_position) in backward_grad_inputs_map.items():
1361
                transformed_tensor_name = self.TransformToNextGradName(name)
1362 1363 1364 1365

                input_autograd_meta_name = GetAutoGradMetaName(
                    transformed_tensor_name)
                if IsPlainTensorType(ttype):
1366
                    input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
1367 1368 1369 1370
                else:
                    assert IsVectorTensorType(ttype)
                    input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
                        transformed_tensor_name)
1371 1372
                    input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
                    input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
1373 1374 1375 1376 1377 1378

                inputs_autograd_meta_list.append(input_autograd_meta)
                compute_require_grad_args_list.append(input_autograd_meta_name)

            # 2. Get TensorWrapper AutoGradMeta
            for name, (ttype, _, pos), in backward_forward_inputs_map.items():
1379
                transformed_tensor_name = self.TransformToNextGradName(name)
1380 1381 1382 1383

                input_autograd_meta_name = GetAutoGradMetaName(
                    transformed_tensor_name)
                if IsPlainTensorType(ttype):
1384
                    input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});"
1385 1386 1387 1388
                else:
                    assert IsVectorTensorType(ttype)
                    input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
                        transformed_tensor_name)
1389 1390
                    input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({transformed_tensor_name});\n"
                    input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401

                inputs_autograd_meta_list.append(input_autograd_meta)
                compute_require_grad_args_list.append(input_autograd_meta_name)
            inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
            compute_require_grad_args_str = ",".join(
                compute_require_grad_args_list)

            # 3. Get Output AutoGradMeta
            outputs_autograd_meta_list = []
            num_fwd_outputs = len(backward_grad_outputs_map.keys())
            for name, (rtype, pos, _) in backward_grad_outputs_map.items():
1402
                transformed_tensor_name = self.TransformToNextGradName(name)
1403 1404 1405 1406 1407 1408 1409

                output_autograd_meta_name = GetAutoGradMetaName(
                    transformed_tensor_name)
                output_autograd_meta_vec_name = GetAutoGradMetaVectorName(
                    transformed_tensor_name)
                if num_fwd_outputs == 1:
                    if IsPlainTensorType(rtype):
1410
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
1411 1412
                    else:
                        assert IsVectorTensorType(rtype)
1413 1414
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1415 1416 1417
                else:
                    # Tuple api_result
                    if IsPlainTensorType(rtype):
1418
                        output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});"
1419 1420
                    else:
                        assert IsVectorTensorType(rtype)
1421 1422
                        output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name});\n"
                        output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
1423 1424 1425 1426

                outputs_autograd_meta_list.append(output_autograd_meta)
            outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)

1427 1428
            compute_require_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
            compute_require_grad_str += f"{indent}bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({compute_require_grad_args_str});"
1429

1430 1431 1432
        # Construct grad_api returns
        num_bwd_outputs = len(backward_grad_outputs_map.keys())
        slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys())
1433
        returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
1434 1435
        for name, (ttype, fwd_position,
                   grad_api_position) in backward_grad_outputs_map.items():
1436
            transformed_tensor_name = self.TransformToNextGradName(name)
1437

1438 1439 1440 1441
            # Infer Grad API Return Type
            if num_bwd_outputs == 1:
                # Single tensor output, return as is
                if IsPlainTensorType(ttype):
1442
                    returns_str += f"{indent}returns[0] = {{ {transformed_tensor_name} }};\n"
1443 1444
                else:
                    assert IsVectorTensorType(ttype)
1445
                    returns_str += f"{indent}returns[0] = {transformed_tensor_name};\n"
1446 1447
            else:
                # Rearrange output order accordingly
1448 1449 1450 1451 1452 1453 1454 1455
                if IsPlainTensorType(ttype):
                    returns_str += f"{indent}returns[{fwd_position}] = {{ {transformed_tensor_name} }};\n"
                else:
                    assert IsVectorTensorType(ttype)
                    returns_str += f"{indent}returns[{fwd_position}] = {transformed_tensor_name};\n"

        returns_str += f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n"
        returns_str += f"{indent}return returns;\n"
1456 1457 1458

        grad_node_name = GetGradNodeName(forward_api_name)

1459 1460 1461
        if len(grad_node_creation_str) == 0:
            grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"

1462 1463 1464 1465 1466
        self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
            grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
            grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
            outputs_autograd_meta_str, compute_require_grad_str,
            grad_node_creation_str, returns_str)
1467 1468 1469 1470 1471

        logging.info(f"Generated Node Definition: {self.node_definition_str}")

    def run(self):
        super().run()
1472

1473 1474
        self.ResetOptionalInputs()

1475 1476 1477
        #####################
        ## Code Generation ##
        #####################
1478 1479
        # Higher-order GradNode generation
        grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()
1480

1481 1482
        self.GenerateNodeDeclaration()

1483
        self.GenerateNodeDefinition(grad_node_creation_str)
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513


class DygraphYamlGenerator(YamlGeneratorBase):
    def __init__(self, api_yaml_path, backward_yaml_path):
        # Parent members: 
        # self.namespace
        # self.api_yaml_path
        # self.forward_api_list
        YamlGeneratorBase.__init__(self, api_yaml_path)

        self.backward_yaml_path = backward_yaml_path
        self.grad_api_dict = {}

        self.forward_definition_str = ""
        self.forward_declaration_str = ""
        self.node_declaration_str = ""
        self.node_definition_str = ""

    def ParseYamlContents(self):
        self.ParseForwardYamlContents()

        backward_yaml_path = self.backward_yaml_path
        self.grad_api_dict = ReadBwdFile(backward_yaml_path)

    def GetBackwardAPIContents(self, forward_api_contents):
        grad_api_dict = self.grad_api_dict

        if 'backward' not in forward_api_contents.keys(): return None

        backward_api_name = forward_api_contents['backward']
1514 1515
        assert backward_api_name in grad_api_dict.keys(), AssertMessage(
            backward_api_name, grad_api_dict.keys())
1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529
        backward_api_contents = grad_api_dict[backward_api_name]

        return backward_api_contents

    def GenerateCode(self):
        forward_api_list = self.forward_api_list
        grad_api_dict = self.grad_api_dict
        namespace = self.namespace

        for forward_api_contents in forward_api_list:
            backward_api_contents = self.GetBackwardAPIContents(
                forward_api_contents)
            if backward_api_contents is None: continue

1530
            # Generate Dygraph Forward Function
1531 1532 1533 1534 1535 1536
            function_generator = DygraphForwardFunctionGenerator(
                forward_api_contents, backward_api_contents, namespace)
            function_generator.run()

            self.forward_definition_str += function_generator.forward_definition_str + "\n"
            self.forward_declaration_str += function_generator.forward_declaration_str + "\n"
1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557

            while True:
                next_grad_api_contents = self.GetBackwardAPIContents(
                    backward_api_contents)

                node_generator = DygraphNodeGenerator(
                    forward_api_contents, backward_api_contents, namespace,
                    next_grad_api_contents)
                node_generator.run()
                self.node_declaration_str += node_generator.node_declaration_str + "\n"
                self.node_definition_str += node_generator.node_definition_str + "\n"

                if next_grad_api_contents is None: break

                # Detect if there exists higher-order GradNode
                forward_api_contents = backward_api_contents

                # Fake forward_api_content
                forward_api_contents['api'] = forward_api_contents[
                    'backward_api']
                backward_api_contents = next_grad_api_contents
1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581

        if len(namespace) > 0:
            if namespace.endswith("::"):
                namespace = namespace[:-2]
            self.forward_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
                namespace, self.forward_definition_str)
            self.forward_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
                namespace, self.forward_declaration_str)
            self.node_declaration_str = NAMESPACE_WRAPPER_TEMPLATE.format(
                namespace, self.node_declaration_str)
            self.node_definition_str = NAMESPACE_WRAPPER_TEMPLATE.format(
                namespace, self.node_definition_str)

    def run(self):
        self.ParseYamlContents()

        self.InferNameSpace()

        self.GenerateCode()


##################
## File Writers ##
##################
1582
def GenerateNodeCCFile(filepath, node_definition_str):
1583 1584
    if os.path.exists(filepath):
        os.remove(filepath)
1585

1586
    file_contents = NODE_CC_FILE_TEMPLATE.format(node_definition_str)
1587 1588 1589 1590 1591
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateNodeHFile(filepath, node_declaration_str):
1592 1593
    if os.path.exists(filepath):
        os.remove(filepath)
1594

1595
    file_contents = NODE_H_FILE_TEMPLATE.format(node_declaration_str)
1596 1597 1598 1599 1600
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardCCFile(filepath, forward_definition_str):
1601 1602
    if os.path.exists(filepath):
        os.remove(filepath)
1603

1604 1605 1606
    core_ops_info_str = GenerateCoreOpInfoDefinition()
    file_contents = FORWARD_CC_FILE_TEMPLATE.format(core_ops_info_str,
                                                    forward_definition_str)
1607 1608 1609 1610 1611
    with open(filepath, 'a') as f:
        f.write(file_contents)


def GenerateForwardHFile(filepath, forward_function_declaration_str):
1612 1613
    if os.path.exists(filepath):
        os.remove(filepath)
1614

1615 1616 1617
    core_ops_info_str = GenerateCoreOpInfoDeclaration()
    file_contents = FORWARD_H_FILE_TEMPLATE.format(
        core_ops_info_str, forward_function_declaration_str)
1618 1619 1620 1621 1622 1623 1624
    with open(filepath, 'a') as f:
        f.write(file_contents)


if __name__ == "__main__":
    args = ParseArguments()

1625 1626
    api_yaml_paths = args.api_yaml_path.split(",")
    backward_yaml_paths = args.backward_yaml_path.split(",")
1627 1628 1629 1630 1631 1632 1633

    # Generate per Dygraph API
    node_declaration_str = ""
    node_definition_str = ""
    forward_definition_str = ""
    forward_declaration_str = ""

1634 1635 1636 1637
    for i in range(len(api_yaml_paths)):
        api_yaml_path = api_yaml_paths[i]
        backward_yaml_path = backward_yaml_paths[i]

1638 1639
        generator = DygraphYamlGenerator(api_yaml_path, backward_yaml_path)
        generator.run()
1640

1641 1642 1643 1644
        node_declaration_str += generator.node_declaration_str + "\n"
        node_definition_str += generator.node_definition_str + "\n"
        forward_definition_str += generator.forward_definition_str + "\n"
        forward_declaration_str += generator.forward_declaration_str + "\n"
1645

1646 1647 1648 1649 1650 1651 1652 1653 1654 1655
    # Generate Files
    nodes_h_path = args.nodes_h_path
    nodes_cc_path = args.nodes_cc_path
    forwards_h_path = args.forwards_h_path
    forwards_cc_path = args.forwards_cc_path

    GenerateNodeCCFile(nodes_cc_path, node_definition_str)
    GenerateNodeHFile(nodes_h_path, node_declaration_str)
    GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
    GenerateForwardHFile(forwards_h_path, forward_declaration_str)