提交 a4951843 编写于 作者: S sneaxiy

inference feed partial data, test=develop

上级 6dadb5de
......@@ -9,7 +9,7 @@ cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_pr
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope computation_op_handle share_tensor_buffer_functor)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper INTERFACE SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
......@@ -65,6 +65,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
multi_devices_helper
sequential_execution_pass
modify_op_lock_and_record_event_pass
all_reduce_deps_pass
......@@ -72,7 +73,7 @@ set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto
eager_deletion_pass
buffer_shared_inplace_op_pass
buffer_shared_cross_op_memory_reuse_pass
set_reader_device_count_pass)
set_reader_device_info_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -66,7 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_fused_graph");
AppendMultiDevPass();
AppendSetReaderDeviceCountPass();
AppendSetReaderDeviceIndexPass();
AppendMultiGraphOptPasses();
AppendPassToSetMkldnnAttr("mkldnn_placement_pass");
......@@ -225,8 +225,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
&strategy_);
}
void AppendSetReaderDeviceCountPass() {
AppendPass("set_reader_device_count_pass");
void AppendSetReaderDeviceIndexPass() {
AppendPass("set_reader_device_index_pass");
}
void AppendPrintGraphPass(const std::string &pass_name,
......@@ -399,12 +399,9 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "set_reader_device_count_pass") {
} else if (pass->Type() == "set_reader_device_index_pass") {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
pass->Erase(kLocalScopes);
pass->SetNotOwned<const std::vector<Scope *>>(kLocalScopes,
&local_scopes);
}
VLOG(1) << "Start Apply Pass " << pass->Type();
graph = pass->Apply(graph);
......@@ -441,7 +438,7 @@ USE_PASS(fuse_sgd_op_pass);
USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(set_reader_device_count_pass);
USE_PASS(set_reader_device_index_pass);
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......
......@@ -34,6 +34,8 @@ class ComputationOpHandle : public OpHandleBase {
OperatorBase *GetOp() { return op_.get(); }
const OperatorBase *GetOp() const { return op_.get(); }
std::string Name() const override;
const Scope *GetScope() const { return scope_; }
......
......@@ -31,10 +31,12 @@ namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, Scope *scope, const platform::Place &place,
ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node),
scope_(scope),
scope_idx_(scope_idx),
place_(place),
var_infos_(vars.begin(), vars.end()),
gc_(gc) {
......
......@@ -34,7 +34,7 @@ namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, Scope *scope,
EagerDeletionOpHandle(ir::Node *node, Scope *scope, size_t scope_idx,
const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc);
......@@ -50,6 +50,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
*/
Priority GetPriority() const override { return kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
protected:
void RunImpl() override;
......@@ -63,6 +65,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
void CallOnce();
Scope *scope_;
size_t scope_idx_;
platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own
......
......@@ -12,9 +12,197 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include <algorithm>
#include <unordered_set>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {} // namespace details
namespace details {
static constexpr size_t kUndefinedDevIdx = -1UL;
static std::unordered_set<std::string> kMultiDeviceOps{
"sync_batch_norm",
"sync_batch_norm_grad",
"allreduce",
"c_allreduce_sum",
"c_allreduce_prod",
"c_allreduce_min",
"c_allreduce_max",
"c_allgather",
"c_reducescatter",
"c_broadcast",
"c_comm_init",
"c_comm_init_all",
"c_gen_nccl_id",
"c_sync_comm_stream",
"send",
"recv",
"send_barrier",
"fetch_barrier",
};
static size_t GetScopeIdxFromOp(const details::OpHandleBase &op) {
if (auto *compute_op =
dynamic_cast<const details::ComputationOpHandle *>(&op)) {
return kMultiDeviceOps.count(compute_op->GetOp()->Type()) == 0
? compute_op->GetScopeIdx()
: kUndefinedDevIdx;
} else if (auto *gc_op =
dynamic_cast<const details::EagerDeletionOpHandle *>(&op)) {
return gc_op->GetScopeIdx();
} else if (auto *share_op =
dynamic_cast<const details::ShareTensorBufferOpHandle *>(
&op)) {
return share_op->GetScopeIdx();
} else {
return kUndefinedDevIdx;
}
}
static bool ContainMultiDeviceOp(const ProgramDesc &program,
size_t begin_block_idx) {
for (size_t block_idx = begin_block_idx; block_idx < program.Size();
++block_idx) {
for (auto *op_desc : program.Block(block_idx).AllOps()) {
if (kMultiDeviceOps.count(op_desc->Type()) > 0) {
return true;
}
}
}
return false;
}
static size_t GetUniqueDeviceIdOfOp(const details::OpHandleBase &op) {
size_t dev_idx = GetScopeIdxFromOp(op);
if (dev_idx == kUndefinedDevIdx) {
return kUndefinedDevIdx;
}
const auto &ins = op.Inputs();
const auto &outs = op.Outputs();
auto in_outs = ins;
in_outs.insert(in_outs.end(), outs.begin(), outs.end());
for (auto *var : in_outs) {
auto *var_handle = dynamic_cast<details::VarHandle *>(var);
if (var_handle == nullptr) {
continue;
}
if (dev_idx != var_handle->scope_idx()) {
return kUndefinedDevIdx;
}
}
return dev_idx;
}
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph) {
if (ContainMultiDeviceOp(graph->OriginProgram(), 1)) {
return {};
}
size_t place_num = 0;
auto op_handles = ir::FilterByNodeWrapper<OpHandleBase>(*graph);
if (op_handles.empty()) {
return {};
}
std::unordered_map<details::OpHandleBase *, size_t> op_to_dev_idx;
for (auto &op : op_handles) {
auto dev_idx = GetUniqueDeviceIdOfOp(*op);
if (dev_idx == kUndefinedDevIdx) {
VLOG(10) << "Op " << op->Name() << " is not determined";
return {};
}
place_num = std::max(place_num, dev_idx + 1);
op_to_dev_idx[op] = dev_idx;
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
for (auto &in_var : op->Inputs()) {
if (in_var->GeneratedOp()) {
auto iter = op_to_dev_idx.find(in_var->GeneratedOp());
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
for (auto &out_var : op->Outputs()) {
for (auto &pending_op : out_var->PendingOps()) {
auto iter = op_to_dev_idx.find(pending_op);
if (iter == op_to_dev_idx.end() || iter->second != dev_idx) {
return {};
}
}
}
}
PADDLE_ENFORCE_GE(place_num, 1, platform::errors::NotFound(
"No place found, this may be a bug"));
std::vector<std::unique_ptr<ir::Graph>> graphs(place_num);
for (auto &g : graphs) {
g.reset(new ir::Graph(ProgramDesc()));
g->Set(kGraphVars, new GraphVars(1UL));
g->Set(kGraphDepVars, new GraphDepVars());
}
for (auto &op : op_handles) {
auto dev_idx = op_to_dev_idx.at(op);
auto *ret_graph = graphs[dev_idx].get();
auto &ret_vars = ret_graph->Get<GraphVars>(kGraphVars)[0];
auto &ret_dummy_vars = ret_graph->Get<GraphDepVars>(kGraphDepVars);
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_idx];
ret_graph->AddNode(graph->RemoveNode(op->Node()).release());
auto handler = [&](const std::vector<VarHandleBase *> &vars) {
for (auto *var : vars) {
if (graph->Nodes().count(var->Node()) > 0) {
ret_graph->AddNode(graph->RemoveNode(var->Node()).release());
auto *dummy_var = dynamic_cast<DummyVarHandle *>(var);
if (dummy_var == nullptr) {
ret_vars.emplace(var->Name(), origin_vars.at(var->Name()));
} else {
ret_dummy_vars.emplace(dummy_var);
}
}
}
};
handler(op->Inputs());
handler(op->Outputs());
}
graph->Erase(kGraphVars);
graph->Erase(kGraphDepVars);
return graphs;
}
bool HasDropLastReadOp(const ir::Graph &graph) {
auto ops = ir::FilterByNodeWrapper<OpHandleBase>(graph);
for (auto *op : ops) {
auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
if (compute_op && compute_op->GetOp()->Type() == "read" &&
compute_op->GetOp()->Attr<bool>("drop_last")) {
VLOG(10) << "The graph has drop_last=True read op";
return true;
}
}
VLOG(10) << "The graph does not have drop_last=True read op";
return false;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -47,6 +47,7 @@ constexpr char kGraphVars[] = "vars";
constexpr char kNRanks[] = "nranks";
constexpr char kPlaces[] = "places";
constexpr char kGlobalScope[] = "global_scope";
constexpr char kLocalScopes[] = "local_scopes";
constexpr char kNCCLCtxs[] = "nccl_ctxs";
constexpr char kUseHierarchicalAllReduce[] = "use_hierarchical_allreduce";
......@@ -100,6 +101,11 @@ inline std::vector<std::string> GetOpRoleVarsOrEmpty(const OpDesc &op) {
return boost::get<std::vector<std::string>>(iter->second);
}
std::vector<std::unique_ptr<ir::Graph>> TrySeparateToMultipleSingleDeviceGraphs(
ir::Graph *graph);
bool HasDropLastReadOp(const ir::Graph &graph);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -21,11 +22,11 @@ namespace paddle {
namespace framework {
namespace details {
std::vector<std::unique_ptr<ir::Graph>>
ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
static std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ir::Graph *graph, size_t place_num) {
std::vector<std::unique_ptr<ir::Graph>> graphs;
graphs.reserve(places_.size());
for (size_t i = 0; i < places_.size(); ++i) {
graphs.reserve(place_num);
for (size_t i = 0; i < place_num; ++i) {
ProgramDesc empty;
graphs.emplace_back(std::unique_ptr<ir::Graph>(new ir::Graph(empty)));
auto &g = graphs.back();
......@@ -64,7 +65,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
}
}
for (size_t dev_id = 0; dev_id < places_.size(); ++dev_id) {
for (size_t dev_id = 0; dev_id < place_num; ++dev_id) {
auto &dev_vars = graphs[dev_id]->Get<GraphVars>(kGraphVars)[0];
auto &origin_vars = graph->Get<GraphVars>(kGraphVars)[dev_id];
for (auto &name_pair : origin_vars) {
......@@ -85,15 +86,34 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
: ParallelSSAGraphExecutor(strategy, local_scopes, local_exec_scopes,
places,
SeparateMultiDevicesGraph(graph,
places.size())) {}
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
// TODO(Yancey1989): Copying graphs is not safely since it deleted the
// attrs.
graphs_(SeparateMultiDevicesGraph(graph)) {
places_(places),
graphs_(std::move(graphs)),
feed_status_(places.size(), FeedStatus::kNone) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(places_.size(), graphs_.size(),
platform::errors::InvalidArgument(
"Graph number does not match place number"));
PADDLE_ENFORCE_GT(
places_.size(), 0,
platform::errors::InvalidArgument("place number must be larger than 0"));
auto seq_allreduce_pass =
ir::PassRegistry::Instance().Get("all_reduce_deps_pass");
seq_allreduce_pass->Set<bool>(kUseHierarchicalAllReduce, new bool(false));
......@@ -123,22 +143,43 @@ std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
return result;
}
enum ExceptionStatus { kSuccess = 0, kEOF, kOther };
FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
size_t feed_num = std::count(feed_status_.begin(), feed_status_.end(),
FeedStatus::kHasFeed);
bool has_feed = (feed_num > 0);
VLOG(10) << "Feed num " << feed_num;
size_t place_num = places_.size();
std::vector<std::future<FeedFetchList>> run_futures;
std::vector<ExceptionStatus> exception_status(place_num,
ExceptionStatus::kSuccess);
std::vector<FeedFetchList> fetch_data;
FeedFetchList ret;
fetch_data.reserve(places_.size());
ret.reserve(fetch_tensors.size());
fetch_data.reserve(place_num);
ret.reserve(place_num);
exception_holder_.Clear();
for (size_t i = 0; i < places_.size(); ++i) {
auto call = [this, i, &fetch_tensors]() -> FeedFetchList {
for (size_t i = 0; i < place_num; ++i) {
auto call = [&, i]() -> FeedFetchList {
try {
return executors_[i]->Run(fetch_tensors);
if (!support_partial_feed_ || !has_feed ||
feed_status_[i] == FeedStatus::kHasFeed) {
return executors_[i]->Run(fetch_tensors);
} else {
return FeedFetchList();
}
} catch (platform::EOFException &) {
exception_status[i] = ExceptionStatus::kEOF;
exception_holder_.Catch(std::current_exception());
} catch (...) {
exception_status[i] = ExceptionStatus::kOther;
exception_holder_.Catch(std::current_exception());
}
return FeedFetchList();
......@@ -153,21 +194,63 @@ FeedFetchList ParallelSSAGraphExecutor::Run(
if (pool_) {
for (auto &f : run_futures) {
if (exception_holder_.IsCaught()) {
f.wait();
} else {
fetch_data.emplace_back(f.get());
fetch_data.emplace_back(f.get());
}
}
bool has_exception = exception_holder_.IsCaught();
if (!support_partial_feed_ && has_exception) {
VLOG(10) << "Exception rethrow because partial feed is not supported";
exception_holder_.ReThrow();
}
std::vector<bool> is_valid(place_num, true);
if (support_partial_feed_) {
if (has_feed) {
for (size_t i = 0; i < place_num; ++i) {
if (feed_status_[i] == FeedStatus::kNone) {
is_valid[i] = false;
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is given";
exception_holder_.ReThrow();
}
}
} else {
for (size_t i = 0; i < place_num; ++i) {
if (exception_status[i] == ExceptionStatus::kOther) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Exception rethrow because non-EOF exception raises when "
"feed is not given";
exception_holder_.ReThrow();
} else if (exception_status[i] != ExceptionStatus::kSuccess) {
is_valid[i] = false;
}
}
}
}
if (exception_holder_.IsCaught()) {
if (std::count(is_valid.begin(), is_valid.end(), true) == 0) {
PADDLE_ENFORCE_EQ(has_exception, true,
platform::errors::InvalidArgument(
"Thread pool raises exception but not caught"));
VLOG(10) << "Raise exception because there is no success worker";
exception_holder_.ReThrow();
}
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.reserve(local_scopes_.size());
for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) {
lodtensor_ptrs.reserve(place_num);
for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) {
if (!is_valid[scope_idx]) {
continue;
}
lodtensor_ptrs.push_back(&fetch_data.at(scope_idx).at(fetch_idx));
}
ret.emplace_back();
......
......@@ -27,12 +27,25 @@ namespace framework {
namespace details {
class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
enum FeedStatus {
kNone = 0, // No feed
kHasFeed = 1 // Has feed
};
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> graphs);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
......@@ -41,10 +54,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
std::vector<std::unique_ptr<ir::Graph>> SeparateMultiDevicesGraph(
ir::Graph *graph);
void SetHasFeed(size_t dev_idx, bool has_feed) {
feed_status_[dev_idx] = has_feed ? FeedStatus::kHasFeed : FeedStatus::kNone;
}
void EnablePartialFeedSupport() { support_partial_feed_ = true; }
bool SupportPartialFeed() const { return support_partial_feed_; }
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
......@@ -54,6 +72,9 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
std::vector<std::unique_ptr<details::FastThreadedSSAGraphExecutor>>
executors_;
ExceptionHolder exception_holder_;
bool support_partial_feed_{false};
std::vector<FeedStatus> feed_status_;
};
} // namespace details
......
......@@ -228,7 +228,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
}
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(),
eager_deletion_node, op->GetScope(), op->GetScopeIdx(), op->GetPlace(),
std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
auto it = std::find_if(
......
......@@ -98,7 +98,7 @@ class ReferenceCountPassTestHelper {
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_);
ref_cnt_pass->Apply(&graph_);
ref_cnt_pass->Apply(&const_cast<ir::Graph &>(executor_->Graph()));
}
bool IsLastLivedOps(const std::string &name,
......
......@@ -11,7 +11,7 @@ endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle ${ALL_REDUCE_OP_HANDLES} reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(set_reader_device_count_pass SRCS set_reader_device_count_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(set_reader_device_info_pass SRCS set_reader_device_info_pass.cc DEPS graph graph_helper pass multi_devices_graph_pass)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass)
......
......@@ -22,77 +22,80 @@ namespace paddle {
namespace framework {
namespace ir {
class SetReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override;
private:
int GetDeviceCount() const;
std::unordered_set<std::string> ReaderOpSet() const;
const Scope *GlobalScope() const;
};
int SetReaderDeviceCountPass::GetDeviceCount() const {
static int GetDeviceCountFromPassAttr(const Pass &pass) {
return static_cast<int>(
Get<const std::vector<platform::Place>>(details::kPlaces).size());
pass.Get<const std::vector<platform::Place>>(details::kPlaces).size());
}
std::unordered_set<std::string> SetReaderDeviceCountPass::ReaderOpSet() const {
static std::unordered_set<std::string> ReaderOpSet() {
return {"create_py_reader"};
}
const Scope *SetReaderDeviceCountPass::GlobalScope() const {
return Get<const std::vector<Scope *>>(details::kLocalScopes)[0];
}
void SetReaderDeviceCountPass::ApplyImpl(Graph *graph) const {
auto dev_cnt = GetDeviceCount();
auto reader_ops = ReaderOpSet();
auto scope = GlobalScope();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
auto queue_name = op_handle.GetOp()->Input("blocking_queue");
auto var = scope->FindVar(queue_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Blocking queue of DataLoader not found"));
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
if (var->IsType<QueueHolder>()) {
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
class InitReaderDeviceCountPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
using QueueHolder =
operators::reader::OrderedMultiDeviceLoDTensorBlockingQueueHolder;
auto reader_ops = ReaderOpSet();
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
const auto &scope = Get<const Scope>(details::kGlobalScope);
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto queue_name = node->Op()->Input("blocking_queue")[0];
auto var = scope.FindVar(queue_name);
if (var && var->IsType<QueueHolder>()) {
VLOG(10) << "Set device count of " << queue_name << " to be "
<< dev_cnt;
var->GetMutable<QueueHolder>()->GetQueue()->SetDeviceCount(dev_cnt);
}
}
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
}
};
VLOG(10) << "Found op number " << found_op_num;
}
class SetReaderDeviceIndexPass : public Pass {
protected:
void ApplyImpl(Graph *graph) const override {
auto dev_cnt = GetDeviceCountFromPassAttr(*this);
auto reader_ops = ReaderOpSet();
size_t found_op_num = 0;
for (auto &node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
reader_ops.count(node->Op()->Type()) != 0) {
auto &op_handle = dynamic_cast<details::ComputationOpHandle &>(
node->Wrapper<details::OpHandleBase>());
auto *op_desc = node->Op();
auto &op_base_attrs =
const_cast<framework::AttributeMap &>(op_handle.GetOp()->Attrs());
int dev_idx = static_cast<int>(op_handle.GetScopeIdx());
op_desc->SetAttr("device_index", dev_idx);
op_desc->SetAttr("device_count", dev_cnt);
op_base_attrs["device_index"] = dev_idx;
op_base_attrs["device_count"] = dev_cnt;
++found_op_num;
VLOG(10) << "Found op " << op_desc->Type() << " on device " << dev_idx;
}
}
VLOG(10) << "Found op number " << found_op_num;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(set_reader_device_count_pass,
paddle::framework::ir::SetReaderDeviceCountPass)
REGISTER_PASS(init_reader_device_count_pass,
paddle::framework::ir::InitReaderDeviceCountPass)
.RequirePassAttr(paddle::framework::details::kGlobalScope)
.RequirePassAttr(paddle::framework::details::kPlaces);
REGISTER_PASS(set_reader_device_index_pass,
paddle::framework::ir::SetReaderDeviceIndexPass)
.RequirePassAttr(paddle::framework::details::kPlaces);
......@@ -307,18 +307,18 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
const std::vector<platform::Place> places) const {
PADDLE_ENFORCE_GT(places.size(), 0,
platform::errors::InvalidArgument(
"place number cannot be empty when splitting"));
check_memory_size();
int batch_size =
lod().empty() ? dims()[0] : static_cast<int>(lod()[0].size()) - 1;
size_t result_size = std::min(static_cast<size_t>(batch_size), places.size());
size_t remainder = batch_size % places.size();
size_t batch_size =
lod().empty() ? static_cast<size_t>(dims()[0]) : lod()[0].size() - 1;
std::vector<LoDTensor> results;
results.reserve(result_size);
// if result_size(batch_size) is 0, just return #places.size() copys of empty
// if batch_size is 0, just return #places.size() copys of empty
// tensors.
if (result_size == 0) {
if (batch_size == 0) {
std::vector<LoDTensor> empty_results;
empty_results.reserve(places.size());
for (size_t i = 0; i < places.size(); ++i) {
LoDTensor dst;
dst.Resize(dims());
......@@ -326,18 +326,22 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
if (!lod().empty()) {
dst.set_lod(lod());
}
results.emplace_back(dst);
empty_results.emplace_back(std::move(dst));
}
return results;
return empty_results;
}
int step_width = static_cast<int>(batch_size / result_size);
auto step_width = (batch_size + places.size() - 1) / places.size();
auto result_size = (batch_size + step_width - 1) / step_width;
std::vector<LoDTensor> results;
results.reserve(result_size);
for (size_t i = 0; i < result_size; ++i) {
int begin = static_cast<int>(i * step_width);
int end = static_cast<int>((i + 1) * step_width);
if (i + 1 == places.size()) { // last
end += remainder;
}
auto begin = i * step_width;
auto end = std::min<size_t>((i + 1) * step_width, batch_size);
PADDLE_ENFORCE_LT(begin, end,
platform::errors::InvalidArgument(
"begin must be less than end, this may be a bug"));
LoDTensor dst;
if (lod().empty()) {
......@@ -362,7 +366,7 @@ std::vector<LoDTensor> LoDTensor::SplitLoDTensor(
}
dst.set_lod(my_lod);
}
results.emplace_back(dst);
results.emplace_back(std::move(dst));
}
return results;
......
......@@ -55,8 +55,9 @@ static bool gProfileStarted = false;
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
: places_(places) {
ParallelExecutorPrivate(const std::vector<platform::Place> &places,
Scope *global_scope)
: places_(places), global_scope_(global_scope) {
if (!FLAGS_pe_profile_fname.empty()) {
std::call_once(gProfileOnce, [] {
#ifdef WITH_GPERFTOOLS
......@@ -82,6 +83,19 @@ class ParallelExecutorPrivate {
}
}
void InitReaderDeviceCount(ir::Graph *graph) const {
auto pass =
ir::PassRegistry::Instance().Get("init_reader_device_count_pass");
pass->SetNotOwned<const Scope>(details::kGlobalScope, global_scope_);
pass->SetNotOwned<const std::vector<platform::Place>>(details::kPlaces,
&places_);
pass->Apply(graph);
}
void SetHasFeed(size_t dev_idx, bool has_feed = true);
bool AllowPartialFeed() const;
ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
......@@ -257,8 +271,20 @@ class ParallelExecutorPrivate {
ir::MemOptVarInfoMapList mem_opt_var_infos_;
ir::GarbageCollectorMap gcs_;
details::ParallelSSAGraphExecutor *inference_executor_{nullptr};
};
void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) {
if (inference_executor_) {
inference_executor_->SetHasFeed(dev_idx, has_feed);
}
}
bool ParallelExecutorPrivate::AllowPartialFeed() const {
return inference_executor_ && inference_executor_->SupportPartialFeed();
}
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
if (FLAGS_use_ngraph) {
LOG_FIRST_N(WARNING, 1)
......@@ -379,6 +405,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
return graph;
}
class ResetHasFeedGuard {
public:
explicit ResetHasFeedGuard(ParallelExecutorPrivate *pe_member)
: pe_member_(pe_member) {}
~ResetHasFeedGuard() {
for (size_t i = 0; i < pe_member_->places_.size(); ++i) {
pe_member_->SetHasFeed(i, false);
}
}
private:
ParallelExecutorPrivate *pe_member_;
};
size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); }
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
......@@ -407,8 +448,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
const ExecutionStrategy &exec_strategy,
const BuildStrategy &build_strategy,
ir::Graph *graph)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
: member_(new ParallelExecutorPrivate(places, scope)) {
member_->InitReaderDeviceCount(graph);
member_->use_cuda_ = exec_strategy.use_cuda_;
member_->build_strategy_ = build_strategy;
member_->use_all_reduce_ = member_->build_strategy_.reduce_ ==
......@@ -606,18 +647,35 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
"Paddle should be compiled with CUDA for ParallelGraph Execution.");
#endif
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
bool has_drop_last_read_op = details::HasDropLastReadOp(*graph);
auto possible_inference_graphs =
details::TrySeparateToMultipleSingleDeviceGraphs(graph);
if (!possible_inference_graphs.empty()) {
VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase";
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
member_->places_, std::move(possible_inference_graphs));
if (!has_drop_last_read_op) {
VLOG(5) << "Enable partial feed support in inference phase";
pg_exe->EnablePartialFeedSupport();
}
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
member_->inference_executor_ = pg_exe;
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
}
final_graphs.emplace_back(graph);
}
final_graphs.emplace_back(graph);
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
......@@ -724,6 +782,8 @@ FeedFetchList ParallelExecutor::Run(
platform::RecordBlock b(0);
ResetHasFeedGuard reset_has_feed_guard(member_);
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->HasGarbageCollectors());
......@@ -734,10 +794,22 @@ FeedFetchList ParallelExecutor::Run(
void ParallelExecutor::FeedTensorsIntoLocalScopes(
const std::vector<std::unordered_map<std::string, LoDTensor>> &tensors) {
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());
if (!member_->AllowPartialFeed()) {
PADDLE_ENFORCE_EQ(
member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number does not match the device number"));
} else {
PADDLE_ENFORCE_GE(member_->local_scopes_.size(), tensors.size(),
platform::errors::InvalidArgument(
"The feed tensor number exceeds the device number"));
}
for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i];
if (!map.empty()) {
member_->SetHasFeed(i);
}
for (auto &pair : map) {
bool is_persistable = member_->IsPersistable(pair.first);
if (!is_persistable) {
......@@ -757,6 +829,11 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) {
size_t num_places = member_->places_.size();
bool allow_partial_feed = member_->AllowPartialFeed();
size_t persistable_feed_len = -1UL;
size_t non_persistable_feed_len = -1UL;
for (auto &pair : tensors) {
bool is_persistable = member_->IsPersistable(pair.first);
VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable")
......@@ -764,7 +841,8 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
<< ", place: " << pair.second.place();
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
if (!is_persistable && num_places != lod_tensors.size()) {
if (!is_persistable && num_places != lod_tensors.size() &&
!allow_partial_feed) {
auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of current batch is less than the "
"count(%d) of devices(%s), currently, it is not allowed. ",
......@@ -790,7 +868,7 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
framework::TensorCopy(pair.second, member_->places_.at(i), &tmp);
}
}
if (lod_tensors.size() != num_places) {
if (lod_tensors.size() != num_places && !allow_partial_feed) {
auto error_info = string::Sprintf(
"The number(%d) of samples[%s] of the current batch does not match "
"the count(%d) of devices(%s). Because that %s is a persistable "
......@@ -804,7 +882,31 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
for (size_t j = 0; j < num_places; ++j) {
if (allow_partial_feed) {
if (is_persistable) {
if (persistable_feed_len == -1UL) {
persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different persistable variables "
"should be the same"));
}
} else {
if (non_persistable_feed_len == -1UL) {
non_persistable_feed_len = lod_tensors.size();
} else {
PADDLE_ENFORCE_EQ(
non_persistable_feed_len, lod_tensors.size(),
platform::errors::InvalidArgument(
"The feeded number of different non-persistable variables "
"should be the same"));
}
}
}
for (size_t j = 0; j < lod_tensors.size(); ++j) {
auto *feed_scope = is_persistable ? member_->local_scopes_[j]
: member_->local_exec_scopes_[j];
auto *feed_var = feed_scope->Var(pair.first);
......@@ -814,6 +916,19 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
t->set_lod(lod_tensors[j].lod());
}
}
if (allow_partial_feed && persistable_feed_len != -1UL &&
non_persistable_feed_len != -1UL) {
VLOG(10) << "Persistable len " << persistable_feed_len;
VLOG(10) << "Non persistable len " << non_persistable_feed_len;
PADDLE_ENFORCE_GE(persistable_feed_len, non_persistable_feed_len,
platform::errors::InvalidArgument(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"));
for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i);
}
}
}
ParallelExecutor::~ParallelExecutor() {
......@@ -864,6 +979,10 @@ bool ParallelExecutor::EnableParallelGraphExecution(
return enable_parallel_graph;
}
const ir::Graph &ParallelExecutor::Graph() const {
return member_->executor_->Graph();
}
} // namespace framework
} // namespace paddle
......@@ -871,3 +990,4 @@ USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
USE_PASS(buffer_shared_cross_op_memory_reuse_pass);
USE_PASS(init_reader_device_count_pass);
......@@ -79,6 +79,8 @@ class ParallelExecutor {
FeedFetchList Run(const std::vector<std::string> &fetch_tensors);
const ir::Graph &Graph() const;
private:
// broadcast the parameters from the 0th device.
// trainer_id the trainer index in nccl distributed training.
......
......@@ -156,6 +156,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true);
AddAttr<bool>("infer_out", "").SetDefault(true);
AddAttr<bool>("drop_last",
"Whether to drop last batches whose number is less than CPU "
"cores/GPU cards number")
.SetDefault(true);
AddComment(R"DOC(
Read Operator
......
......@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "Python.h"
#include "boost/optional.hpp"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/reader.h"
......@@ -41,6 +42,58 @@ namespace pybind {
namespace py = pybind11;
namespace reader = operators::reader;
// Check whether the tensor shape matches the VarDesc shape
// Return the different shape if exists
static boost::optional<std::vector<int64_t>> DiffTensorShapeWithVarDesc(
const framework::LoDTensor &tensor, const framework::VarDesc &var_desc,
size_t num_places) {
auto tensor_shape = tensor.dims();
auto desc_shape = var_desc.GetShape();
int64_t rank = tensor_shape.size();
if (UNLIKELY(rank == 0)) {
if (desc_shape.size() != 0) { // Tensor rank = 0 but desc does not match
return framework::vectorize<int64_t>(tensor_shape);
} else {
return boost::none;
}
}
PADDLE_ENFORCE_GE(tensor_shape[0], 0,
platform::errors::InvalidArgument(
"Tensor shape must not be less than 0"));
if (!tensor.lod().empty()) {
tensor_shape[0] = -1; // unknown shape
} else {
int64_t split_size = (tensor_shape[0] + num_places - 1) / num_places;
int64_t remainder = (split_size == 0 ? 0 : tensor_shape[0] % split_size);
tensor_shape[0] = split_size;
if (desc_shape[0] >= 0) { // need check dim 0
if (tensor_shape[0] != desc_shape[0]) {
return framework::vectorize<int64_t>(tensor_shape);
}
if (remainder > 0) {
tensor_shape[0] = remainder;
return framework::vectorize<int64_t>(tensor_shape);
}
}
}
for (int64_t idx = 1; idx < rank; ++idx) {
PADDLE_ENFORCE_GE(tensor_shape[idx], 0,
platform::errors::InvalidArgument(
"Tensor shape must not be less than 0"));
if (desc_shape[idx] >= 0 && tensor_shape[idx] != desc_shape[idx]) {
return framework::vectorize<int64_t>(tensor_shape);
}
}
return boost::none;
}
static const std::shared_ptr<reader::LoDTensorBlockingQueue> &GetQueue(
const std::shared_ptr<reader::LoDTensorBlockingQueue> &queue, size_t idx) {
return queue;
......@@ -66,10 +119,12 @@ class MultiDeviceFeedReader {
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places, bool use_double_buffer)
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last)
: queue_(queue),
names_(names),
pool_(new ::ThreadPool(dst_places.size())) {
pool_(new ::ThreadPool(dst_places.size())),
drop_last_(drop_last) {
std::vector<framework::DDim> dims;
for (auto &shape : shapes) {
dims.push_back(framework::make_ddim(shape));
......@@ -113,12 +168,16 @@ class MultiDeviceFeedReader {
ReadAsync();
}
bool DropLast() const { return drop_last_; }
ResultDictList ReadNext() {
CheckNextStatus();
ResultDictList result(ret_.size());
for (size_t i = 0; i < ret_.size(); ++i) {
for (size_t j = 0; j < names_.size(); ++j) {
result[i].emplace(names_[j], std::move(ret_[i][j]));
if (!ret_[i].empty()) {
for (size_t j = 0; j < names_.size(); ++j) {
result[i].emplace(names_[j], std::move(ret_[i][j]));
}
}
}
ReadAsync();
......@@ -155,24 +214,29 @@ class MultiDeviceFeedReader {
};
Status WaitFutures(std::exception_ptr *excep) {
bool is_success = true;
*excep = nullptr;
size_t success_num = 0;
for (size_t i = 0; i < futures_.size(); ++i) {
auto each_status = futures_[i].get();
if (UNLIKELY(each_status != Status::kSuccess)) {
is_success = false;
if (UNLIKELY(each_status == Status::kException)) {
PADDLE_ENFORCE_NOT_NULL(exceptions_[i]);
*excep = exceptions_[i];
exceptions_[i] = nullptr;
}
} else {
++success_num;
}
}
if (UNLIKELY(*excep)) {
return Status::kException;
}
if (drop_last_) {
return success_num == futures_.size() ? Status::kSuccess : Status::kEOF;
} else {
return is_success ? Status::kSuccess : Status::kEOF;
return success_num > 0 ? Status::kSuccess : Status::kEOF;
}
}
......@@ -226,6 +290,7 @@ class MultiDeviceFeedReader {
std::vector<std::exception_ptr> exceptions_;
std::vector<std::vector<framework::LoDTensor>> ret_;
bool drop_last_;
};
template <typename QueueType>
......@@ -270,6 +335,17 @@ void BindMultiDeviceReader(py::module *module, const char *reader_name) {
void BindReader(py::module *module) {
auto &m = *module;
m.def("diff_tensor_shape", [](const framework::LoDTensor &tensor,
const framework::VarDesc &var_desc,
size_t num_places) -> py::object {
auto diff = DiffTensorShapeWithVarDesc(tensor, var_desc, num_places);
if (diff) {
return py::cast(std::move(diff.get()));
} else {
return py::cast(nullptr);
}
});
m.def("init_lod_tensor_blocking_queue",
[](framework::Variable &var, size_t capacity,
bool is_ordered) -> py::object {
......@@ -337,10 +413,10 @@ void BindReader(py::module *module) {
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
bool use_double_buffer, bool drop_last) {
return new MultiDeviceFeedReader<reader::LoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer);
use_double_buffer, drop_last);
},
py::return_value_policy::take_ownership);
......@@ -352,13 +428,13 @@ void BindReader(py::module *module) {
const std::vector<std::vector<int>> &shapes,
const std::vector<framework::proto::VarType::Type> &dtypes,
const std::vector<bool> &need_check_feed,
const std::vector<platform::Place> &dst_places,
bool use_double_buffer) {
const std::vector<platform::Place> &dst_places, bool use_double_buffer,
bool drop_last) {
queue->SetDeviceCount(dst_places.size());
return new MultiDeviceFeedReader<
reader::OrderedMultiDeviceLoDTensorBlockingQueue>(
queue, names, shapes, dtypes, need_check_feed, dst_places,
use_double_buffer);
use_double_buffer, drop_last);
},
py::return_value_policy::take_ownership);
}
......
......@@ -216,18 +216,12 @@ def check_feed_shape_type(var, feed, num_places=1):
the feed value
"""
if var.desc.need_check_feed():
feed_shape = feed.shape()
if six.PY2:
feed_shape[0] = long(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
else:
feed_shape[0] = int(feed_shape[0] /
num_places) if len(feed.lod()) == 0 else -1
if not dimension_is_compatible_with(feed_shape, var.shape):
diff_shape = core.diff_tensor_shape(feed, var.desc, num_places)
if diff_shape is not None:
raise ValueError(
'The feeded Variable %r should have dimensions = %d, shape = '
'%r, but received feeded shape %r on each device' %
(var.name, len(var.shape), var.shape, feed_shape))
(var.name, len(var.shape), var.shape, diff_shape))
if not dtype_is_compatible_with(feed._dtype(), var.dtype):
var_dtype_format = convert_dtype(var.dtype) if isinstance(
var.dtype, core.VarDesc.VarType) else var.dtype
......@@ -646,11 +640,6 @@ class Executor(object):
exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(program._places):
raise ValueError(
"Feed a list of tensor, the list should be the same size as places"
)
res = list()
for i, each in enumerate(feed):
if not isinstance(each, dict):
......
......@@ -88,6 +88,7 @@ class DataLoader(object):
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
keep_order=False):
"""
Create a DataLoader object for loading data from Python generator.
......@@ -134,6 +135,9 @@ class DataLoader(object):
can be used in the dygraph mode. In the static graph mode,
whether this parameter is set or not has no effect.
The Default value is False.
drop_last (bool): whether to drop the last batches whose number is
less than the CPU core/GPU card number. The default value is
True.
keep_order (bool): whether to assign the data to CPU cores or GPU
cards in order. Supposing that there are 2 batches and we use
2 GPU cards to run the network. If keep_order=True, GPU 0 would
......@@ -289,7 +293,7 @@ class DataLoader(object):
return_list, use_multiprocess)
else:
return GeneratorLoader(feed_list, capacity, use_double_buffer,
iterable, return_list, keep_order)
iterable, return_list, drop_last, keep_order)
@staticmethod
def from_dataset(dataset, places, drop_last=True):
......@@ -422,7 +426,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
core.Variable(), self._capacity, False)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
self._need_check_feed, self._places, self._use_double_buffer, True)
def _start(self):
if self._use_multiprocess:
......@@ -628,6 +632,7 @@ class GeneratorLoader(DataLoaderBase):
use_double_buffer=True,
iterable=True,
return_list=False,
drop_last=True,
keep_order=False):
self._tensor_reader = None
self._places = None
......@@ -635,6 +640,8 @@ class GeneratorLoader(DataLoaderBase):
self._queue = None
self._feed_list = feed_list
self._exited = False
self._drop_last = drop_last
self._keep_order = keep_order
if not capacity:
raise ValueError("Please give value to capacity.")
self._iterable = iterable
......@@ -643,7 +650,6 @@ class GeneratorLoader(DataLoaderBase):
raise Exception("Feed list must be given under static mode.")
self._use_double_buffer = use_double_buffer
self._capacity = capacity
self._keep_order = keep_order
if not self._iterable:
self._init_non_iterable()
......@@ -667,7 +673,8 @@ class GeneratorLoader(DataLoaderBase):
core.Variable(), self._capacity, self._keep_order)
self._reader = core.create_py_reader(
self.queue, self._var_names, self._shapes, self._dtypes,
self._need_check_feed, self._places, self._use_double_buffer)
self._need_check_feed, self._places, self._use_double_buffer,
self._drop_last)
def _init_non_iterable(self):
lod_levels = []
......@@ -744,7 +751,8 @@ class GeneratorLoader(DataLoaderBase):
default_main_program().current_block().append_op(
type='read',
inputs={'Reader': [self._reader]},
outputs={'Out': self._feed_list})
outputs={'Out': self._feed_list},
attrs={'drop_last': self._drop_last})
@property
def queue(self):
......
......@@ -355,4 +355,5 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu
test_parallel_executor_crf_auto_growth test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass
test_data_norm_op test_imperative_using_non_zero_gpu test_fuse_bn_act_pass
test_optimizer_in_control_flow test_dataloader_keep_order
test_parallel_executor_inference_feed_partial_data
test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
import unittest
import six
class TestInferencePartialFeed(unittest.TestCase):
def setUp(self):
self.iterations = 10
self.size = 10
def run_network(self, places, use_split):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y)
relu_lr = fluid.layers.relu(lr)
exe = fluid.Executor(places[0])
exe.run(startup_prog)
prog = fluid.CompiledProgram(main_prog).with_data_parallel(
places=places)
gen_random = lambda shape:np.random.uniform(low=-1.0, high=1.0, size=shape).astype('float32')
assert_result = lambda feed, result: self.assertTrue(np.array_equal(np.maximum(0, feed), result))
def feed_split_test():
for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size])
if place_num <= len(places):
lr_np = gen_random([place_num])
else:
lr_np = gen_random([1])
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog,
feed={x.name: x_np,
y.name: y_np,
lr.name: lr_np},
fetch_list=[relu_x, relu_y, relu_lr])
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
if place_num <= len(places):
assert_result(lr_np, relu_lr_np)
else:
expected_relu_lr_np = max(lr_np[0], 0)
self.assertTrue(np.all(expected_relu_lr_np == relu_lr_np))
def feed_list_test():
for place_num in six.moves.range(1, len(places) + 1):
x_np_list = []
y_np_list = []
lr_np_list = []
feed_list = []
for _ in six.moves.range(place_num):
x_np = gen_random([1, self.size])
y_np = gen_random([1, self.size])
lr_np = gen_random([1])
x_np_list.append(x_np)
y_np_list.append(y_np)
lr_np_list.append(lr_np)
feed_list.append({
x.name: x_np,
y.name: y_np,
lr.name: lr_np
})
relu_x_np, relu_y_np, relu_lr_np = exe.run(
prog, feed=feed_list, fetch_list=[relu_x, relu_y, relu_lr])
x_np = np.concatenate(x_np_list)
y_np = np.concatenate(y_np_list)
lr_np = np.concatenate(lr_np_list)
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
assert_result(lr_np, relu_lr_np)
for _ in six.moves.range(self.iterations):
if use_split:
feed_split_test()
else:
feed_list_test()
def test_main(self):
places = [fluid.cpu_places(4)]
if fluid.is_compiled_with_cuda():
places.append(fluid.cuda_places())
for p in places:
self.run_network(p, use_split=True)
self.run_network(p, use_split=False)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
def setUp(self):
self.epoch_num = 3
self.batch_num = 101 # a prime number
self.batch_size = 32
def create_reader(self):
def __impl__():
for _ in six.moves.range(self.batch_num):
yield np.random.random([self.batch_size, 1]).astype('float32'),
return __impl__
def run_network(self, iterable, use_cuda, drop_last):
x = fluid.data(shape=[None, 1], name='x', dtype='float32')
places = fluid.cuda_places() if use_cuda else fluid.cpu_places(4)
loader = fluid.io.DataLoader.from_generator(
feed_list=[x], capacity=16, iterable=iterable, drop_last=drop_last)
y = fluid.layers.fc(x, size=10)
loss = fluid.layers.reduce_mean(y)
exe = fluid.Executor(places[0])
exe.run(fluid.default_startup_program())
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
places=places, loss_name=loss.name)
loader.set_batch_generator(
self.create_reader(), places=places if iterable else None)
for _ in six.moves.range(self.epoch_num):
actual_batch_num = 0
if loader.iterable:
for feed_data in loader():
x_data, = exe.run(prog, feed=feed_data, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] / self.batch_size)
else:
loader.start()
try:
while True:
x_data, = exe.run(prog, fetch_list=[x])
self.assertEqual(x_data.shape[0] % self.batch_size, 0)
self.assertTrue(x_data.shape[0] != 0)
actual_batch_num += int(x_data.shape[0] /
self.batch_size)
except fluid.core.EOFException:
loader.reset()
if not drop_last or len(places) == 1:
self.assertEqual(self.batch_num, actual_batch_num)
else:
self.assertGreater(self.batch_num, actual_batch_num)
def test_main(self):
use_cuda_list = [False, True] if fluid.is_compiled_with_cuda(
) else [False]
iterable_list = [False, True]
drop_last_list = [False, True]
for iterable in iterable_list:
for use_cuda in use_cuda_list:
for drop_last in drop_last_list:
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.run_network(iterable, use_cuda, drop_last)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册