未验证 提交 ca4aea2c 编写于 作者: C chenjian 提交者: GitHub

fix new dygraph record event (#41715)

* fix new dygraph record event

* refine name

* fix

* fix

* fix according to review
上级 c239f15a
...@@ -2480,7 +2480,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2480,7 +2480,7 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n" "%s\n"
" SetIsTensorWrappersCleared(true);\n" " SetIsTensorWrappersCleared(true);\n"
" }\n" " }\n"
" std::string name() override { return \" GradNode%s \"; } \n " " std::string name() override { return \"GradNode%sMid\"; } \n "
"\n" "\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n " "std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new " " auto copied_node = std::shared_ptr<GradNode%s>(new "
......
...@@ -137,7 +137,7 @@ def RemoveConstAndReference(string): ...@@ -137,7 +137,7 @@ def RemoveConstAndReference(string):
def GetGradNodeName(string): def GetGradNodeName(string):
return f"FinalGradNode{string}" return f"GradNode{string}Final"
def GetDygraphForwardFunctionName(string): def GetDygraphForwardFunctionName(string):
......
...@@ -120,7 +120,7 @@ class {} : public egr::GradNodeBase {{ ...@@ -120,7 +120,7 @@ class {} : public egr::GradNodeBase {{
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override; std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \"{}\"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
{} {}
...@@ -804,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): ...@@ -804,7 +804,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_retain_grad_str = "\n".join(set_retain_grad_list) set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation" node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n" node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format( self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str, pass_stop_gradient_args_str, node_creation_event_str, pass_stop_gradient_args_str,
......
...@@ -643,7 +643,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -643,7 +643,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Running GradNode:" << node->name(); VLOG(6) << "Running GradNode:" << node->name();
paddle::platform::RecordEvent node_record_event( paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node", std::string((*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1); paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) { if (queue.size() > 1 && node_in_degree_map[node] != 0) {
......
...@@ -744,7 +744,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -744,7 +744,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} using kernel_signature = {kernel_signature}; {code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{ {code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::Operator, 1); {code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); {code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} }} {code_indent} }}
...@@ -771,7 +771,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self ...@@ -771,7 +771,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} using kernel_signature = {kernel_signature}; {code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>(); {code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{ {code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::Operator, 1); {code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args}); {code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} }} {code_indent} }}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册