未验证 提交 f027b2ad 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Refactor] refactored eager_gen.py PR #2 (#40907)

上级 5f6038ff
...@@ -50,6 +50,10 @@ yaml_types_mapping = { ...@@ -50,6 +50,10 @@ yaml_types_mapping = {
############################# #############################
### File Reader Helpers ### ### File Reader Helpers ###
############################# #############################
def AssertMessage(lhs_str, rhs_str):
return f"lhs: {lhs_str}, rhs: {rhs_str}"
def ReadFwdFile(filepath): def ReadFwdFile(filepath):
f = open(filepath, 'r') f = open(filepath, 'r')
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
...@@ -62,10 +66,10 @@ def ReadBwdFile(filepath): ...@@ -62,10 +66,10 @@ def ReadBwdFile(filepath):
contents = yaml.load(f, Loader=yaml.FullLoader) contents = yaml.load(f, Loader=yaml.FullLoader)
ret = {} ret = {}
for content in contents: for content in contents:
assert 'backward_api' in content.keys(), AssertMessage('backward_api',
content.keys())
if 'backward_api' in content.keys(): if 'backward_api' in content.keys():
api_name = content['backward_api'] api_name = content['backward_api']
else:
assert False
ret[api_name] = content ret[api_name] = content
f.close() f.close()
...@@ -225,7 +229,7 @@ def ParseYamlReturns(string): ...@@ -225,7 +229,7 @@ def ParseYamlReturns(string):
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping." ), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type, AssertMessage("Tensor", ret_type)
ret_name = RemoveSpecialSymbolsInName(ret_name) ret_name = RemoveSpecialSymbolsInName(ret_name)
returns_list.append([ret_name, ret_type, i]) returns_list.append([ret_name, ret_type, i])
......
...@@ -16,6 +16,7 @@ import yaml ...@@ -16,6 +16,7 @@ import yaml
import re import re
import argparse import argparse
import os import os
import logging
from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info from codegen_utils import core_ops_returns_info, core_ops_args_info, core_ops_args_type_info
from codegen_utils import yaml_types_mapping from codegen_utils import yaml_types_mapping
from codegen_utils import ReadFwdFile, ReadBwdFile from codegen_utils import ReadFwdFile, ReadBwdFile
...@@ -30,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB ...@@ -30,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage
########### ###########
...@@ -398,14 +400,21 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -398,14 +400,21 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_contents = self.forward_api_contents forward_api_contents = self.forward_api_contents
grad_api_contents = self.grad_api_contents grad_api_contents = self.grad_api_contents
assert 'api' in forward_api_contents.keys() assert 'api' in forward_api_contents.keys(
assert 'args' in forward_api_contents.keys() ), "Unable to find \"api\" in api.yaml"
assert 'output' in forward_api_contents.keys() assert 'args' in forward_api_contents.keys(
assert 'backward' in forward_api_contents.keys() ), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys(
assert 'args' in grad_api_contents.keys() ), "Unable to find \"output\" in api.yaml"
assert 'output' in grad_api_contents.keys() assert 'backward' in forward_api_contents.keys(
assert 'forward' in grad_api_contents.keys() ), "Unable to find \"backward\" in api.yaml"
assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"
def ForwardsValidationCheck(self): def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list forward_inputs_list = self.forward_inputs_list
...@@ -424,8 +433,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -424,8 +433,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
orig_input_type = orig_forward_inputs_list[i][1] orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2] orig_input_pos = orig_forward_inputs_list[i][2]
assert forward_input_type == orig_input_type assert forward_input_type == orig_input_type, AssertMessage(
assert forward_input_pos == orig_input_pos forward_input_type, orig_input_type)
assert forward_input_pos == orig_input_pos, AssertMessage(
forward_input_pos, orig_input_pos)
for i in range(len(forward_attrs_list)): for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0] orig_attr_name = orig_forward_attrs_list[i][0]
...@@ -436,9 +447,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -436,9 +447,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_attr_type = forward_attrs_list[i][1] forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2] forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3] forward_attr_pos = forward_attrs_list[i][3]
assert orig_attr_type == forward_attr_type assert orig_attr_type == forward_attr_type, AssertMessage(
assert orig_attr_default == forward_attr_default orig_attr_type, forward_attr_type)
assert orig_attr_pos == forward_attr_pos assert orig_attr_default == forward_attr_default, AssertMessage(
orig_attr_default, forward_attr_default)
assert orig_attr_pos == forward_attr_pos, AssertMessage(
orig_attr_pos, forward_attr_pos)
for i in range(len(forward_returns_list)): for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][1] orig_return_type = orig_forward_returns_list[i][1]
...@@ -446,8 +460,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -446,8 +460,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_return_type = forward_returns_list[i][1] forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2] forward_return_pos = forward_returns_list[i][2]
assert orig_return_type == forward_return_type assert orig_return_type == forward_return_type, AssertMessage(
assert orig_return_pos == forward_return_pos orig_return_type, forward_return_type)
assert orig_return_pos == forward_return_pos, AssertMessage(
orig_return_pos, forward_return_pos)
# Check Order: Inputs, Attributes # Check Order: Inputs, Attributes
max_input_position = -1 max_input_position = -1
...@@ -456,7 +472,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -456,7 +472,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_attr_position = -1 max_attr_position = -1
for _, _, _, pos in forward_attrs_list: for _, _, _, pos in forward_attrs_list:
assert pos > max_input_position assert pos > max_input_position, AssertMessage(pos,
max_input_position)
max_attr_position = max(max_attr_position, pos) max_attr_position = max(max_attr_position, pos)
def BackwardValidationCheck(self): def BackwardValidationCheck(self):
...@@ -471,12 +488,14 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -471,12 +488,14 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_grad_tensor_position = -1 max_grad_tensor_position = -1
for _, (_, _, pos) in backward_grad_inputs_map.items(): for _, (_, _, pos) in backward_grad_inputs_map.items():
assert pos > max_fwd_input_position assert pos > max_fwd_input_position, AssertMessage(
pos, max_grad_tensor_position)
max_grad_tensor_position = max(max_grad_tensor_position, pos) max_grad_tensor_position = max(max_grad_tensor_position, pos)
max_attr_position = -1 max_attr_position = -1
for _, _, _, pos in backward_attrs_list: for _, _, _, pos in backward_attrs_list:
assert pos > max_grad_tensor_position assert pos > max_grad_tensor_position, AssertMessage(
pos, max_grad_tensor_position)
max_attr_position = max(max_attr_position, pos) max_attr_position = max(max_attr_position, pos)
def IntermediateValidationCheck(self): def IntermediateValidationCheck(self):
...@@ -491,7 +510,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -491,7 +510,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
len(forward_returns_list)) len(forward_returns_list))
for ret_name, _, pos in forward_returns_list: for ret_name, _, pos in forward_returns_list:
if ret_name in intermediate_outputs: if ret_name in intermediate_outputs:
assert pos in intermediate_positions assert pos in intermediate_positions, AssertMessage(
pos, intermediate_positions)
def CollectBackwardInfo(self): def CollectBackwardInfo(self):
forward_api_contents = self.forward_api_contents forward_api_contents = self.forward_api_contents
...@@ -505,9 +525,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -505,9 +525,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward( self.backward_inputs_list, self.backward_attrs_list, self.backward_returns_list = ParseYamlBackward(
backward_args_str, backward_returns_str) backward_args_str, backward_returns_str)
print("Parsed Backward Inputs List: ", self.backward_inputs_list)
print("Prased Backward Attrs List: ", self.backward_attrs_list) logging.info(
print("Parsed Backward Returns List: ", self.backward_returns_list) f"Parsed Backward Inputs List: {self.backward_inputs_list}")
logging.info(f"Prased Backward Attrs List: {self.backward_attrs_list}")
logging.info(
f"Parsed Backward Returns List: {self.backward_returns_list}")
def CollectForwardInfoFromBackwardContents(self): def CollectForwardInfoFromBackwardContents(self):
...@@ -530,7 +553,9 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -530,7 +553,9 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_fwd_name = FindForwardName(backward_input_name) backward_fwd_name = FindForwardName(backward_input_name)
if backward_fwd_name: if backward_fwd_name:
# Grad Input # Grad Input
assert backward_fwd_name in forward_outputs_position_map.keys() assert backward_fwd_name in forward_outputs_position_map.keys(
), AssertMessage(backward_fwd_name,
forward_outputs_position_map.keys())
matched_forward_output_type = forward_outputs_position_map[ matched_forward_output_type = forward_outputs_position_map[
backward_fwd_name][0] backward_fwd_name][0]
matched_forward_output_pos = forward_outputs_position_map[ matched_forward_output_pos = forward_outputs_position_map[
...@@ -556,7 +581,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -556,7 +581,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_input_type, False, backward_input_pos backward_input_type, False, backward_input_pos
] ]
else: else:
assert False, backward_input_name assert False, f"Cannot find {backward_input_name} in forward position map"
for backward_output in backward_returns_list: for backward_output in backward_returns_list:
backward_output_name = backward_output[0] backward_output_name = backward_output[0]
...@@ -564,9 +589,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -564,9 +589,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_pos = backward_output[2] backward_output_pos = backward_output[2]
backward_fwd_name = FindForwardName(backward_output_name) backward_fwd_name = FindForwardName(backward_output_name)
assert backward_fwd_name is not None assert backward_fwd_name is not None, f"Detected {backward_fwd_name} = None"
assert backward_fwd_name in forward_inputs_position_map.keys( assert backward_fwd_name in forward_inputs_position_map.keys(
), f"Unable to find {backward_fwd_name} in forward inputs" ), AssertMessage(backward_fwd_name,
forward_inputs_position_map.keys())
matched_forward_input_type = forward_inputs_position_map[ matched_forward_input_type = forward_inputs_position_map[
backward_fwd_name][0] backward_fwd_name][0]
...@@ -577,12 +603,15 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -577,12 +603,15 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_type, matched_forward_input_pos, backward_output_type, matched_forward_input_pos,
backward_output_pos backward_output_pos
] ]
print("Generated Backward Fwd Input Map: ", logging.info(
self.backward_forward_inputs_map) f"Generated Backward Fwd Input Map: {self.backward_forward_inputs_map}"
print("Generated Backward Grad Input Map: ", )
self.backward_grad_inputs_map) logging.info(
print("Generated Backward Grad Output Map: ", f"Generated Backward Grad Input Map: {self.backward_grad_inputs_map}"
self.backward_grad_outputs_map) )
logging.info(
f"Generated Backward Grad Output Map: {self.backward_grad_outputs_map}"
)
def GenerateNodeDeclaration(self): def GenerateNodeDeclaration(self):
forward_op_name = self.forward_api_name forward_op_name = self.forward_api_name
...@@ -642,7 +671,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -642,7 +671,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
set_tensor_wrapper_methods_str, set_attribute_methods_str, set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str) tensor_wrapper_members_str, attribute_members_str)
print("Generated Node Declaration: ", self.node_declaration_str) logging.info(f"Generated Node Declaration: {self.node_declaration_str}")
def GenerateNodeDefinition(self): def GenerateNodeDefinition(self):
namespace = self.namespace namespace = self.namespace
...@@ -710,7 +739,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -710,7 +739,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace, grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
backward_api_name, grad_api_args_str, returns_str) backward_api_name, grad_api_args_str, returns_str)
print("Generated Node Definition: ", self.node_definition_str) logging.info(f"Generated Node Definition: {self.node_definition_str}")
def GenerateForwardDefinition(self, is_inplaced): def GenerateForwardDefinition(self, is_inplaced):
namespace = self.namespace namespace = self.namespace
...@@ -813,8 +842,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -813,8 +842,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_event_str, node_creation_str, returns_str) dygraph_event_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n" self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
print("Generated Forward Definition: ", self.forward_definition_str) logging.info(
print("Generated Forward Declaration: ", self.forward_declaration_str) f"Generated Forward Definition: {self.forward_definition_str}")
logging.info(
f"Generated Forward Declaration: {self.forward_declaration_str}")
def GenerateNodeCreationCodes(self, forward_call_str): def GenerateNodeCreationCodes(self, forward_call_str):
forward_api_name = self.forward_api_name forward_api_name = self.forward_api_name
...@@ -921,7 +952,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -921,7 +952,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
# Aligned with forward output position # Aligned with forward output position
assert name in forward_outputs_position_map.keys() assert name in forward_outputs_position_map.keys(
), AssertMessage(name, forward_outputs_position_map.keys())
fwd_output_pos = forward_outputs_position_map[name][1] fwd_output_pos = forward_outputs_position_map[name][1]
tw_name = f"std::get<{fwd_output_pos}>(api_result)" tw_name = f"std::get<{fwd_output_pos}>(api_result)"
else: else:
...@@ -1114,7 +1146,8 @@ class DygraphYamlGenerator(YamlGeneratorBase): ...@@ -1114,7 +1146,8 @@ class DygraphYamlGenerator(YamlGeneratorBase):
if 'backward' not in forward_api_contents.keys(): return None if 'backward' not in forward_api_contents.keys(): return None
backward_api_name = forward_api_contents['backward'] backward_api_name = forward_api_contents['backward']
assert backward_api_name in grad_api_dict.keys() assert backward_api_name in grad_api_dict.keys(), AssertMessage(
backward_api_name, grad_api_dict.keys())
backward_api_contents = grad_api_dict[backward_api_name] backward_api_contents = grad_api_dict[backward_api_name]
return backward_api_contents return backward_api_contents
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册