提交 3723caba 编写于 作者: J jim19930609

Added EagerUtils helper functions for final state CodeGen

上级 1d755225
#add_subdirectory(final_state_generator) add_subdirectory(final_state_generator)
set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag) set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag)
......
...@@ -9,7 +9,7 @@ set(forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager ...@@ -9,7 +9,7 @@ set(forwards_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager
set(nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/node.cc") set(nodes_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/node.cc")
set(nodes_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/node.h") set(nodes_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/node.h")
execute_process( add_custom_target(eager_final_state_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py" COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py"
"--api_yaml_path=${api_yaml_path}" "--api_yaml_path=${api_yaml_path}"
"--backward_yaml_path=${backward_yaml_path}" "--backward_yaml_path=${backward_yaml_path}"
...@@ -21,4 +21,5 @@ execute_process( ...@@ -21,4 +21,5 @@ execute_process(
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_h_path} ${forwards_h_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_h_path} ${forwards_h_path}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_cc_path} ${nodes_cc_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_cc_path} ${nodes_cc_path}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path}
VERBATIM
) )
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import yaml import yaml
import re import re
import argparse import argparse
import os
def ParseArguments(): def ParseArguments():
...@@ -93,8 +94,8 @@ def ReadBwdFile(filepath): ...@@ -93,8 +94,8 @@ def ReadBwdFile(filepath):
contents = yaml.load(f) contents = yaml.load(f)
ret = {} ret = {}
for content in contents: for content in contents:
assert 'grad_api' in content.keys() assert 'backward_api' in content.keys()
api_name = content['grad_api'] api_name = content['backward_api']
ret[api_name] = content ret[api_name] = content
return ret return ret
...@@ -435,10 +436,10 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -435,10 +436,10 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
aname, GetConstReference(atype), aname, saved_attr_name, aname) aname, GetConstReference(atype), aname, saved_attr_name, aname)
ATTRIBUTE_MEMBER_TEMPLATE = """ ATTRIBUTE_MEMBER_TEMPLATE = """
{} {}; {} {} = {};
""" """
attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format( attribute_members_str += ATTRIBUTE_MEMBER_TEMPLATE.format(
GetConstReference(atype), saved_attr_name) GetConstReference(atype), saved_attr_name, default_val)
# End: SetAttributes & Attribute Members # End: SetAttributes & Attribute Members
NODE_DECLARATION_TEMPLATE = """ NODE_DECLARATION_TEMPLATE = """
...@@ -491,15 +492,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -491,15 +492,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
tensor_wrapper_name = GetSavedName(name) tensor_wrapper_name = GetSavedName(name)
if is_fwd_input: if is_fwd_input:
grad_api_args[ grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, true)" grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, true) )"
else: else:
grad_api_args[ grad_api_args[
grad_api_position] = f"egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, false)" grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, false) )"
for _, (_, fwd_position, for _, (_, fwd_position,
grad_api_position) in backward_grad_input_map.items(): grad_api_position) in backward_grad_input_map.items():
grad_api_args[ grad_api_args[
grad_api_position] = f"*grads[{fwd_position}].Tensor().get()" grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( *grads[{fwd_position}] )"
for name, _, _, grad_api_position in backward_attrs_list: for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name) saved_attribute_name = GetSavedName(name)
...@@ -615,7 +616,7 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -615,7 +616,7 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# SetAttributes # SetAttributes
set_attributes_list = [] set_attributes_list = []
for name, _, _, _ in backward_attrs_list: for name, _, _, _ in backward_attrs_list:
set_attributes = " grad_node->SetAttribute{name}({name});" set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes_list.append(set_attributes) set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list) set_attributes_str = "\n".join(set_attributes_list)
...@@ -727,7 +728,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -727,7 +728,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
inputs_args_list = ["" for i in range(num_inputs)] inputs_args_list = ["" for i in range(num_inputs)]
inputs_call_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"*{name}.Tensor().get()" inputs_call_list[pos] = f"egr::EagerUtils::SyncToPtenTensors({name})"
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
inputs_args_list[pos] = f"const egr::EagerTensor& {name}" inputs_args_list[pos] = f"const egr::EagerTensor& {name}"
else: else:
...@@ -905,10 +906,17 @@ if __name__ == "__main__": ...@@ -905,10 +906,17 @@ if __name__ == "__main__":
# Collect Forward Inputs/Outputs # Collect Forward Inputs/Outputs
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward( forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForwardFromBackward(
bwd_forward_str) bwd_forward_str)
print("Parsed Forward Inputs List: ", forward_inputs_list)
print("Prased Forward Attrs List: ", forward_attrs_list)
print("Parsed Forward Returns List: ", forward_returns_list)
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward( orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", orig_forward_inputs_list)
print("Prased Original Forward Attrs List: ", orig_forward_attrs_list)
print("Parsed Original Forward Returns List: ",
orig_forward_returns_list)
# Forward Validation Checks # Forward Validation Checks
ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
...@@ -919,15 +927,25 @@ if __name__ == "__main__": ...@@ -919,15 +927,25 @@ if __name__ == "__main__":
# Parse Backward Inputs/Outputs # Parse Backward Inputs/Outputs
backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward( backward_inputs_list, backward_attrs_list, backward_returns_list = ParseYamlBackward(
bwd_args_str, bwd_returns_str) bwd_args_str, bwd_returns_str)
print("Parsed Backward Inputs List: ", backward_inputs_list)
print("Prased Backward Attrs List: ", backward_attrs_list)
print("Parsed Backward Returns List: ", backward_returns_list)
# Determine Forward Inputs/Outputs Position # Determine Forward Inputs/Outputs Position
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list) forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
# SlotName Matching # SlotName Matching
backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching( backward_fwd_input_map, backward_grad_input_map, backward_grad_output_map = SlotNameMatching(
backward_inputs_list, backward_returns_list, backward_inputs_list, backward_returns_list,
forward_inputs_position_map, forward_outputs_position_map) forward_inputs_position_map, forward_outputs_position_map)
print("Generated Backward Fwd Input Map: ", backward_fwd_input_map)
print("Generated Backward Grad Input Map: ", backward_grad_input_map)
print("Generated Backward Grad Output Map: ", backward_grad_output_map)
# Backward Validation Check # Backward Validation Check
BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map, BackwardValidationCheck(backward_fwd_input_map, backward_grad_input_map,
...@@ -936,11 +954,13 @@ if __name__ == "__main__": ...@@ -936,11 +954,13 @@ if __name__ == "__main__":
# Node Declaration Generation # Node Declaration Generation
node_declaration_str += GenerateNodeDeclaration( node_declaration_str += GenerateNodeDeclaration(
fwd_api_name, backward_fwd_input_map, backward_attrs_list) fwd_api_name, backward_fwd_input_map, backward_attrs_list)
print("Generated Node Declaration: ", node_declaration_str)
node_definition_str += GenerateNodeDefinition( node_definition_str += GenerateNodeDefinition(
fwd_api_name, bwd_api_name, backward_fwd_input_map, fwd_api_name, bwd_api_name, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map, backward_grad_input_map, backward_grad_output_map,
backward_attrs_list) backward_attrs_list)
print("Generated Node Definition: ", node_definition_str)
# Node Definition Generation # Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition( definition_declaration_pair = GenerateForwardDefinition(
...@@ -948,6 +968,8 @@ if __name__ == "__main__": ...@@ -948,6 +968,8 @@ if __name__ == "__main__":
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0] forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1] forward_declaration_str += definition_declaration_pair[1]
...@@ -957,6 +979,12 @@ if __name__ == "__main__": ...@@ -957,6 +979,12 @@ if __name__ == "__main__":
forwards_h_path = args.forwards_h_path forwards_h_path = args.forwards_h_path
forwards_cc_path = args.forwards_cc_path forwards_cc_path = args.forwards_cc_path
for path in [
nodes_cc_path, nodes_h_path, forwards_h_path, forwards_cc_path
]:
if os.path.exists(path):
os.remove(path)
GenerateNodeCCFile(nodes_cc_path, node_definition_str) GenerateNodeCCFile(nodes_cc_path, node_definition_str)
GenerateNodeHFile(nodes_h_path, node_declaration_str) GenerateNodeHFile(nodes_h_path, node_declaration_str)
GenerateForwardCCFile(forwards_cc_path, forward_definition_str) GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
......
...@@ -286,4 +286,43 @@ void EagerUtils::CheckAndRetainGrad( ...@@ -286,4 +286,43 @@ void EagerUtils::CheckAndRetainGrad(
} }
} }
paddle::experimental::Tensor EagerUtils::SyncToPtenTensors(
const egr::EagerTensor& tensor) {
const_cast<EagerTensor*>(&tensor)->SyncToTensor();
return *tensor.Tensor().get();
}
std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
const std::vector<egr::EagerTensor>& tensors) {
std::vector<paddle::experimental::Tensor> res;
size_t num = tensors.size();
res.reserve(num);
for (size_t i = 0; i < num; i++) {
const_cast<EagerTensor*>(&(tensors[i]))->SyncToTensor();
res.push_back(*tensors[i].Tensor().get());
}
return res;
}
egr::EagerTensor EagerUtils::CreateEagerTensorFromTensor(
const paddle::experimental::Tensor& tensor) {
egr::EagerTensor ret;
ret.set_tensor(std::make_shared<paddle::experimental::Tensor>(tensor));
return ret;
}
std::vector<egr::EagerTensor> EagerUtils::CreateEagerTensorFromTensor(
const std::vector<paddle::experimental::Tensor>& tensors) {
std::vector<egr::EagerTensor> res;
size_t num = tensors.size();
res.reserve(num);
for (size_t i = 0; i < num; i++) {
egr::EagerTensor tmp;
tmp.set_tensor(std::make_shared<paddle::experimental::Tensor>(tensors[i]));
res.emplace_back(std::move(tmp));
}
return res;
}
} // namespace egr } // namespace egr
...@@ -170,6 +170,16 @@ class EagerUtils { ...@@ -170,6 +170,16 @@ class EagerUtils {
static void CheckAndRetainGrad(const egr::EagerTensor& tensor); static void CheckAndRetainGrad(const egr::EagerTensor& tensor);
static void CheckAndRetainGrad(const std::vector<egr::EagerTensor>& tensors); static void CheckAndRetainGrad(const std::vector<egr::EagerTensor>& tensors);
static paddle::experimental::Tensor SyncToPtenTensors(
const egr::EagerTensor& tensor);
static std::vector<paddle::experimental::Tensor> SyncToPtenTensors(
const std::vector<egr::EagerTensor>& tensors);
static egr::EagerTensor CreateEagerTensorFromTensor(
const paddle::experimental::Tensor& tensor);
static std::vector<egr::EagerTensor> CreateEagerTensorFromTensor(
const std::vector<paddle::experimental::Tensor>& tensors);
}; };
} // namespace egr } // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册