未验证 提交 caaaf2f0 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Fixed performance issue regarding BackwardRun using add_final_state_dygraph (#41912)

上级 5c91010d
......@@ -1471,9 +1471,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name = GetGradNodeName(forward_api_name)
if len(grad_node_creation_str) == 0:
grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
......
......@@ -766,7 +766,8 @@ std::vector<paddle::experimental::Tensor> RunBackward(
<< ", rank: " << edge_rank.second;
node_input_buffers_dict[next_node]->add(
edge_rank.first, edge_rank.second, grad_output_tensor);
edge_rank.first, edge_rank.second, grad_output_tensor,
create_graph);
// Update queue
node_in_degree_map[next_node]--;
......
......@@ -72,7 +72,8 @@ void GradTensorHolder::CopyValueFromTensor(
}
void GradTensorHolder::add(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t) {
const paddle::experimental::Tensor& t,
bool create_graph) {
// TODO(jiabin): We need to deal with empty input_buffer with slot size not
// empty;
PADDLE_ENFORCE(slot_id < buffer_.size(),
......@@ -113,8 +114,12 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
if (t.is_dense_tensor()) {
if (buffer_tensor.is_dense_tensor()) {
buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor);
if (create_graph) {
buffer_tensor = add_final_state_dygraph_function(t, buffer_tensor);
} else {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(
t, &buffer_tensor);
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
......
......@@ -45,7 +45,8 @@ class GradTensorHolder {
GradTensorHolder& operator=(const GradTensorHolder& other) = default;
// Create new tensor and copy tensor->impl
void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t);
void add(size_t slot_id, size_t rank, const paddle::experimental::Tensor& t,
bool create_graph = false);
void CopyValueFromTensor(size_t slot_id, size_t rank,
const paddle::experimental::Tensor& t,
bool fill_one = false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册