提交 7ac969b8 编写于 作者: Y Yu Yang

Debug

* add Check align
* Make FetchData not shared_ptr
* Remove FetchData
* Wait & Fetch Data
上级 599f7a87
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "lod_tensor.h"
#include "lod_tensor_array.h"
#include "op_registry.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/operators/math/concat.h"
namespace paddle {
......@@ -158,15 +159,8 @@ struct ScaleLossGradOpHandle : public OpHandle {
}
};
struct FetchedData {
public:
std::vector<framework::LoDTensor> tensors_;
explicit FetchedData(size_t num_fetched) { tensors_.resize(num_fetched); }
};
struct FetchOpHandle : public OpHandle {
std::shared_ptr<FetchedData> data_;
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
......@@ -175,15 +169,26 @@ struct FetchOpHandle : public OpHandle {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
}
// Lazily merge tensors. Will faster code.
MergeTensors();
}
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
void WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
protected:
void RunImpl() override {
for (auto *input : inputs_) {
......@@ -208,15 +213,6 @@ struct FetchOpHandle : public OpHandle {
}
}
}
private:
void MergeTensors() const {
std::vector<const LoDTensor *> tensors_ptr;
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->tensors_[offset_].MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
};
class ParallelExecutorPrivate {
......@@ -325,7 +321,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
: member_(member) {}
void Wait(platform::DeviceContext *waited_dev) override {
VLOG(3) << "Wait nccl all reduce op";
OpHandle::Wait(waited_dev);
}
......@@ -355,6 +350,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>());
uintptr_t buf = reinterpret_cast<uintptr_t>(buffer);
if (buf % sizeof(float) != 0) {
VLOG(3) << "Buffer is not aligned " << buf;
}
if (dtype == -1) {
dtype = ToNCCLDataType(lod_tensor.type());
}
......@@ -680,7 +680,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
bool use_event = true;
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
FeedFetchList fetched_data(fetch_tensors.size());
// Version --> VarHandle
member_->exception_.reset();
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
......@@ -728,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back();
FetchOpHandle *op = &fetch_ops.back();
op->data_ = fetched_data;
op->data_ = &fetched_data;
op->offset_ = i;
op->local_scopes_ = &member_->local_scopes_;
for (auto &p : member_->places_) {
......@@ -786,9 +786,12 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
fetch_ops.clear();
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
fetched_data->tensors_;
for (auto &fetch_op : fetch_ops) {
fetch_op.WaitAndMergeCPUTensors();
}
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetched_data;
}
void ParallelExecutor::RunOp(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册