未验证 提交 f027b2ad 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Refactor] refactored eager_gen.py PR #2 (#40907)

上级 5f6038ff
......@@ -50,6 +50,10 @@ yaml_types_mapping = {
#############################
### File Reader Helpers ###
#############################
def AssertMessage(lhs_str, rhs_str):
return f"lhs: {lhs_str}, rhs: {rhs_str}"
def ReadFwdFile(filepath):
f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader)
......@@ -62,10 +66,10 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {}
for content in contents:
assert 'backward_api' in content.keys(), AssertMessage('backward_api',
content.keys())
if 'backward_api' in content.keys():
api_name = content['backward_api']
else:
assert False
ret[api_name] = content
f.close()
......@@ -225,7 +229,7 @@ def ParseYamlReturns(string):
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type
assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type)
ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i])
......
......@@ -16,6 +16,7 @@ import yaml
import re
import argparse
import os
import logging
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile
......@@ -30,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage
###########
......@@ -398,14 +400,21 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents
assert 'api' in forward_api_contents.keys()
assert 'args' in forward_api_contents.keys()
assert 'output' in forward_api_contents.keys()
assert 'backward' in forward_api_contents.keys()
assert 'args' in grad_api_contents.keys()
assert 'output' in grad_api_contents.keys()
assert 'forward' in grad_api_contents.keys()
assert 'api' in forward_api_contents.keys(
), "Unable to find \"api\" in api.yaml"
assert 'args' in forward_api_contents.keys(
), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in api.yaml"
assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml"
assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"
def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list
......@@ -424,8 +433,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]
assert forward_input_type == orig_input_type
assert forward_input_pos == orig_input_pos
assert forward_input_type == orig_input_type, AssertMessage(
forward_input_type, orig_input_type)
assert forward_input_pos == orig_input_pos, AssertMessage(
forward_input_pos, orig_input_pos)
for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
......@@ -436,9 +447,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type
assert orig_attr_default == forward_attr_default
assert orig_attr_pos == forward_attr_pos
assert orig_attr_type == forward_attr_type, AssertMessage(
orig_attr_type, forward_attr_type)
assert orig_attr_default == forward_attr_default, AssertMessage(
orig_attr_default, forward_attr_default)
assert orig_attr_pos == forward_attr_pos, AssertMessage(
orig_attr_pos, forward_attr_pos)
for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1]
......@@ -446,8 +460,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]
assert orig_return_type == forward_return_type
assert orig_return_pos == forward_return_pos
assert orig_return_type == forward_return_type, AssertMessage(
orig_return_type, forward_return_type)
assert orig_return_pos == forward_return_pos, AssertMessage(
orig_return_pos, forward_return_pos)
# Check Order: Inputs, Attributes
max_input_position = -1
......@@ -456,7 +472,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_attr_position = -1
for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position
assert pos > max_input_position, AssertMessage(pos,
max_input_position)
max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(self):
......@@ -471,12 +488,14 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_inputs_map.items():
assert pos > max_fwd_input_position
assert pos > max_fwd_input_position, AssertMessage(
pos, max_grad_tensor_position)
max_grad_tensor_position = max(max_grad_tensor_position, pos)
max_attr_position = -1
for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position
assert pos > max_grad_tensor_position, AssertMessage(
pos, max_grad_tensor_position)
max_attr_position = max(max_attr_position, pos)
def IntermediateValidationCheck(self):
......@@ -491,7 +510,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
len(forward_returns_list))
for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs:
assert pos in intermediate_positions
assert pos in intermediate_positions, AssertMessage(
pos, intermediate_positions)
def CollectBackwardInfo(self):
forward_api_contents = self.forward_api_contents
......@@ -505,9 +525,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
backward_args_str, backward_returns_str)
print("Parsed Backward Inputs List: ", self.backward_inputs_list)
print("Prased Backward Attrs List: ", self.backward_attrs_list)
print("Parsed Backward Returns List: ", self.backward_returns_list)
logging.info(
f"Parsed Backward Inputs List: {self.backward_inputs_list}")
logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}")
logging.info(
f"Parsed Backward Returns List: {self.backward_returns_list}")
def CollectForwardInfoFromBackwardContents(self):
......@@ -530,7 +553,9 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name:
# Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys()
assert backward_fwd_name in forward_outputs_position_map.keys(
), AssertMessage(backward_fwd_name,
forward_outputs_position_map.keys())
matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[
......@@ -556,7 +581,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_input_type, False, backward_input_pos
]
else:
assert False, backward_input_name
assert False, f"Cannot find {backward_input_name} in forward position map"
for backward_output in backward_returns_list:
backward_output_name = backward_output[0]
......@@ -564,9 +589,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_pos = backward_output[2]
backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None
assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None"
assert backward_fwd_name in forward_inputs_position_map.keys(
), f"Unable to find {backward_fwd_name} in forward inputs"
), AssertMessage(backward_fwd_name,
forward_inputs_position_map.keys())
matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0]
......@@ -577,12 +603,15 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_type, matched_forward_input_pos,
backward_output_pos
]
print("Generated Backward Fwd Input Map: ",
self.backward_forward_inputs_map)
print("Generated Backward Grad Input Map: ",
self.backward_grad_inputs_map)
print("Generated Backward Grad Output Map: ",
self.backward_grad_outputs_map)
logging.info(
f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}"
)
logging.info(
f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}"
)
logging.info(
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)
def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name
......@@ -642,7 +671,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
print("Generated Node Declaration: ", self.node_declaration_str)
logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
def GenerateNodeDefinition(self):
namespace = self.namespace
......@@ -710,7 +739,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str)
print("Generated Node Definition: ", self.node_definition_str)
logging.info(f"Generated Node Definition: {self.node_definition_str}")
def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace
......@@ -813,8 +842,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_event_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
print("Generated Forward Definition: ", self.forward_definition_str)
print("Generated Forward Declaration: ", self.forward_declaration_str)
logging.info(
f"Generated Forward Definition: {self.forward_definition_str}")
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")
def GenerateNodeCreationCodes(self, forward_call_str):
forward_api_name = self.forward_api_name
......@@ -921,7 +952,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
assert name in forward_outputs_position_map.keys()
assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else:
......@@ -1114,7 +1146,8 @@ class DygraphYamlGenerator(YamlGeneratorBase):
if 'backward' not in forward_api_contents.keys(): return None
backward_api_name = forward_api_contents['backward']
assert backward_api_name in grad_api_dict.keys()
assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name]
return backward_api_contents
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册