未验证 提交 9401173e 编写于 作者: S ShenLiang 提交者: GitHub

Remove scale loss before reduce in dygraph (#30807)

上级 0020d915
......@@ -12,7 +12,7 @@ if(NOT WIN32)
if(WITH_NCCL)
cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows tensor)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce var_type_traits)
cc_library(reducer SRCS reducer.cc DEPS layer imperative_all_reduce)
nv_library(reducer SRCS reducer.cc reducer.cu DEPS layer imperative_all_reduce)
endif()
if(WITH_XPU_BKCL)
cc_library(bkcl_context SRCS bkcl_context.cc DEPS collective_helper device_context tensor var_type_traits)
......
......@@ -66,6 +66,8 @@ class ParallelContext {
inline int GetNRings() const { return strategy_.nrings_; }
inline int64_t GetNRanks() const { return strategy_.nranks_; }
protected:
ParallelStrategy strategy_;
platform::Place place_;
......
......@@ -28,6 +28,29 @@ namespace paddle {
namespace imperative {
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
// div the nranks
void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) {
framework::Tensor *tensor =
is_sparse_
? sparse_contents_->GetMutable<framework::SelectedRows>()
->mutable_value()
: dense_contents_.GetMutable<framework::LoDTensor>();
if (platform::is_gpu_place(tensor->place())) {
#if defined(PADDLE_WITH_NCCL)
DivNRanks(tensor, nranks, context);
#endif
} else if (platform::is_cpu_place(tensor->place())) {
framework::VisitDataTypeSmall(
dtype_, DivNRanksForAllReduce<platform::CPUDeviceContext>(
tensor, nranks, context));
} else if (platform::is_xpu_place(tensor->place())) {
#ifdef PADDLE_WITH_XPU_BKCL
// TODO(liuyuhui) support xpu about div nranks in the future
#endif
}
}
template <typename DeviceContext, typename T>
static void ConcatTensorsForAllReduce(
const DeviceContext &context,
......@@ -276,6 +299,7 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
find_unused_vars_(find_unused_vars) {
VLOG(3) << "Start construct the Reducer ...";
nrings_ = parallel_ctx->GetNRings();
nranks_ = parallel_ctx->GetNRanks();
// initialize groups
InitializeGroups(group_indices);
for (size_t global_var_index = 0; global_var_index < vars_.size();
......@@ -444,7 +468,7 @@ void Reducer::PrepareForBackward(
PADDLE_ENFORCE_EQ(
all_group_ready_, false,
platform::errors::PreconditionNotMet(
"Please note that all ``forward`` outputs derived from the module "
"Please note that all forward outputs derived from the module "
"parameters must participate in the calculation of losses and "
"subsequent gradient calculations. If not, the wrapper will hang, "
"waiting for autograd to generate gradients for these parameters. "
......@@ -631,6 +655,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
if (group.sparse_contents_ != nullptr) {
VLOG(3) << "sparse group [" << next_group_
<< "] start allreduce in ring[" << run_order << "]";
group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_);
parallel_ctx_->AllReduceByStream(
*group.sparse_contents_, group.sparse_contents_, run_order, false);
} else {
......@@ -654,6 +679,7 @@ void Reducer::MarkGroupReady(size_t group_index) {
parallel_ctx_->WaitComm(run_order);
}
#endif
group.DivNRanks(*parallel_ctx_->GetDeviceContext(run_order), nranks_);
// Start allreduce
parallel_ctx_->AllReduceByStream(
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/reducer.h"
namespace paddle {
namespace imperative {
#if defined(PADDLE_WITH_NCCL)
void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks,
const platform::DeviceContext &context) {
framework::VisitDataTypeSmall(
dtype_, DivNRanksForAllReduce<platform::CUDADeviceContext>(tensor, nranks,
context));
}
#endif
} // namespace imperative
} // namespace paddle
......@@ -29,10 +29,12 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // namespace platform
namespace imperative {
......@@ -46,6 +48,37 @@ namespace paddle {
namespace imperative {
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
template <typename T>
struct DivNRanksFunctor {
DivNRanksFunctor(int64_t nranks, T* output)
: nranks_(nranks), output_(output) {}
HOSTDEVICE void operator()(size_t idx) const {
output_[idx] /= static_cast<T>(nranks_);
}
int64_t nranks_;
T* output_;
};
template <typename Dex>
struct DivNRanksForAllReduce {
framework::Tensor* in_;
int64_t nranks_;
const platform::DeviceContext& ctx_;
DivNRanksForAllReduce(framework::Tensor* in, int64_t nranks,
const platform::DeviceContext& ctx)
: in_(in), nranks_(nranks), ctx_(ctx) {}
template <typename T>
void apply() const {
T* data = in_->mutable_data<T>(ctx_.GetPlace());
platform::ForRange<Dex> for_range(static_cast<const Dex&>(ctx_),
static_cast<size_t>(in_->numel()));
DivNRanksFunctor<T> functor(nranks_, data);
for_range(functor);
}
};
class Group {
public:
// Here, we use dense_contents_ & sparse_contents_ to
......@@ -77,6 +110,12 @@ class Group {
// context is used to select the stream for split
void SplitTensors(const platform::DeviceContext& context);
// use it in CUDA
void DivNRanks(framework::Tensor* tensor, int64_t nranks,
const platform::DeviceContext& context);
void DivNRanks(const platform::DeviceContext& context, int64_t nranks);
friend std::ostream& operator<<(std::ostream&, const Group&);
};
......@@ -122,7 +161,6 @@ class Reducer {
private:
std::vector<std::shared_ptr<imperative::VarBase>> vars_;
std::vector<std::vector<size_t>> group_indices_;
static std::shared_ptr<Reducer> s_instance_;
std::vector<Group> groups_;
size_t next_group_ = 0;
platform::Place place_;
......@@ -132,6 +170,7 @@ class Reducer {
std::vector<VariableLocator> variable_locators_;
int nrings_ = 1;
int64_t nranks_ = -1;
// Following variables are to help rebuild group
// TODO(shenliang03): Support rebuild in the future.
......
......@@ -99,6 +99,8 @@ void GroupConcatSplit(Place place, size_t size) {
.mutable_data(place, group.dtype_);
group.ConcatTensors(*dev_ctx);
group.DivNRanks(*dev_ctx, 1);
framework::Tensor tmp;
framework::TensorCopySync(*tensor, cpu_place, &tmp);
auto* data = tmp.data<T>();
......
......@@ -308,6 +308,7 @@ def _split_tensors(coalesced_grads_and_grad_vars):
def scale_loss(loss):
# TODO(liuyuhui) Currently only for xpu. Will be removed in the future.
if not ParallelEnv().world_size > 1:
return loss
......
......@@ -170,7 +170,8 @@ def monkey_patch_varbase():
"""
if framework.in_dygraph_mode():
if paddle.distributed.get_world_size() > 1:
if paddle.is_compiled_with_xpu():
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss = scale_loss(self)
scaled_loss._run_backward(framework._dygraph_tracer(),
retain_graph)
......
......@@ -519,7 +519,8 @@ class TestParallelDyGraphRunnerBase(object):
loss.backward()
opt.minimize(loss)
model.clear_gradients()
if not args.accumulate_gradient:
model.clear_gradients()
print_to_out(out_losses)
def run_trainer_with_spawn(self, args):
......@@ -594,7 +595,8 @@ class TestParallelDyGraphRunnerBase(object):
loss.backward()
opt.step()
opt.clear_grad()
if not args.accumulate_gradient:
opt.clear_grad()
print_to_out(out_losses)
......@@ -625,6 +627,7 @@ def runtime_main(test_class):
parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_xpu', action='store_true')
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--accumulate_gradient', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument('--hogwild', action='store_true')
......@@ -722,6 +725,7 @@ class TestDistBase(unittest.TestCase):
self._use_hallreduce = False
self._save_model = False
self._fuse_all_reduce = None
self._accumulate_gradient = False
self._setup_config()
global DIST_UT_PORT
......@@ -845,6 +849,9 @@ class TestDistBase(unittest.TestCase):
if len(devices) > 1 and self._use_dgc:
cmd += " --use_dgc"
if self._accumulate_gradient:
cmd += " --accumulate_gradient"
env_local.update(envs)
print("local_cmd: {}, env: {}".format(cmd, env_local))
......@@ -1011,6 +1018,9 @@ class TestDistBase(unittest.TestCase):
if self._use_dgc:
tr_cmd += " --use_dgc"
if self._accumulate_gradient:
tr_cmd += " --accumulate_gradient"
if self._pipeline_mode:
tr_cmd += " --use_pipeline"
if self._mp_mode:
......
......@@ -60,7 +60,6 @@ class TestFleetDygraphSingle(unittest.TestCase):
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
adam.step()
adam.clear_grad()
......
......@@ -66,12 +66,13 @@ class TestParallelDygraphMnistSpawn(TestDistSpawnRunner):
self.check_dist_result_with_spawn(test_class=TestMnist, delta=1e-5)
class TestFleetDygraphMnist(TestDistBase):
class TestParallelDygraphMnistAccGrad(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._use_fleet_api = True
self._accumulate_gradient = True
def test_mnist(self):
if fluid.core.is_compiled_with_cuda():
......
......@@ -48,5 +48,21 @@ class TestParallelDygraphTransformerSpawn(TestDistSpawnRunner):
test_class=TestTransformer, delta=1e-5)
class TestParallelDygraphTransformerAccGrad(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
self._accumulate_gradient = True
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册