提交 ea669796 编写于 作者: Q Qiao Longfei

can run

上级 afda8401
......@@ -184,7 +184,7 @@ endif()
target_link_libraries(executor garbage_collector)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
......
......@@ -79,6 +79,8 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
......
......@@ -27,6 +27,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// set the correct size of thread pool to each device.
......
......@@ -116,7 +116,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass;
if (strategy_.is_distribution_) {
if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
......
......@@ -86,6 +86,7 @@ struct BuildStrategy {
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
bool async_mode_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
......
......@@ -975,3 +975,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass,
paddle::framework::details::AsyncSSAGraphBuilder);
......@@ -55,7 +55,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const;
bool NeedCollectiveOps() const;
virtual bool NeedCollectiveOps() const;
bool IsScaleLossOp(ir::Node *node) const;
......@@ -116,6 +116,20 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {}
bool NeedCollectiveOps() const override { return false; }
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const {
return false;
}
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
int GetVarDeviceID(const std::string &varname) const;
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
......@@ -282,10 +283,19 @@ ParallelExecutor::ParallelExecutor(
graphs.push_back(std::move(graph));
}
#else
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
if (build_strategy.async_mode_) {
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
}
#endif
auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) {
......@@ -323,23 +333,31 @@ ParallelExecutor::ParallelExecutor(
"please don't pass loss_var_name.";
}
}
if (build_strategy.enable_parallel_graph_) {
if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs)));
} else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs)));
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0])));
}
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
......@@ -401,14 +419,22 @@ void ParallelExecutor::BCastParamsToDevices(
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
auto share_memory = [&] {
t->Resize(dims);
t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t);
};
auto copy_memory = [&] { t->ShareDataWith(main_tensor); };
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->build_strategy_.async_mode_) {
share_memory();
} else if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
copy_memory();
} else {
t->ShareDataWith(main_tensor);
share_memory();
}
}
}
......
......@@ -1030,6 +1030,9 @@ All parameter, weight, gradient are variables in Paddle.
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) { self.is_distribution_ = b; })
.def_property("async_mode",
[](const BuildStrategy &self) { return self.async_mode_; },
[](BuildStrategy &self, bool b) { self.async_mode_ = b; })
.def_property(
"memory_early_delete",
[](const BuildStrategy &self) { return self.memory_early_delete_; },
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册