提交 0bf799a5 编写于 作者: T typhoonzero

wip testing

上级 b9c28df9
...@@ -16,7 +16,7 @@ else() ...@@ -16,7 +16,7 @@ else()
set(multi_devices_graph_builder_deps) set(multi_devices_graph_builder_deps)
endif() endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps}) scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context) simple_threadpool device_context)
...@@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -35,22 +35,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, bool distributed) platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
distributed_(distributed),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, bool distributed) const std::vector<Scope *> &local_scopes)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes) {
distributed_(distributed) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
...@@ -99,7 +97,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -99,7 +97,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
// append send op if program is distributed trainer main program. // append send op if program is distributed trainer main program.
// always use the first device // always use the first device
if (is_forwarding && distributed_ && op->Type() == "send") { if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
size_t i = 0; size_t i = 0;
......
...@@ -34,14 +34,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -34,14 +34,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs, platform::NCCLContextMap *nccl_ctxs);
bool distributed = false);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes);
bool distributed = false);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
...@@ -55,7 +53,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -55,7 +53,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
std::unordered_set<std::string> grad_names_; std::unordered_set<std::string> grad_names_;
bool distributed_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
......
...@@ -48,13 +48,13 @@ class ParallelExecutor { ...@@ -48,13 +48,13 @@ class ParallelExecutor {
const std::string& fetched_var_name, const std::string& fetched_var_name,
const std::unordered_map<std::string, LoDTensor>& feed_tensors); const std::unordered_map<std::string, LoDTensor>& feed_tensors);
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
private: private:
void SplitTensorToPlaces( void SplitTensorToPlaces(
const std::unordered_map<std::string, LoDTensor>& feed_tensors); const std::unordered_map<std::string, LoDTensor>& feed_tensors);
ParallelExecutorPrivate* member_; ParallelExecutorPrivate* member_;
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
}; };
} // namespace framework } // namespace framework
......
...@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) { for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
} }
for (int64_t i = 0; i < rows2->size(); ++i) { for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i); EXPECT_EQ(rows_data2[i], i);
} }
EXPECT_EQ(slr2->height(), 1000); EXPECT_EQ(slr2->height(), 1000);
......
...@@ -554,6 +554,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -554,6 +554,7 @@ All parameter, weight, gradient are variables in Paddle.
bcast_vars, main_program, loss_var_name, bcast_vars, main_program, loss_var_name,
scope, local_scopes, allow_op_delay); scope, local_scopes, allow_op_delay);
}) })
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
.def("local_scopes", .def("local_scopes",
[](ParallelExecutor &self) -> std::vector<Scope *> * { [](ParallelExecutor &self) -> std::vector<Scope *> * {
return &self.GetLocalScopes(); return &self.GetLocalScopes();
......
...@@ -99,7 +99,7 @@ class ParallelExecutor(object): ...@@ -99,7 +99,7 @@ class ParallelExecutor(object):
local_scopes = share_vars_from.executor.local_scopes( local_scopes = share_vars_from.executor.local_scopes(
) if share_vars_from else [] ) if share_vars_from else []
persistable_vars = [ self.persistable_vars = [
v.name v.name
for v in filter(lambda var: var.persistable, main.list_vars()) for v in filter(lambda var: var.persistable, main.list_vars())
] ]
...@@ -112,7 +112,7 @@ class ParallelExecutor(object): ...@@ -112,7 +112,7 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters() p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(persistable_vars), set(self.persistable_vars),
main.desc, main.desc,
loss_name if loss_name else '', loss_name if loss_name else '',
scope, scope,
...@@ -142,3 +142,6 @@ class ParallelExecutor(object): ...@@ -142,3 +142,6 @@ class ParallelExecutor(object):
self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict) self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def bcast_params(self):
self.executor.bcast_params(set(self.persistable_vars))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册