未验证 提交 4be96ad9 编写于 作者: C chenjian 提交者: GitHub

fix new dygraph record event (#41715) (#41771)

* fix new dygraph record event

* refine name

* fix

* fix

* fix according to review
上级 33583dde
......@@ -2481,7 +2481,7 @@ static std::string GenerateGradNodeHeaderContents(
"%s\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n "
" std::string name() override { return \"GradNode%sMid\"; } \n "
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
......
......@@ -136,7 +136,7 @@ def RemoveConstAndReference(string):
def GetGradNodeName(string):
return f"FinalGradNode{string}"
return f"GradNode{string}Final"
def GetDygraphForwardFunctionName(string):
......
......@@ -120,7 +120,7 @@ class {} : public egr::GradNodeBase {{
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }}
std::string name() override {{ return \"{}\"; }}
void ClearTensorWrappers() override {{
{}
......@@ -791,7 +791,7 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::Operator, 1);\n"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str, pass_stop_gradient_args_str,
......
......@@ -641,7 +641,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Running GradNode:" << node->name();
paddle::platform::RecordEvent node_record_event(
std::string(typeid(*node).name()) + " grad_node",
std::string((*node).name()) + " grad_node",
paddle::platform::TracerEventType::Operator, 1);
if (queue.size() > 1 && node_in_degree_map[node] != 0) {
......
......@@ -734,7 +734,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::Operator, 1);
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} }}
......@@ -761,7 +761,7 @@ PADDLE_API {self.gene_return_type_code()} {self.get_api_func_name() + '_'}({self
{code_indent} using kernel_signature = {kernel_signature};
{code_indent} auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{code_indent} {{
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::Operator, 1);
{code_indent} paddle::platform::RecordEvent kernel_record_event(\"{api_func_name} compute\", paddle::platform::TracerEventType::OperatorInner, 1);
{code_indent} (*kernel_fn)({kernel_args}, {outputs_args});
{code_indent} }}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册