未验证 提交 652cf430 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #9746 from typhoonzero/multigpumultinode

[Feature] Enable multi gpu distributed training of fluid
......@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......@@ -15,7 +16,7 @@ else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
......@@ -54,6 +55,27 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
......@@ -76,27 +98,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// append send op if program is distributed trainer main program.
// always use the first device
if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0);
continue;
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(&result, op, p, i);
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
auto var_names = op->OutputArgumentNames();
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
......
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -41,6 +44,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
op_->Run(*local_scope_, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -48,13 +48,13 @@ class ParallelExecutor {
const std::string& fetched_var_name,
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
private:
void SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
};
} // namespace framework
......
......@@ -65,9 +65,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
}
void ProcGetResponse(const VarHandle& var_h,
// const sendrecv::VariableMessage& ret_msg) {
const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = NULL;
framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
}
......
......@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
for (int64_t i = 0; i < rows2->size(); ++i) {
for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
......
......@@ -553,6 +553,7 @@ All parameter, weight, gradient are variables in Paddle.
bcast_vars, main_program, loss_var_name,
scope, local_scopes, allow_op_delay);
})
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
.def("local_scopes",
[](ParallelExecutor &self) -> std::vector<Scope *> * {
return &self.GetLocalScopes();
......
......@@ -255,6 +255,7 @@ class DistributeTranspiler:
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program
......
......@@ -100,9 +100,11 @@ class ParallelExecutor(object):
local_scopes = share_vars_from.executor.local_scopes(
) if share_vars_from else []
persistable_vars = [
self.persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, main.list_vars())
for v in filter(lambda var: \
var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars())
]
self.executor = core.ParallelExecutor(
......@@ -113,7 +115,7 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(persistable_vars),
set(self.persistable_vars),
main.desc,
loss_name if loss_name else '',
scope,
......@@ -143,3 +145,6 @@ class ParallelExecutor(object):
self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
return [arr[i] for i in range(len(arr))]
def bcast_params(self):
self.executor.bcast_params(set(self.persistable_vars))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册