提交 7ac969b8 编写于 作者: Y Yu Yang

Debug

* add Check align
* Make FetchData not shared_ptr
* Remove FetchData
* Wait & Fetch Data
上级 599f7a87
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "lod_tensor.h" #include "lod_tensor.h"
#include "lod_tensor_array.h" #include "lod_tensor_array.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat.h"
namespace paddle { namespace paddle {
...@@ -158,15 +159,8 @@ struct ScaleLossGradOpHandle : public OpHandle { ...@@ -158,15 +159,8 @@ struct ScaleLossGradOpHandle : public OpHandle {
} }
}; };
struct FetchedData {
public:
std::vector<framework::LoDTensor> tensors_;
explicit FetchedData(size_t num_fetched) { tensors_.resize(num_fetched); }
};
struct FetchOpHandle : public OpHandle { struct FetchOpHandle : public OpHandle {
std::shared_ptr<FetchedData> data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
std::vector<Scope *> *local_scopes_; std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_; std::vector<LoDTensor> tensors_;
...@@ -175,15 +169,26 @@ struct FetchOpHandle : public OpHandle { ...@@ -175,15 +169,26 @@ struct FetchOpHandle : public OpHandle {
for (auto *input_var : inputs_) { for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this); input_var->pending_ops_.erase(this);
} }
// Lazily merge tensors. Will faster code.
MergeTensors();
} }
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error"); PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
} }
void WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
protected: protected:
void RunImpl() override { void RunImpl() override {
for (auto *input : inputs_) { for (auto *input : inputs_) {
...@@ -208,15 +213,6 @@ struct FetchOpHandle : public OpHandle { ...@@ -208,15 +213,6 @@ struct FetchOpHandle : public OpHandle {
} }
} }
} }
private:
void MergeTensors() const {
std::vector<const LoDTensor *> tensors_ptr;
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->tensors_[offset_].MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
}; };
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
...@@ -325,7 +321,6 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -325,7 +321,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
: member_(member) {} : member_(member) {}
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
VLOG(3) << "Wait nccl all reduce op";
OpHandle::Wait(waited_dev); OpHandle::Wait(waited_dev);
} }
...@@ -355,6 +350,11 @@ struct NCCLAllReduceOpHandle : public OpHandle { ...@@ -355,6 +350,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>(); auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>()); void *buffer = const_cast<void *>(lod_tensor.data<void>());
uintptr_t buf = reinterpret_cast<uintptr_t>(buffer);
if (buf % sizeof(float) != 0) {
VLOG(3) << "Buffer is not aligned " << buf;
}
if (dtype == -1) { if (dtype == -1) {
dtype = ToNCCLDataType(lod_tensor.type()); dtype = ToNCCLDataType(lod_tensor.type());
} }
...@@ -680,7 +680,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -680,7 +680,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
bool use_event = true; bool use_event = true;
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); FeedFetchList fetched_data(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
...@@ -728,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -728,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto &vars = fetched_vars[var_name]; auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(); fetch_ops.emplace_back();
FetchOpHandle *op = &fetch_ops.back(); FetchOpHandle *op = &fetch_ops.back();
op->data_ = fetched_data; op->data_ = &fetched_data;
op->offset_ = i; op->offset_ = i;
op->local_scopes_ = &member_->local_scopes_; op->local_scopes_ = &member_->local_scopes_;
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
...@@ -786,9 +786,12 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -786,9 +786,12 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
fetch_ops.clear(); for (auto &fetch_op : fetch_ops) {
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() = fetch_op.WaitAndMergeCPUTensors();
fetched_data->tensors_; }
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetched_data;
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册