未验证 提交 984eacb3 编写于 作者: S ShenLiang 提交者: GitHub

[DataParallel]Support control flow in new DP (#40593)

* fix bug

* fix bug
上级 755a6c53
cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api) cc_library(processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api)
cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api) cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup phi phi_api string_helper)
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE)
cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper) cc_library(processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper)
......
...@@ -17,6 +17,20 @@ ...@@ -17,6 +17,20 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
static Backend TransToBackend(platform::Place place) {
static const std::map<phi::AllocationType, Backend> type_backend = {
{phi::AllocationType::GPU, Backend::GPU},
{phi::AllocationType::CPU, Backend::CPU},
};
phi::AllocationType type = place.GetType();
auto it = type_backend.find(type);
PADDLE_ENFORCE_EQ(it != type_backend.end(), true,
platform::errors::InvalidArgument(
"Place type (%s) is not supported. ", place));
return it->second;
}
std::vector<std::vector<size_t>> Eager_AssignGroupBySize( std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor> tensors, const std::vector<Tensor> tensors,
const std::vector<bool> &is_sparse_gradient, const std::vector<bool> &is_sparse_gradient,
...@@ -297,10 +311,18 @@ EagerReducer::EagerReducer( ...@@ -297,10 +311,18 @@ EagerReducer::EagerReducer(
std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node); std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
accumulation_grad_node->RegisterReduceHook( accumulation_grad_node->RegisterReduceHook(
std::make_shared<egr::CppTensorVoidHook>(reduce_hook)); std::make_shared<egr::CppTensorVoidHook>(reduce_hook));
gradnode_index_map_[grad_node.get()] = global_var_index;
} }
vars_marked_ready_.resize(tensors_.size(), false); vars_marked_ready_.resize(tensors_.size(), false);
local_used_vars_.resize(tensors_.size(), 0); local_used_vars_.resize(tensors_.size(), 0);
if (find_unused_vars_each_step_) {
global_used_vars_ = paddle::experimental::empty(
ScalarArray({static_cast<int32_t>(tensors_.size())}), DataType::INT32,
TransToBackend(inner_place_));
}
} }
std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor( std::shared_ptr<egr::GradNodeBase> EagerReducer::GetGradNodeFromTensor(
...@@ -341,21 +363,10 @@ void EagerReducer::InitializeGroups( ...@@ -341,21 +363,10 @@ void EagerReducer::InitializeGroups(
} else { } else {
// process the dense gradient. // process the dense gradient.
InitializeDenseGroups(tensor_indices_, &group); InitializeDenseGroups(tensor_indices_, &group);
experimental::Backend backend; // experimental::Backend backend = TransToBackend(inner_place_);
switch (inner_place_.GetType()) {
case phi::AllocationType::GPU:
backend = experimental::Backend::GPU;
break;
case phi::AllocationType::CPU:
backend = experimental::Backend::CPU;
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Place type (%s) is not supported. ", inner_place_));
break;
}
group.dense_contents_ = paddle::experimental::empty( group.dense_contents_ = paddle::experimental::empty(
ScalarArray({group.all_length_}), group.dtype_, backend); ScalarArray({group.all_length_}), group.dtype_,
TransToBackend(inner_place_));
} }
// map tensors to this group by VariableLocator // map tensors to this group by VariableLocator
...@@ -418,6 +429,53 @@ void EagerReducer::InitializeDenseGroups( ...@@ -418,6 +429,53 @@ void EagerReducer::InitializeDenseGroups(
p_group->all_length_ = all_length; p_group->all_length_ = all_length;
} }
void EagerReducer::TraverseBackwardGraph(const std::vector<Tensor> &outputs) {
std::queue<egr::GradNodeBase *> queue;
std::set<egr::GradNodeBase *> visited;
for (const auto &output : outputs) {
auto *auto_grad_meta =
static_cast<egr::AutogradMeta *>(output.get_autograd_meta());
if (!auto_grad_meta) continue;
auto shared_grad_node = auto_grad_meta->GetMutableGradNode();
if (shared_grad_node == nullptr || shared_grad_node.get() == nullptr ||
auto_grad_meta->StopGradient()) {
continue;
}
egr::GradNodeBase *grad_node = shared_grad_node.get();
queue.emplace(grad_node);
}
while (!queue.empty()) {
egr::GradNodeBase *node = queue.front();
queue.pop();
const std::vector<std::vector<egr::Edge>> &edges = node->GetEdges();
for (size_t i = 0; i < edges.size(); i++) {
for (size_t j = 0; j < edges[i].size(); j++) {
const egr::Edge &edge = edges[i][j];
auto next_node_shared = edge.GetMutableGradNode();
if (!next_node_shared || !next_node_shared.get()) {
continue;
}
auto *next_node = next_node_shared.get();
const bool was_inserted = visited.insert(next_node).second;
if (was_inserted) {
queue.emplace(next_node);
}
}
}
}
for (const auto &it : gradnode_index_map_) {
if (visited.count(it.first) == 0) {
unused_vars_.push_back(it.second);
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Tensor " << tensors_[it.second].name() << " at index "
<< it.second << " is marked as unused.";
}
}
}
void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
VLOG(3) << "after forward, then reset count for backward."; VLOG(3) << "after forward, then reset count for backward.";
grad_need_hooks_ = true; grad_need_hooks_ = true;
...@@ -429,6 +487,51 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) { ...@@ -429,6 +487,51 @@ void EagerReducer::PrepareForBackward(const std::vector<Tensor> &outputs) {
// reinitialize vars_marked_ready_ for next iteration // reinitialize vars_marked_ready_ for next iteration
vars_marked_ready_.clear(); vars_marked_ready_.clear();
vars_marked_ready_.resize(tensors_.size(), false); vars_marked_ready_.resize(tensors_.size(), false);
PADDLE_ENFORCE_EQ(
groups_need_finalize_, false,
platform::errors::PreconditionNotMet(
"A serious error has occurred here. Please "
"set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have "
"set, There may be several reasons for this error: "
"1) Please note that all forward outputs derived from the module "
"parameters must participate in the calculation of losses and "
"subsequent gradient calculations. If not, the wrapper will hang, "
"waiting for autograd to generate gradients for these parameters. "
"you can use detach or stop_gradient to make the unused parameters "
"detached from the autograd graph. "
"2) Used multiple forwards and one backward. You may be able to wrap "
"multiple forwards in a model."));
// The first var to trigger the unused parameter
has_marked_unused_vars_ = false;
if (find_unused_vars_once_ || find_unused_vars_each_step_) {
unused_vars_.clear();
TraverseBackwardGraph(outputs);
// only check once in first step
find_unused_vars_once_ = false;
}
if (find_unused_vars_each_step_ && unused_vars_.empty()) {
LOG_FIRST_N(WARNING, 1)
<< "All parameters are involved in the backward pass. "
"It is recommended to set find_unused_parameters to False "
"to improve performance. However, if unused parameters "
"appear in subsequent iterative training, then an error "
"will occur. Please make it clear that in the subsequent "
"training, there will be no parameters that are not used "
"in the backward pass, and then set find_unused_parameters";
}
if (unused_vars_.size() == tensors_.size()) {
LOG_FIRST_N(WARNING, 1)
<< "There is no parameter in the device involved "
"in the backward calculation. If there are "
"parameters on other devices involved in the "
"backward, then a serious error will occur here.";
}
} }
void EagerReducer::AddDistHook(size_t var_index) { void EagerReducer::AddDistHook(size_t var_index) {
...@@ -446,36 +549,104 @@ void EagerReducer::AddDistHook(size_t var_index) { ...@@ -446,36 +549,104 @@ void EagerReducer::AddDistHook(size_t var_index) {
auto &tensor = tensors_[var_index]; auto &tensor = tensors_[var_index];
const auto &grad_node = GetGradNodeFromTensor(&tensor); const auto &grad_node = GetGradNodeFromTensor(&tensor);
VLOG(3) << "Var[" << var_index << "] [" << (*grad_node).name() VLOG(3) << "Tensor[" << var_index << "] [" << tensors_[var_index].name()
<< "] arrived and triggered disthook"; << "@Grad] arrived and triggered disthook";
local_used_vars_[var_index] = 1; local_used_vars_[var_index] = 1;
if (!has_marked_unused_vars_) {
has_marked_unused_vars_ = true;
for (const auto unused_index : unused_vars_) {
MarkVarReady(unused_index, false);
}
}
MarkVarReady(var_index, true); MarkVarReady(var_index, true);
} }
void EagerReducer::MarkVarReady(const size_t var_index, void EagerReducer::MarkVarReady(const size_t var_index,
const bool is_used_var) { const bool is_used_var) {
VLOG(3) << "Tensor[" << var_index << "][" << tensors_[var_index].name()
<< "] is marked ready.";
// error happened, if the var is ready before.
if (vars_marked_ready_[var_index]) {
auto error_info = string::Sprintf(
"Error happened, when parameter[%d][%s] has been ready before. "
"Please set find_unused_parameters=True to traverse backward graph "
"in each step to prepare reduce in advance. If you have set, "
"there may be several reasons for this error: "
"1) In multiple reentrant backward phase, some parameters are reused."
"2) Using model parameters outside of forward function. Please "
"make sure that model parameters are not shared in concurrent "
"forward-backward passes.",
var_index, tensors_[var_index].name());
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, false,
platform::errors::PreconditionNotMet(error_info));
error_info +=
"3) Unused parameters retrieval is incorrect. "
"The return value of forward will be used to retrieve"
" the unused parameters of the entire model. These "
"gradients of unused parameters will not be synchronized "
"between multiple cards. However, if the unused "
"parameters participate in the backward calculation "
"again at a later time (e.g. after the forward function, "
"the loss calculation uses the unused "
"paramters of the forward and trigger backward), "
"its gradient will be wrong.";
PADDLE_ENFORCE_EQ(has_marked_unused_vars_, true,
platform::errors::PreconditionNotMet(error_info));
} else {
vars_marked_ready_[var_index] = true;
}
groups_need_finalize_ = true;
const auto &var_locator = variable_locators_[var_index]; const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index; const auto group_index = var_locator.group_index;
const auto inside_group_index = var_locator.inside_group_index; const auto inside_group_index = var_locator.inside_group_index;
auto &group = groups_[group_index]; auto &group = groups_[group_index];
auto &group_tensor = group.dense_tensors_[inside_group_index]; auto &group_tensor = group.dense_tensors_[inside_group_index];
auto *autograd_meta = tensors_[var_index].get_autograd_meta(); const auto length = group.length_[inside_group_index];
auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
if (is_used_var) {
group_tensor auto *autograd_meta = tensors_[var_index].get_autograd_meta();
.ShareDataWith( auto &grad_tensor = static_cast<egr::AutogradMeta *>(autograd_meta)->Grad();
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl()))) group_tensor
.Resize({grad_tensor.numel()}); .ShareDataWith(
*(std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor.impl())))
vars_marked_ready_[var_index] = true; .Resize({grad_tensor.numel()});
} else {
// TODO(shenliang03): maybe save the memory by avoiding tensor construction
if (!group_tensor.initialized()) {
group_tensor.Resize({static_cast<int64_t>(length)});
group_tensor.mutable_data(inner_place_, group.dtype_);
}
if (HasGrad(var_index)) {
VLOG(3) << "Tensor[" << tensors_[var_index].name() << "] has grad";
auto grad_tensor = egr::EagerUtils::mutable_grad(tensors_[var_index]);
group_tensor
.ShareDataWith(*(
std::dynamic_pointer_cast<phi::DenseTensor>(grad_tensor->impl())))
.Resize({length});
} else {
VLOG(3) << "Tensor[" << tensors_[var_index].name()
<< "] doesn't have grad";
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(inner_place_);
group_tensor.Resize({static_cast<int64_t>(length)});
phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0);
}
}
if (--group.pending_ == 0) { if (--group.pending_ == 0) {
// can start allreduce // can start allreduce
MarkGroupReady(group_index); MarkGroupReady(group_index);
} }
if (next_group_ == groups_.size()) {
FinalizeBackward();
}
} }
void EagerReducer::MarkGroupReady(size_t group_index) { void EagerReducer::MarkGroupReady(size_t group_index) {
...@@ -501,6 +672,92 @@ void EagerReducer::MarkGroupReady(size_t group_index) { ...@@ -501,6 +672,92 @@ void EagerReducer::MarkGroupReady(size_t group_index) {
} }
} }
bool EagerReducer::HasGrad(size_t var_index) {
auto grad = egr::EagerUtils::mutable_grad(tensors_[var_index]);
if (grad && grad->is_initialized()) {
return true;
} else {
return false;
}
}
void EagerReducer::ProcessUnusedDenseVars() {
// The calculation stream must be used here to
// avoid conflicts with communication.
VLOG(3) << "Local used vars : "
<< string::join_strings(local_used_vars_, ',');
const auto *dev_ctx =
platform::DeviceContextPool::Instance().Get(inner_place_);
auto *global_used_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(global_used_vars_.impl())
.get();
framework::TensorFromVector<int32_t>(local_used_vars_, *dev_ctx,
global_used_tensor);
distributed::AllreduceOptions opts;
opts.reduce_op = ReduceOp::SUM;
std::vector<Tensor> reduce_tensors = {global_used_vars_};
process_group_->AllReduce(reduce_tensors, opts)->Synchronize();
framework::TensorToVector<int>(*global_used_tensor, *dev_ctx,
&local_used_vars_);
dev_ctx->Wait();
// sync compute stream to get global used var message,
// but maybe affect speed performance
VLOG(3) << "Global used vars : "
<< string::join_strings(local_used_vars_, ',');
for (const auto var_index : unused_vars_) {
const bool global_unused = (local_used_vars_[var_index] == 0);
// global used but local unused, set grad
VLOG(3) << "[Rank " << process_group_->GetRank() << "]: "
<< "Var [" << var_index << "] [" << tensors_[var_index].name()
<< "] global_unused: " << global_unused
<< " has grad: " << HasGrad(var_index);
if (!global_unused) {
VLOG(3) << "Set Tensor[" << var_index << "]'s Grad for [Rank "
<< process_group_->GetRank() << "]";
const auto &var_locator = variable_locators_[var_index];
const auto group_index = var_locator.group_index;
const auto &group = groups_[group_index];
const auto inside_group_index = var_locator.inside_group_index;
auto &src_tensor = group.dense_tensors_[inside_group_index];
Tensor grad_value(std::make_shared<phi::DenseTensor>(src_tensor));
auto dest_var_base = tensors_[var_index];
auto grad_tensor = egr::EagerUtils::mutable_grad(dest_var_base);
grad_tensor->copy_(grad_value, inner_place_, true);
grad_tensor->reshape(dest_var_base.shape());
}
}
}
void EagerReducer::FinalizeBackward() {
groups_need_finalize_ = false;
grad_need_hooks_ = false;
for (auto &group : groups_) {
group.task->Synchronize();
}
for (auto &group : groups_) {
group.SplitTensors(inner_place_);
}
if (find_unused_vars_each_step_) {
ProcessUnusedDenseVars();
local_used_vars_.clear();
local_used_vars_.resize(tensors_.size(), 0);
VLOG(3) << "ProcessUnusedDenseVars is finished.";
}
VLOG(3) << "In the batch, Reducer is finished.";
}
void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
const int curr_group_index) { const int curr_group_index) {
// The overall timeline: concat > div_nranks > allreduce > split // The overall timeline: concat > div_nranks > allreduce > split
...@@ -513,24 +770,14 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group, ...@@ -513,24 +770,14 @@ void EagerReducer::FusedAllReduceSchedule(EagerGroup *group,
group->ConcatTensors(inner_place_); group->ConcatTensors(inner_place_);
// div nranks // div nranks
double scaling = 1.0 / nranks_; paddle::experimental::scale_(group->dense_contents_, 1.0 / nranks_, 0.0,
paddle::experimental::scale_(group->dense_contents_, scaling, 0.0, false); false);
// all_reduce // all_reduce
std::vector<Tensor> reduce_tensors = {group->dense_contents_}; std::vector<Tensor> reduce_tensors = {group->dense_contents_};
tasks_.push_back(process_group_->AllReduce(reduce_tensors, opts)); group->task = process_group_->AllReduce(reduce_tensors, opts);
if (tasks_.size() == groups_.size()) { // split in FinalizeBackward()
for (size_t index = 0; index < tasks_.size(); index++) {
auto &task = tasks_.back();
task->Synchronize();
tasks_.pop_back();
}
for (size_t index = 0; index < groups_.size(); index++) {
auto &group = groups_[index];
group.SplitTensors(inner_place_);
}
}
} }
std::ostream &operator<<(std::ostream &out, const EagerGroup &group) { std::ostream &operator<<(std::ostream &out, const EagerGroup &group) {
......
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/ext_compat_utils.h" #include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -35,6 +37,7 @@ using Tensor = paddle::experimental::Tensor; ...@@ -35,6 +37,7 @@ using Tensor = paddle::experimental::Tensor;
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>; using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
using ScalarArray = using ScalarArray =
paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>; paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>;
using Backend = paddle::experimental::Backend;
std::vector<std::vector<size_t>> Eager_AssignGroupBySize( std::vector<std::vector<size_t>> Eager_AssignGroupBySize(
const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient, const std::vector<Tensor>, const std::vector<bool> &is_sparse_gradient,
...@@ -61,6 +64,9 @@ class EagerGroup { ...@@ -61,6 +64,9 @@ class EagerGroup {
// external message of group // external message of group
phi::DataType dtype_; phi::DataType dtype_;
// help to sync
std::shared_ptr<ProcessGroup::Task> task;
// context is used to select the stream for concat // context is used to select the stream for concat
void ConcatTensors(const platform::Place &); void ConcatTensors(const platform::Place &);
...@@ -98,6 +104,10 @@ class EagerReducer { ...@@ -98,6 +104,10 @@ class EagerReducer {
void MarkVarReady(const size_t var_index, const bool is_used_var); void MarkVarReady(const size_t var_index, const bool is_used_var);
void MarkGroupReady(const size_t group_index); void MarkGroupReady(const size_t group_index);
void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index); void FusedAllReduceSchedule(EagerGroup *group, const int curr_group_index);
void FinalizeBackward();
void TraverseBackwardGraph(const std::vector<Tensor> &outputs);
void ProcessUnusedDenseVars();
bool HasGrad(size_t var_index);
private: private:
std::vector<Tensor> tensors_; std::vector<Tensor> tensors_;
...@@ -105,7 +115,6 @@ class EagerReducer { ...@@ -105,7 +115,6 @@ class EagerReducer {
std::vector<bool> is_sparse_gradient_; std::vector<bool> is_sparse_gradient_;
std::shared_ptr<distributed::ProcessGroup> process_group_; std::shared_ptr<distributed::ProcessGroup> process_group_;
std::vector<size_t> group_size_limits_; std::vector<size_t> group_size_limits_;
bool find_unused_vars_each_step_;
std::vector<EagerGroup> groups_; std::vector<EagerGroup> groups_;
std::vector<TensorLocator> variable_locators_; std::vector<TensorLocator> variable_locators_;
...@@ -113,12 +122,20 @@ class EagerReducer { ...@@ -113,12 +122,20 @@ class EagerReducer {
platform::Place inner_place_; platform::Place inner_place_;
size_t next_group_ = 0; size_t next_group_ = 0;
int64_t nranks_ = -1; int64_t nranks_ = -1;
std::vector<std::shared_ptr<paddle::distributed::ProcessGroup::Task>> tasks_;
bool grad_need_hooks_{false}; bool grad_need_hooks_{false};
std::vector<bool> vars_marked_ready_; std::vector<bool> vars_marked_ready_;
std::vector<int> local_used_vars_; std::vector<int32_t> local_used_vars_;
// Following variables are to help unused vars
std::vector<size_t> unused_vars_;
std::map<egr::GradNodeBase *, size_t> gradnode_index_map_;
bool has_marked_unused_vars_{false};
bool find_unused_vars_each_step_{false};
bool find_unused_vars_once_{true};
bool groups_need_finalize_{false};
Tensor global_used_vars_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -327,23 +327,25 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args, ...@@ -327,23 +327,25 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
grad = meta->MutableGrad(); grad = meta->MutableGrad();
} }
if (grad->is_selected_rows()) { if (grad->impl()) {
auto selected_rows = if (grad->is_selected_rows()) {
std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl()); auto selected_rows =
if (selected_rows->mutable_value()->IsInitialized()) { std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
selected_rows->mutable_rows()->clear(); if (selected_rows->mutable_value()->IsInitialized()) {
selected_rows->mutable_value()->clear(); selected_rows->mutable_rows()->clear();
} selected_rows->mutable_value()->clear();
} else if (grad->is_dense_tensor()) { }
if (grad->initialized()) { } else if (grad->is_dense_tensor()) {
if (set_to_zero) { if (grad->initialized()) {
grad->set_impl(paddle::experimental::zeros_like(*grad).impl()); if (set_to_zero) {
} else { grad->set_impl(paddle::experimental::zeros_like(*grad).impl());
VLOG(4) << "Gradient of " << self->tensor.name() } else {
<< " is initialized, will be released."; VLOG(4) << "Gradient of " << self->tensor.name()
auto dense_tensor = << " is initialized, will be released.";
std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl()); auto dense_tensor =
dense_tensor->MoveMemoryHolder(); std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
dense_tensor->MoveMemoryHolder();
}
} }
} }
} }
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import os
import paddle
import numpy as np
import paddle.distributed as dist
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.dygraph.parallel import ParallelEnv
import paddle.fluid.core as core
paddle.seed(1024)
np.random.seed(2021)
batch = 5
in_dim = 10
out_dim = 20
def init_process_group(strategy=None):
nranks = ParallelEnv().nranks
rank = ParallelEnv().local_rank
is_master = True if rank == 0 else False
store = paddle.fluid.core.TCPStore("127.0.0.1", 6174, is_master, nranks)
group = core.ProcessGroupNCCL(store, rank, nranks)
return group
class SimpleNet(fluid.Layer):
def __init__(self, train_id):
super(SimpleNet, self).__init__()
self.w1 = self.create_parameter(
shape=[in_dim, out_dim], dtype="float32")
self.w2 = self.create_parameter(
shape=[in_dim, out_dim], dtype="float32")
self.share_net = Linear(out_dim, 10)
self.unused_param = self.create_parameter(
shape=[out_dim, in_dim], dtype="float64")
# just for test sync_params_buffers
# self.register_buffer("queue", paddle.randn([10, 5]))
# self.queue = paddle.nn.functional.normalize(self.queue, axis=0)
# self.register_buffer("queue_ptr", paddle.zeros([1], 'int64'))
self.trainer_id = train_id
def forward(self, x):
is_use = (paddle.equal_all(
x, paddle.ones(shape=(batch, in_dim))).numpy()[0] and
self.trainer_id == 1)
if is_use:
tmp = paddle.matmul(x, self.w1)
else:
tmp = paddle.matmul(x, self.w2)
return self.share_net(tmp)
class TestDistTraning(unittest.TestCase):
def test_multiple_gpus(self):
dist.init_parallel_env()
self.trainer_id = dist.get_rank()
process_group = init_process_group()
self.pg = process_group
with _test_eager_guard():
model_a = SimpleNet(self.trainer_id)
model_b = SimpleNet(self.trainer_id)
state_dict = model_a.state_dict()
model_b.set_state_dict(state_dict)
model_a = paddle.DataParallel(
model_a,
find_unused_parameters=True,
process_group=process_group)
model_b = paddle.DataParallel(
model_b,
find_unused_parameters=True,
process_group=process_group)
ones_input = paddle.ones(shape=(batch, in_dim))
ones_input.stop_gradient = True
w1_grad_sum = np.zeros((in_dim, out_dim), dtype='float32')
w2_grad_sum = np.zeros((in_dim, out_dim), dtype='float32')
for step_id in range(5):
print("==============", step_id)
random_input = paddle.rand(shape=(batch, in_dim))
random_input.stop_gradient = True
if step_id % 2 == 0:
out_a = model_a(random_input)
out_b = model_b(random_input)
else:
out_a = model_a(ones_input)
out_b = model_b(ones_input)
out_a.sum().backward()
out_b.sum().backward()
self.check_gradient(model_a.parameters())
self.check_gradient(model_b.parameters())
# test acc gradient
w1_grad_sum = self.check_acc(model_a._layers.w1.grad,
w1_grad_sum,
model_b._layers.w1.grad)
w2_grad_sum = self.check_acc(model_a._layers.w2.grad,
w2_grad_sum,
model_b._layers.w2.grad)
model_a.clear_gradients()
def check_acc(self, grad, grad_sum, acc_grad):
if grad is not None:
grad_sum = grad_sum + grad.numpy()
acc_grad = acc_grad.numpy() if acc_grad is not None else None
np.testing.assert_allclose(grad_sum, acc_grad, rtol=1e-6)
return grad_sum
def print_trainer_0(self, *args):
if self.trainer_id == 0:
print(*args)
def broadcast_param(self, param, root):
self.pg.broadcast(param, root)
return param
def check_gradient(self, params):
other_param = []
for param in params:
if param.trainable and (param.grad is not None):
grad = param.grad
other_grad = self.broadcast_param(grad, root=1)
if self.trainer_id == 0:
np.testing.assert_allclose(other_grad.numpy(), grad.numpy())
if __name__ == '__main__':
unittest.main()
...@@ -205,5 +205,10 @@ class TestDataParallelInEagerMode(TestMultipleGpus): ...@@ -205,5 +205,10 @@ class TestDataParallelInEagerMode(TestMultipleGpus):
self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py') self.run_mnist_2gpu('parallel_dygraph_dataparallel_in_eager_mode.py')
class TestGradientCheckInEagerMode(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_gradient_check_in_eager_mode.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册