提交 a4951843 编写于 作者: S sneaxiy

inference feed partial data, test=develop

上级 6dadb5de
...@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr ...@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor) cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper INTERFACE SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
...@@ -65,6 +65,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -65,6 +65,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper
sequential_execution_pass sequential_execution_pass
modify_op_lock_and_record_event_pass modify_op_lock_and_record_event_pass
all_reduce_deps_pass all_reduce_deps_pass
...@@ -72,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto ...@@ -72,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass eager_deletion_pass
buffer_shared_inplace_op_pass buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass buffer_shared_cross_op_memory_reuse_pass
set_reader_device_count_pass) set_reader_device_info_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -66,7 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -66,7 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_fused_graph"); AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendMultiDevPass(); AppendMultiDevPass();
AppendSetReaderDeviceCountPass(); AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses(); AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass"); AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
...@@ -225,8 +225,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -225,8 +225,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_); &strategy_);
} }
void AppendSetReaderDeviceCountPass() { void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_count_pass"); AppendPass("set_reader_device_index_pass");
} }
void AppendPrintGraphPass(const std::string &pass_name, void AppendPrintGraphPass(const std::string &pass_name,
...@@ -399,12 +399,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, ...@@ -399,12 +399,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "set_reader_device_count_pass") { } else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
} }
VLOG(1) << "Start Apply Pass " << pass->Type(); VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph); graph = pass->Apply(graph);
...@@ -441,7 +438,7 @@ USE_PASS(fuse_sgd_op_pass); ...@@ -441,7 +438,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_count_pass); USE_PASS(set_reader_device_index_pass);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
......
...@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
OperatorBase *GetOp() { return op_.get(); } OperatorBase *GetOp() { return op_.get(); }
const OperatorBase *GetOp() const { return op_.get(); }
std::string Name() const override; std::string Name() const override;
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
......
...@@ -31,10 +31,12 @@ namespace framework { ...@@ -31,10 +31,12 @@ namespace framework {
namespace details { namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, Scope *scope, const platform::Place &place, ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc) const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
scope_idx_(scope_idx),
place_(place), place_(place),
var_infos_(vars.begin(), vars.end()), var_infos_(vars.begin(), vars.end()),
gc_(gc) { gc_(gc) {
......
...@@ -34,7 +34,7 @@ namespace details { ...@@ -34,7 +34,7 @@ namespace details {
class EagerDeletionOpHandle : public OpHandleBase { class EagerDeletionOpHandle : public OpHandleBase {
public: public:
EagerDeletionOpHandle(ir::Node *node, Scope *scope, EagerDeletionOpHandle(ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place, const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc); GarbageCollector *gc);
...@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
*/ */
Priority GetPriority() const override { return kHighest; } Priority GetPriority() const override { return kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
...@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void CallOnce(); void CallOnce();
Scope *scope_; Scope *scope_;
size_t scope_idx_;
platform::Place place_; platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own GarbageCollector *gc_; // not own
......
...@@ -12,9 +12,197 @@ ...@@ -12,9 +12,197 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {} // namespace details namespace details {
static constexpr size_t kUndefinedDevIdx = -1UL;
static std::unordered_set<std::string> kMultiDeviceOps{
"sync_batch_norm",
"sync_batch_norm_grad",
"allreduce",
"c_allreduce_sum",
"c_allreduce_prod",
"c_allreduce_min",
"c_allreduce_max",
"c_allgather",
"c_reducescatter",
"c_broadcast",
"c_comm_init",
"c_comm_init_all",
"c_gen_nccl_id",
"c_sync_comm_stream",
"send",
"recv",
"send_barrier",
"fetch_barrier",
};
static size_t GetScopeIdxFromOp(const details::OpHandleBase &op) {
if (auto *compute_op =
dynamic_cast<const details::ComputationOpHandle *>(&op)) {
return kMultiDeviceOps.count(compute_op->GetOp()->Type()) == 0
? compute_op->GetScopeIdx()
: kUndefinedDevIdx;
} else if (auto *gc_op =
dynamic_cast<const details::EagerDeletionOpHandle *>(&op)) {
return gc_op->GetScopeIdx();
} else if (auto *share_op =
dynamic_cast<const details::ShareTensorBufferOpHandle *>(
&op)) {
return share_op->GetScopeIdx();
} else {
return kUndefinedDevIdx;
}
}
static bool ContainMultiDeviceOp(const ProgramDesc &program,
size_t begin_block_idx) {
for (size_t block_idx = begin_block_idx; block_idx < program.Size();
++block_idx) {
for (auto *op_desc : program.Block(block_idx).AllOps()) {
if (kMultiDeviceOps.count(op_desc->Type()) > 0) {
return true;
}
}
}
return false;
}
static size_t GetUniqueDeviceIdOfOp(const details::OpHandleBase &op) {
size_t dev_idx = GetScopeIdxFromOp(op);
if (dev_idx == kUndefinedDevIdx) {
return kUndefinedDevIdx;
}
const auto &ins = op.Inputs();
const auto &outs = op.Outputs();
auto in_outs = ins;
in_outs.insert(in_outs.end(), outs.begin(), outs.end());
for (auto *var : in_outs) {
auto *var_handle = dynamic_cast<details::VarHandle *>(var);
if (var_handle == nullptr) {
continue;
}
if (dev_idx != var_handle->scope_idx()) {
return kUndefinedDevIdx;
}
}
return dev_idx;
}
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph) {
if (ContainMultiDeviceOp(graph->OriginProgram(), 1)) {
return {};
}
size_t place_num = 0;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
if (op_handles.empty()) {
return {};
}
std::unordered_map<details::OpHandleBase *, size_t> op_to_dev_idx;
for (auto &op : op_handles) {
auto dev_idx = GetUniqueDeviceIdOfOp(*op);
if (dev_idx == kUndefinedDevIdx) {
VLOG(10) << "Op " << op->Name() << " is not determined";
return {};
}
place_num = std::max(place_num, dev_idx + 1);
op_to_dev_idx[op] = dev_idx;
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
for (auto &in_var : op->Inputs()) {
if (in_var->GeneratedOp()) {
auto iter = op_to_dev_idx.find(in_var->GeneratedOp());
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
for (auto &out_var : op->Outputs()) {
for (auto &pending_op : out_var->PendingOps()) {
auto iter = op_to_dev_idx.find(pending_op);
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
}
PADDLE_ENFORCE_GE(place_num, 1, platform::errors::NotFound(
"No place found, this may be a bug"));
std::vector<std::unique_ptr<ir::Graph>> graphs(place_num);
for (auto &g : graphs) {
g.reset(new ir::Graph(ProgramDesc()));
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars());
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
auto *ret_graph = graphs[dev_idx].get();
auto &ret_vars = ret_graph->Get<GraphVars>(kGraphVars)[0];
auto &ret_dummy_vars = ret_graph->Get<GraphDepVars>(kGraphDepVars);
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
ret_graph->AddNode(graph->RemoveNode(op->Node()).release());
auto handler = [&](const std::vector<VarHandleBase *> &vars) {
for (auto *var : vars) {
if (graph->Nodes().count(var->Node()) > 0) {
ret_graph->AddNode(graph->RemoveNode(var->Node()).release());
auto *dummy_var = dynamic_cast<DummyVarHandle *>(var);
if (dummy_var == nullptr) {
ret_vars.emplace(var->Name(), origin_vars.at(var->Name()));
} else {
ret_dummy_vars.emplace(dummy_var);
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
}
graph->Erase(kGraphVars);
graph->Erase(kGraphDepVars);
return graphs;
}
bool HasDropLastReadOp(const ir::Graph &graph) {
auto ops = ir::FilterByNodeWrapper<OpHandleBase>(graph);
for (auto *op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op && compute_op->GetOp()->Type() == "read" &&
compute_op->GetOp()->Attr<bool>("drop_last")) {
VLOG(10) << "The graph has drop_last=True read op";
return true;
}
}
VLOG(10) << "The graph does not have drop_last=True read op";
return false;
}
} // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars"; ...@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kNRanks[] = "nranks"; constexpr char kNRanks[] = "nranks";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kGlobalScope[] = "global_scope";
constexpr char kLocalScopes[] = "local_scopes"; constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs"; constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce"; constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
...@@ -100,6 +101,11 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) { ...@@ -100,6 +101,11 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
return boost::get<std::vector<std::string>>(iter->second); return boost::get<std::vector<std::string>>(iter->second);
} }
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph);
bool HasDropLastReadOp(const ir::Graph &graph);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -21,11 +22,11 @@ namespace paddle { ...@@ -21,11 +22,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::vector<std::unique_ptr<ir::Graph>> static std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ir::Graph *graph, size_t place_num) {
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size()); graphs.reserve(place_num);
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < place_num; ++i) {
ProgramDesc empty; ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty))); graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back(); auto &g = graphs.back();
...@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ...@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
} }
} }
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < place_num; ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0]; auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id]; auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) { for (auto &name_pair : origin_vars) {
...@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
: ParallelSSAGraphExecutor(strategy, local_scopes, local_exec_scopes,
places,
SeparateMultiDevicesGraph(graph,
places.size())) {}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(places),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the graphs_(std::move(graphs)),
// attrs. feed_status_(places.size(), FeedStatus::kNone) {
graphs_(SeparateMultiDevicesGraph(graph)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(),
platform::errors::InvalidArgument(
"Graph number does not match place number"));
PADDLE_ENFORCE_GT(
places_.size(), 0,
platform::errors::InvalidArgument("place number must be larger than 0"));
auto seq_allreduce_pass = auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass"); ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false)); seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
...@@ -123,22 +143,43 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() { ...@@ -123,22 +143,43 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return result; return result;
} }
enum ExceptionStatus { kSuccess = 0, kEOF, kOther };
FeedFetchList ParallelSSAGraphExecutor::Run( FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
size_t feed_num = std::count(feed_status_.begin(), feed_status_.end(),
FeedStatus::kHasFeed);
bool has_feed = (feed_num > 0);
VLOG(10) << "Feed num " << feed_num;
size_t place_num = places_.size();
std::vector<std::future<FeedFetchList>> run_futures; std::vector<std::future<FeedFetchList>> run_futures;
std::vector<ExceptionStatus> exception_status(place_num,
ExceptionStatus::kSuccess);
std::vector<FeedFetchList> fetch_data; std::vector<FeedFetchList> fetch_data;
FeedFetchList ret; FeedFetchList ret;
fetch_data.reserve(places_.size()); fetch_data.reserve(place_num);
ret.reserve(fetch_tensors.size()); ret.reserve(place_num);
exception_holder_.Clear(); exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < place_num; ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList { auto call = [&, i]() -> FeedFetchList {
try { try {
if (!support_partial_feed_ || !has_feed ||
feed_status_[i] == FeedStatus::kHasFeed) {
return executors_[i]->Run(fetch_tensors); return executors_[i]->Run(fetch_tensors);
} else {
return FeedFetchList();
}
} catch (platform::EOFException &) {
exception_status[i] = ExceptionStatus::kEOF;
exception_holder_.Catch(std::current_exception());
} catch (...) { } catch (...) {
exception_status[i] = ExceptionStatus::kOther;
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
return FeedFetchList(); return FeedFetchList();
...@@ -153,21 +194,63 @@ FeedFetchList ParallelSSAGraphExecutor::Run( ...@@ -153,21 +194,63 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if (pool_) { if (pool_) {
for (auto &f : run_futures) { for (auto &f : run_futures) {
if (exception_holder_.IsCaught()) {
f.wait();
} else {
fetch_data.emplace_back(f.get()); fetch_data.emplace_back(f.get());
} }
} }
bool has_exception = exception_holder_.IsCaught();
if (!support_partial_feed_ && has_exception) {
VLOG(10) << "Exception rethrow because partial feed is not supported";
exception_holder_.ReThrow();
} }
if (exception_holder_.IsCaught()) {
std::vector<bool> is_valid(place_num, true);
if (support_partial_feed_) {
if (has_feed) {
for (size_t i = 0; i < place_num; ++i) {
if (feed_status_[i] == FeedStatus::kNone) {
is_valid[i] = false;
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is given";
exception_holder_.ReThrow();
}
}
} else {
for (size_t i = 0; i < place_num; ++i) {
if (exception_status[i] == ExceptionStatus::kOther) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is not given";
exception_holder_.ReThrow();
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
is_valid[i] = false;
}
}
}
}
if (std::count(is_valid.begin(), is_valid.end(), true) == 0) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Raise exception because there is no success worker";
exception_holder_.ReThrow(); exception_holder_.ReThrow();
} }
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) { for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs; std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size()); lodtensor_ptrs.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
if (!is_valid[scope_idx]) {
continue;
}
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx)); lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
} }
ret.emplace_back(); ret.emplace_back();
......
...@@ -27,12 +27,25 @@ namespace framework { ...@@ -27,12 +27,25 @@ namespace framework {
namespace details { namespace details {
class ParallelSSAGraphExecutor : public SSAGraphExecutor { class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
enum FeedStatus {
kNone = 0, // No feed
kHasFeed = 1 // Has feed
};
public: public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes, const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
...@@ -41,10 +54,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -41,10 +54,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private: void SetHasFeed(size_t dev_idx, bool has_feed) {
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph( feed_status_[dev_idx] = has_feed ? FeedStatus::kHasFeed : FeedStatus::kNone;
ir::Graph *graph); }
void EnablePartialFeedSupport() { support_partial_feed_ = true; }
bool SupportPartialFeed() const { return support_partial_feed_; }
private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
...@@ -54,6 +72,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -54,6 +72,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>> std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>>
executors_; executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
bool support_partial_feed_{false};
std::vector<FeedStatus> feed_status_;
}; };
} // namespace details } // namespace details
......
...@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} }
auto *eager_deletion_op = new details::EagerDeletionOpHandle( auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), eager_deletion_node, op->GetScope(), op->GetScopeIdx(), op->GetPlace(),
std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get()); std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
auto it = std::find_if( auto it = std::find_if(
......
...@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper { ...@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
ir::PassRegistry::Instance().Get("reference_count_pass"); ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_); ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_);
ref_cnt_pass->Apply(&graph_); ref_cnt_pass->Apply(&const_cast<ir::Graph &>(executor_->Graph()));
} }
bool IsLastLivedOps(const std::string &name, bool IsLastLivedOps(const std::string &name,
......
...@@ -11,7 +11,7 @@ endif() ...@@ -11,7 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_count_pass SRCS set_reader_device_count_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass) cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
...@@ -22,35 +22,44 @@ namespace paddle { ...@@ -22,35 +22,44 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class SetReaderDeviceCountPass : public Pass { static int GetDeviceCountFromPassAttr(const Pass &pass) {
protected:
void ApplyImpl(Graph *graph) const override;
private:
int GetDeviceCount() const;
std::unordered_set<std::string> ReaderOpSet() const;
const Scope *GlobalScope() const;
};
int SetReaderDeviceCountPass::GetDeviceCount() const {
return static_cast<int>( return static_cast<int>(
Get<const std::vector<platform::Place>>(details::kPlaces).size()); pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
} }
std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const { static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"}; return {"create_py_reader"};
} }
const Scope *SetReaderDeviceCountPass::GlobalScope() const { class InitReaderDeviceCountPass : public Pass {
return Get<const std::vector<Scope *>>(details::kLocalScopes)[0]; protected:
} void ApplyImpl(Graph *graph) const override {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
auto dev_cnt = GetDeviceCount();
auto reader_ops = ReaderOpSet(); auto reader_ops = ReaderOpSet();
auto scope = GlobalScope(); auto dev_cnt = GetDeviceCountFromPassAttr(*this);
const auto &scope = Get<const Scope>(details::kGlobalScope);
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
}
}
}
};
class SetReaderDeviceIndexPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0; size_t found_op_num = 0;
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
...@@ -69,30 +78,24 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const { ...@@ -69,30 +78,24 @@ void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
op_base_attrs["device_index"] = dev_idx; op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt; op_base_attrs["device_count"] = dev_cnt;
auto queue_name = op_handle.GetOp()->Input("blocking_queue");
auto var = scope->FindVar(queue_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Blocking queue of DataLoader not found"));
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
if (var->IsType<QueueHolder>()) {
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
++found_op_num; ++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx; VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
} }
} }
VLOG(10) << "Found op number " << found_op_num; VLOG(10) << "Found op number " << found_op_num;
} }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(set_reader_device_count_pass, REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::SetReaderDeviceCountPass) paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces); .RequirePassAttr(paddle::framework::details::kPlaces);
...@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor, ...@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor( std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const { const std::vector<platform::Place> places) const {
PADDLE_ENFORCE_GT(places.size(), 0,
platform::errors::InvalidArgument(
"place number cannot be empty when splitting"));
check_memory_size(); check_memory_size();
int batch_size = size_t batch_size =
lod().empty() ? dims()[0] : static_cast<int>(lod()[0].size()) - 1; lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1;
size_t result_size = std::min(static_cast<size_t>(batch_size), places.size());
size_t remainder = batch_size % places.size();
std::vector<LoDTensor> results; // if batch_size is 0, just return #places.size() copys of empty
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// tensors. // tensors.
if (result_size == 0) { if (batch_size == 0) {
std::vector<LoDTensor> empty_results;
empty_results.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst; LoDTensor dst;
dst.Resize(dims()); dst.Resize(dims());
...@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if (!lod().empty()) { if (!lod().empty()) {
dst.set_lod(lod()); dst.set_lod(lod());
} }
results.emplace_back(dst); empty_results.emplace_back(std::move(dst));
} }
return results; return empty_results;
} }
int step_width = static_cast<int>(batch_size / result_size); auto step_width = (batch_size + places.size() - 1) / places.size();
auto result_size = (batch_size + step_width - 1) / step_width;
std::vector<LoDTensor> results;
results.reserve(result_size);
for (size_t i = 0; i < result_size; ++i) { for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width); auto begin = i * step_width;
int end = static_cast<int>((i + 1) * step_width); auto end = std::min<size_t>((i + 1) * step_width, batch_size);
if (i + 1 == places.size()) { // last PADDLE_ENFORCE_LT(begin, end,
end += remainder; platform::errors::InvalidArgument(
} "begin must be less than end, this may be a bug"));
LoDTensor dst; LoDTensor dst;
if (lod().empty()) { if (lod().empty()) {
...@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor( ...@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
} }
dst.set_lod(my_lod); dst.set_lod(my_lod);
} }
results.emplace_back(dst); results.emplace_back(std::move(dst));
} }
return results; return results;
......
...@@ -55,8 +55,9 @@ static bool gProfileStarted = false; ...@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) ParallelExecutorPrivate(const std::vector<platform::Place> &places,
: places_(places) { Scope *global_scope)
: places_(places), global_scope_(global_scope) {
if (!FLAGS_pe_profile_fname.empty()) { if (!FLAGS_pe_profile_fname.empty()) {
std::call_once(gProfileOnce, [] { std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
...@@ -82,6 +83,19 @@ class ParallelExecutorPrivate { ...@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
} }
} }
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const;
ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph); ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); } inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
...@@ -257,8 +271,20 @@ class ParallelExecutorPrivate { ...@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
ir::MemOptVarInfoMapList mem_opt_var_infos_; ir::MemOptVarInfoMapList mem_opt_var_infos_;
ir::GarbageCollectorMap gcs_; ir::GarbageCollectorMap gcs_;
details::ParallelSSAGraphExecutor *inference_executor_{nullptr};
}; };
void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) {
if (inference_executor_) {
inference_executor_->SetHasFeed(dev_idx, has_feed);
}
}
bool ParallelExecutorPrivate::AllowPartialFeed() const {
return inference_executor_ && inference_executor_->SupportPartialFeed();
}
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
if (FLAGS_use_ngraph) { if (FLAGS_use_ngraph) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
...@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ...@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return graph; return graph;
} }
class ResetHasFeedGuard {
public:
explicit ResetHasFeedGuard(ParallelExecutorPrivate *pe_member)
: pe_member_(pe_member) {}
~ResetHasFeedGuard() {
for (size_t i = 0; i < pe_member_->places_.size(); ++i) {
pe_member_->SetHasFeed(i, false);
}
}
private:
ParallelExecutorPrivate *pe_member_;
};
size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); } size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); }
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
...@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const ExecutionStrategy &exec_strategy, const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy, const BuildStrategy &build_strategy,
ir::Graph *graph) ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) { : member_(new ParallelExecutorPrivate(places, scope)) {
member_->global_scope_ = scope; member_->InitReaderDeviceCount(graph);
member_->use_cuda_ = exec_strategy.use_cuda_; member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy; member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ == member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
...@@ -605,6 +646,22 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -605,6 +646,22 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif #endif
} else {
bool has_drop_last_read_op = details::HasDropLastReadOp(*graph);
auto possible_inference_graphs =
details::TrySeparateToMultipleSingleDeviceGraphs(graph);
if (!possible_inference_graphs.empty()) {
VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase";
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, std::move(possible_inference_graphs));
if (!has_drop_last_read_op) {
VLOG(5) << "Enable partial feed support in inference phase";
pg_exe->EnablePartialFeedSupport();
}
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
member_->inference_executor_ = pg_exe;
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor"; VLOG(3) << "use ThreadedSSAGraphExecutor";
...@@ -619,6 +676,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -619,6 +676,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
final_graphs.emplace_back(graph); final_graphs.emplace_back(graph);
} }
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!member_->build_strategy_.async_mode_) { if (!member_->build_strategy_.async_mode_) {
...@@ -724,6 +782,8 @@ FeedFetchList ParallelExecutor::Run( ...@@ -724,6 +782,8 @@ FeedFetchList ParallelExecutor::Run(
platform::RecordBlock b(0); platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors, ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->HasGarbageCollectors()); member_->HasGarbageCollectors());
...@@ -734,10 +794,22 @@ FeedFetchList ParallelExecutor::Run( ...@@ -734,10 +794,22 @@ FeedFetchList ParallelExecutor::Run(
void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) { const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size()); if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(
member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number does not match the device number"));
} else {
PADDLE_ENFORCE_GE(member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number exceeds the device number"));
}
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i]; auto &map = tensors[i];
if (!map.empty()) {
member_->SetHasFeed(i);
}
for (auto &pair : map) { for (auto &pair : map) {
bool is_persistable = member_->IsPersistable(pair.first); bool is_persistable = member_->IsPersistable(pair.first);
if (!is_persistable) { if (!is_persistable) {
...@@ -757,6 +829,11 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( ...@@ -757,6 +829,11 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) { const std::unordered_map<std::string, LoDTensor> &tensors) {
size_t num_places = member_->places_.size(); size_t num_places = member_->places_.size();
bool allow_partial_feed = member_->AllowPartialFeed();
size_t persistable_feed_len = -1UL;
size_t non_persistable_feed_len = -1UL;
for (auto &pair : tensors) { for (auto &pair : tensors) {
bool is_persistable = member_->IsPersistable(pair.first); bool is_persistable = member_->IsPersistable(pair.first);
VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable") VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable")
...@@ -764,7 +841,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -764,7 +841,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
<< ", place: " << pair.second.place(); << ", place: " << pair.second.place();
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_); auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
bool is_cpu_place = platform::is_cpu_place(member_->places_.front()); bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
if (!is_persistable && num_places != lod_tensors.size()) { if (!is_persistable && num_places != lod_tensors.size() &&
!allow_partial_feed) {
auto error_info = string::Sprintf( auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of current batch is less than the " "The number(%d) of samples[%s] of current batch is less than the "
"count(%d) of devices(%s), currently, it is not allowed. ", "count(%d) of devices(%s), currently, it is not allowed. ",
...@@ -790,7 +868,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -790,7 +868,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
framework::TensorCopy(pair.second, member_->places_.at(i), &tmp); framework::TensorCopy(pair.second, member_->places_.at(i), &tmp);
} }
} }
if (lod_tensors.size() != num_places) { if (lod_tensors.size() != num_places && !allow_partial_feed) {
auto error_info = string::Sprintf( auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of the current batch does not match " "The number(%d) of samples[%s] of the current batch does not match "
"the count(%d) of devices(%s). Because that %s is a persistable " "the count(%d) of devices(%s). Because that %s is a persistable "
...@@ -804,7 +882,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -804,7 +882,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
} }
for (size_t j = 0; j < num_places; ++j) { if (allow_partial_feed) {
if (is_persistable) {
if (persistable_feed_len == -1UL) {
persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different persistable variables "
"should be the same"));
}
} else {
if (non_persistable_feed_len == -1UL) {
non_persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
non_persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different non-persistable variables "
"should be the same"));
}
}
}
for (size_t j = 0; j < lod_tensors.size(); ++j) {
auto *feed_scope = is_persistable ? member_->local_scopes_[j] auto *feed_scope = is_persistable ? member_->local_scopes_[j]
: member_->local_exec_scopes_[j]; : member_->local_exec_scopes_[j];
auto *feed_var = feed_scope->Var(pair.first); auto *feed_var = feed_scope->Var(pair.first);
...@@ -814,6 +916,19 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -814,6 +916,19 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t->set_lod(lod_tensors[j].lod()); t->set_lod(lod_tensors[j].lod());
} }
} }
if (allow_partial_feed && persistable_feed_len != -1UL &&
non_persistable_feed_len != -1UL) {
VLOG(10) << "Persistable len " << persistable_feed_len;
VLOG(10) << "Non persistable len " << non_persistable_feed_len;
PADDLE_ENFORCE_GE(persistable_feed_len, non_persistable_feed_len,
platform::errors::InvalidArgument(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"));
for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i);
}
}
} }
ParallelExecutor::~ParallelExecutor() { ParallelExecutor::~ParallelExecutor() {
...@@ -864,6 +979,10 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -864,6 +979,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return enable_parallel_graph; return enable_parallel_graph;
} }
const ir::Graph &ParallelExecutor::Graph() const {
return member_->executor_->Graph();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -871,3 +990,4 @@ USE_PASS(reference_count_pass); ...@@ -871,3 +990,4 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass); USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
...@@ -79,6 +79,8 @@ class ParallelExecutor { ...@@ -79,6 +79,8 @@ class ParallelExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors); FeedFetchList Run(const std::vector<std::string> &fetch_tensors);
const ir::Graph &Graph() const;
private: private:
// broadcast the parameters from the 0th device. // broadcast the parameters from the 0th device.
// trainer_id the trainer index in nccl distributed training. // trainer_id the trainer index in nccl distributed training.
......
...@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" and it is set by ParallelExecutor instance, not users.") " and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true); AddAttr<bool>("infer_out", "").SetDefault(true);
AddAttr<bool>("drop_last",
"Whether to drop last batches whose number is less than CPU "
"cores/GPU cards number")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "Python.h" #include "Python.h"
#include "boost/optional.hpp"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
...@@ -41,6 +42,58 @@ namespace pybind { ...@@ -41,6 +42,58 @@ namespace pybind {
namespace py = pybind11; namespace py = pybind11;
namespace reader = operators::reader; namespace reader = operators::reader;
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static boost::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
const framework::LoDTensor &tensor, const framework::VarDesc &var_desc,
size_t num_places) {
auto tensor_shape = tensor.dims();
auto desc_shape = var_desc.GetShape();
int64_t rank = tensor_shape.size();
if (UNLIKELY(rank == 0)) {
if (desc_shape.size() != 0) { // Tensor rank = 0 but desc does not match
return framework::vectorize<int64_t>(tensor_shape);
} else {
return boost::none;
}
}
PADDLE_ENFORCE_GE(tensor_shape[0], 0,
platform::errors::InvalidArgument(
"Tensor shape must not be less than 0"));
if (!tensor.lod().empty()) {
tensor_shape[0] = -1; // unknown shape
} else {
int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
tensor_shape[0] = split_size;
if (desc_shape[0] >= 0) { // need check dim 0
if (tensor_shape[0] != desc_shape[0]) {
return framework::vectorize<int64_t>(tensor_shape);
}
if (remainder > 0) {
tensor_shape[0] = remainder;
return framework::vectorize<int64_t>(tensor_shape);
}
}
}
for (int64_t idx = 1; idx < rank; ++idx) {
PADDLE_ENFORCE_GE(tensor_shape[idx], 0,
platform::errors::InvalidArgument(
"Tensor shape must not be less than 0"));
if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
return framework::vectorize<int64_t>(tensor_shape);
}
}
return boost::none;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue( static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) { const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
return queue; return queue;
...@@ -66,10 +119,12 @@ class MultiDeviceFeedReader { ...@@ -66,10 +119,12 @@ class MultiDeviceFeedReader {
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer) const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last)
: queue_(queue), : queue_(queue),
names_(names), names_(names),
pool_(new ::ThreadPool(dst_places.size())) { pool_(new ::ThreadPool(dst_places.size())),
drop_last_(drop_last) {
std::vector<framework::DDim> dims; std::vector<framework::DDim> dims;
for (auto &shape : shapes) { for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape)); dims.push_back(framework::make_ddim(shape));
...@@ -113,14 +168,18 @@ class MultiDeviceFeedReader { ...@@ -113,14 +168,18 @@ class MultiDeviceFeedReader {
ReadAsync(); ReadAsync();
} }
bool DropLast() const { return drop_last_; }
ResultDictList ReadNext() { ResultDictList ReadNext() {
CheckNextStatus(); CheckNextStatus();
ResultDictList result(ret_.size()); ResultDictList result(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) { for (size_t i = 0; i < ret_.size(); ++i) {
if (!ret_[i].empty()) {
for (size_t j = 0; j < names_.size(); ++j) { for (size_t j = 0; j < names_.size(); ++j) {
result[i].emplace(names_[j], std::move(ret_[i][j])); result[i].emplace(names_[j], std::move(ret_[i][j]));
} }
} }
}
ReadAsync(); ReadAsync();
return result; return result;
} }
...@@ -155,24 +214,29 @@ class MultiDeviceFeedReader { ...@@ -155,24 +214,29 @@ class MultiDeviceFeedReader {
}; };
Status WaitFutures(std::exception_ptr *excep) { Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr; *excep = nullptr;
size_t success_num = 0;
for (size_t i = 0; i < futures_.size(); ++i) { for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get(); auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) { if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) { if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]); PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i]; *excep = exceptions_[i];
exceptions_[i] = nullptr; exceptions_[i] = nullptr;
} }
} else {
++success_num;
} }
} }
if (UNLIKELY(*excep)) { if (UNLIKELY(*excep)) {
return Status::kException; return Status::kException;
}
if (drop_last_) {
return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
} else { } else {
return is_success ? Status::kSuccess : Status::kEOF; return success_num > 0 ? Status::kSuccess : Status::kEOF;
} }
} }
...@@ -226,6 +290,7 @@ class MultiDeviceFeedReader { ...@@ -226,6 +290,7 @@ class MultiDeviceFeedReader {
std::vector<std::exception_ptr> exceptions_; std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_; std::vector<std::vector<framework::LoDTensor>> ret_;
bool drop_last_;
}; };
template <typename QueueType> template <typename QueueType>
...@@ -270,6 +335,17 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) { ...@@ -270,6 +335,17 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
void BindReader(py::module *module) { void BindReader(py::module *module) {
auto &m = *module; auto &m = *module;
m.def("diff_tensor_shape", [](const framework::LoDTensor &tensor,
const framework::VarDesc &var_desc,
size_t num_places) -> py::object {
auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
if (diff) {
return py::cast(std::move(diff.get()));
} else {
return py::cast(nullptr);
}
});
m.def("init_lod_tensor_blocking_queue", m.def("init_lod_tensor_blocking_queue",
[](framework::Variable &var, size_t capacity, [](framework::Variable &var, size_t capacity,
bool is_ordered) -> py::object { bool is_ordered) -> py::object {
...@@ -337,10 +413,10 @@ void BindReader(py::module *module) { ...@@ -337,10 +413,10 @@ void BindReader(py::module *module) {
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, const std::vector<platform::Place> &dst_places,
bool use_double_buffer) { bool use_double_buffer, bool drop_last) {
return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>( return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places, queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer); use_double_buffer, drop_last);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
...@@ -352,13 +428,13 @@ void BindReader(py::module *module) { ...@@ -352,13 +428,13 @@ void BindReader(py::module *module) {
const std::vector<std::vector<int>> &shapes, const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes, const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed, const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool use_double_buffer) { bool drop_last) {
queue->SetDeviceCount(dst_places.size()); queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader< return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>( reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places, queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer); use_double_buffer, drop_last);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership);
} }
......
...@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1): ...@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
the feed value the feed value
""" """
if var.desc.need_check_feed(): if var.desc.need_check_feed():
feed_shape = feed.shape() diff_shape = core.diff_tensor_shape(feed, var.desc, num_places)
if six.PY2: if diff_shape is not None:
feed_shape[0] = long(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
else:
feed_shape[0] = int(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
if not dimension_is_compatible_with(feed_shape, var.shape):
raise ValueError( raise ValueError(
'The feeded Variable %r should have dimensions = %d, shape = ' 'The feeded Variable %r should have dimensions = %d, shape = '
'%r, but received feeded shape %r on each device' % '%r, but received feeded shape %r on each device' %
(var.name, len(var.shape), var.shape, feed_shape)) (var.name, len(var.shape), var.shape, diff_shape))
if not dtype_is_compatible_with(feed._dtype(), var.dtype): if not dtype_is_compatible_with(feed._dtype(), var.dtype):
var_dtype_format = convert_dtype(var.dtype) if isinstance( var_dtype_format = convert_dtype(var.dtype) if isinstance(
var.dtype, core.VarDesc.VarType) else var.dtype var.dtype, core.VarDesc.VarType) else var.dtype
...@@ -646,11 +640,6 @@ class Executor(object): ...@@ -646,11 +640,6 @@ class Executor(object):
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(program._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list() res = list()
for i, each in enumerate(feed): for i, each in enumerate(feed):
if not isinstance(each, dict): if not isinstance(each, dict):
......
...@@ -88,6 +88,7 @@ class DataLoader(object): ...@@ -88,6 +88,7 @@ class DataLoader(object):
iterable=True, iterable=True,
return_list=False, return_list=False,
use_multiprocess=False, use_multiprocess=False,
drop_last=True,
keep_order=False): keep_order=False):
""" """
Create a DataLoader object for loading data from Python generator. Create a DataLoader object for loading data from Python generator.
...@@ -134,6 +135,9 @@ class DataLoader(object): ...@@ -134,6 +135,9 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode, can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect. whether this parameter is set or not has no effect.
The Default value is False. The Default value is False.
drop_last (bool): whether to drop the last batches whose number is
less than the CPU core/GPU card number. The default value is
True.
keep_order (bool): whether to assign the data to CPU cores or GPU keep_order (bool): whether to assign the data to CPU cores or GPU
cards in order. Supposing that there are 2 batches and we use cards in order. Supposing that there are 2 batches and we use
2 GPU cards to run the network. If keep_order=True, GPU 0 would 2 GPU cards to run the network. If keep_order=True, GPU 0 would
...@@ -289,7 +293,7 @@ class DataLoader(object): ...@@ -289,7 +293,7 @@ class DataLoader(object):
return_list, use_multiprocess) return_list, use_multiprocess)
else: else:
return GeneratorLoader(feed_list, capacity, use_double_buffer, return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list, keep_order) iterable, return_list, drop_last, keep_order)
@staticmethod @staticmethod
def from_dataset(dataset, places, drop_last=True): def from_dataset(dataset, places, drop_last=True):
...@@ -422,7 +426,7 @@ class DygraphGeneratorLoader(DataLoaderBase): ...@@ -422,7 +426,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
core.Variable(), self._capacity, False) core.Variable(), self._capacity, False)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer, True)
def _start(self): def _start(self):
if self._use_multiprocess: if self._use_multiprocess:
...@@ -628,6 +632,7 @@ class GeneratorLoader(DataLoaderBase): ...@@ -628,6 +632,7 @@ class GeneratorLoader(DataLoaderBase):
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False, return_list=False,
drop_last=True,
keep_order=False): keep_order=False):
self._tensor_reader = None self._tensor_reader = None
self._places = None self._places = None
...@@ -635,6 +640,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -635,6 +640,8 @@ class GeneratorLoader(DataLoaderBase):
self._queue = None self._queue = None
self._feed_list = feed_list self._feed_list = feed_list
self._exited = False self._exited = False
self._drop_last = drop_last
self._keep_order = keep_order
if not capacity: if not capacity:
raise ValueError("Please give value to capacity.") raise ValueError("Please give value to capacity.")
self._iterable = iterable self._iterable = iterable
...@@ -643,7 +650,6 @@ class GeneratorLoader(DataLoaderBase): ...@@ -643,7 +650,6 @@ class GeneratorLoader(DataLoaderBase):
raise Exception("Feed list must be given under static mode.") raise Exception("Feed list must be given under static mode.")
self._use_double_buffer = use_double_buffer self._use_double_buffer = use_double_buffer
self._capacity = capacity self._capacity = capacity
self._keep_order = keep_order
if not self._iterable: if not self._iterable:
self._init_non_iterable() self._init_non_iterable()
...@@ -667,7 +673,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -667,7 +673,8 @@ class GeneratorLoader(DataLoaderBase):
core.Variable(), self._capacity, self._keep_order) core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader( self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes, self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer) self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last)
def _init_non_iterable(self): def _init_non_iterable(self):
lod_levels = [] lod_levels = []
...@@ -744,7 +751,8 @@ class GeneratorLoader(DataLoaderBase): ...@@ -744,7 +751,8 @@ class GeneratorLoader(DataLoaderBase):
default_main_program().current_block().append_op( default_main_program().current_block().append_op(
type='read', type='read',
inputs={'Reader': [self._reader]}, inputs={'Reader': [self._reader]},
outputs={'Out': self._feed_list}) outputs={'Out': self._feed_list},
attrs={'drop_last': self._drop_last})
@property @property
def queue(self): def queue(self):
......
...@@ -355,4 +355,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu ...@@ -355,4 +355,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_dataloader_keep_order test_optimizer_in_control_flow test_dataloader_keep_order
test_parallel_executor_inference_feed_partial_data
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
import unittest
import six
class TestInferencePartialFeed(unittest.TestCase):
def setUp(self):
self.iterations = 10
self.size = 10
def run_network(self, places, use_split):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y)
relu_lr = fluid.layers.relu(lr)
exe = fluid.Executor(places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog).with_data_parallel(
places=places)
gen_random = lambda shape:np.random.uniform(low=-1.0, high=1.0, size=shape).astype('float32')
assert_result = lambda feed, result: self.assertTrue(np.array_equal(np.maximum(0, feed), result))
def feed_split_test():
for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size])
if place_num <= len(places):
lr_np = gen_random([place_num])
else:
lr_np = gen_random([1])
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog,
feed={x.name: x_np,
y.name: y_np,
lr.name: lr_np},
fetch_list=[relu_x, relu_y, relu_lr])
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
if place_num <= len(places):
assert_result(lr_np, relu_lr_np)
else:
expected_relu_lr_np = max(lr_np[0], 0)
self.assertTrue(np.all(expected_relu_lr_np == relu_lr_np))
def feed_list_test():
for place_num in six.moves.range(1, len(places) + 1):
x_np_list = []
y_np_list = []
lr_np_list = []
feed_list = []
for _ in six.moves.range(place_num):
x_np = gen_random([1, self.size])
y_np = gen_random([1, self.size])
lr_np = gen_random([1])
x_np_list.append(x_np)
y_np_list.append(y_np)
lr_np_list.append(lr_np)
feed_list.append({
x.name: x_np,
y.name: y_np,
lr.name: lr_np
})
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog, feed=feed_list, fetch_list=[relu_x, relu_y, relu_lr])
x_np = np.concatenate(x_np_list)
y_np = np.concatenate(y_np_list)
lr_np = np.concatenate(lr_np_list)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
assert_result(lr_np, relu_lr_np)
for _ in six.moves.range(self.iterations):
if use_split:
feed_split_test()
else:
feed_list_test()
def test_main(self):
places = [fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
places.append(fluid.cuda_places())
for p in places:
self.run_network(p, use_split=True)
self.run_network(p, use_split=False)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
def setUp(self):
self.epoch_num = 3
self.batch_num = 101 # a prime number
self.batch_size = 32
def create_reader(self):
def __impl__():
for _ in six.moves.range(self.batch_num):
yield np.random.random([self.batch_size, 1]).astype('float32'),
return __impl__
def run_network(self, iterable, use_cuda, drop_last):
x = fluid.data(shape=[None, 1], name='x', dtype='float32')
places = fluid.cuda_places() if use_cuda else fluid.cpu_places(4)
loader = fluid.io.DataLoader.from_generator(
feed_list=[x], capacity=16, iterable=iterable, drop_last=drop_last)
y = fluid.layers.fc(x, size=10)
loss = fluid.layers.reduce_mean(y)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
places=places, loss_name=loss.name)
loader.set_batch_generator(
self.create_reader(), places=places if iterable else None)
for _ in six.moves.range(self.epoch_num):
actual_batch_num = 0
if loader.iterable:
for feed_data in loader():
x_data, = exe.run(prog, feed=feed_data, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] / self.batch_size)
else:
loader.start()
try:
while True:
x_data, = exe.run(prog, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] /
self.batch_size)
except fluid.core.EOFException:
loader.reset()
if not drop_last or len(places) == 1:
self.assertEqual(self.batch_num, actual_batch_num)
else:
self.assertGreater(self.batch_num, actual_batch_num)
def test_main(self):
use_cuda_list = [False, True] if fluid.is_compiled_with_cuda(
) else [False]
iterable_list = [False, True]
drop_last_list = [False, True]
for iterable in iterable_list:
for use_cuda in use_cuda_list:
for drop_last in drop_last_list:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_network(iterable, use_cuda, drop_last)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册