未验证 提交 00594c1c 编写于 作者: 1 123malin 提交者: GitHub

support dumping params/grads in transpiler mode (#22490)

上级 a06d75a2
...@@ -66,9 +66,11 @@ else() ...@@ -66,9 +66,11 @@ else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor) cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif() endif()
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version)
cc_library(device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(device_worker_test SRCS device_worker_test.cc DEPS device_worker)
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog) cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog)
......
...@@ -23,5 +23,73 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) { ...@@ -23,5 +23,73 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
device_reader_ = data_feed; device_reader_ = data_feed;
} }
template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << tensor->data<T>()[i];
}
return os.str();
}
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
}
return os.str();
}
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
} else if (tensor->type() == proto::VarType::INT64) {
out_val = PrintLodTensorIntType(tensor, start, end);
} else if (tensor->type() == proto::VarType::FP64) {
out_val = PrintLodTensorType<double>(tensor, start, end);
} else {
out_val = "unsupported type";
}
return out_val;
}
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
} else {
return {index * dims[1], (index + 1) * dims[1]};
}
}
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
auto& dims = tensor->dims();
if (dims.size() != 2) return false;
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
if (lod.size() != batch_size + 1) {
return false;
}
} else {
if (dims[0] != static_cast<int>(batch_size)) {
return false;
}
}
return true;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -45,6 +45,10 @@ limitations under the License. */ ...@@ -45,6 +45,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
class FleetWrapper; class FleetWrapper;
#define SEC_LOG \ #define SEC_LOG \
...@@ -168,6 +172,8 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -168,6 +172,8 @@ class HogwildWorker : public CPUWorkerBase {
virtual void Initialize(const TrainerDesc& desc); virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles(); virtual void TrainFiles();
virtual void TrainFilesWithProfiler(); virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void PrintFetchVars(); virtual void PrintFetchVars();
virtual void CreateDeviceResource(const ProgramDesc& main_prog); virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory(); virtual void BindingDataFeedMemory();
...@@ -177,6 +183,8 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -177,6 +183,8 @@ class HogwildWorker : public CPUWorkerBase {
protected: protected:
void CreateThreadOperators(const ProgramDesc& program); void CreateThreadOperators(const ProgramDesc& program);
void CreateThreadScope(const ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program);
virtual void DumpParam(const int batch_id);
std::vector<std::string> op_names_; std::vector<std::string> op_names_;
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
bool thread_barrier_; bool thread_barrier_;
...@@ -184,6 +192,12 @@ class HogwildWorker : public CPUWorkerBase { ...@@ -184,6 +192,12 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_; HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_; std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_; std::map<std::string, int> stat_var_name_map_;
// dump params or grads for debug
bool need_dump_param_;
bool need_dump_field_;
std::vector<std::string> dump_param_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
}; };
class DownpourWorker : public HogwildWorker { class DownpourWorker : public HogwildWorker {
...@@ -203,13 +217,11 @@ class DownpourWorker : public HogwildWorker { ...@@ -203,13 +217,11 @@ class DownpourWorker : public HogwildWorker {
void PushGradients(); void PushGradients();
void CollectLabelInfo(size_t table_id); void CollectLabelInfo(size_t table_id);
void AdjustInsWeight(); void AdjustInsWeight();
void DumpParam();
void CopySparseTable(); void CopySparseTable();
void CopyDenseTable(); void CopyDenseTable();
void CopyDenseVars(); void CopyDenseVars();
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end); virtual void DumpParam(const int batch_id);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
DownpourWorkerParameter param_; DownpourWorkerParameter param_;
// copy table // copy table
CopyTableConfig copy_table_config_; CopyTableConfig copy_table_config_;
...@@ -236,16 +248,11 @@ class DownpourWorker : public HogwildWorker { ...@@ -236,16 +248,11 @@ class DownpourWorker : public HogwildWorker {
std::vector<::std::future<int32_t>> push_sparse_status_; std::vector<::std::future<int32_t>> push_sparse_status_;
bool dump_slot_; bool dump_slot_;
bool need_to_push_dense_; bool need_to_push_dense_;
bool need_dump_field_;
bool need_dump_param_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_; std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
float scale_datanorm_; float scale_datanorm_;
std::vector<::std::future<int32_t>> push_dense_status_; std::vector<::std::future<int32_t>> push_dense_status_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
// skipped ops // skipped ops
std::vector<std::string> skip_ops_; std::vector<std::string> skip_ops_;
std::vector<std::string> dump_param_;
// just save the value in param_ for easy access // just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_; std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_; std::map<uint64_t, std::vector<std::string>> dense_value_names_;
......
...@@ -12,13 +12,66 @@ ...@@ -12,13 +12,66 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/device_worker.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST() { TEST(LodTensor, PrintLodTensor) {
// create hogwild device worker LoDTensor tensor1;
tensor1.Resize({2});
tensor1.mutable_data<float>(platform::CPUPlace());
tensor1.data<float>()[0] = 0.2;
tensor1.data<float>()[1] = 0.5;
std::string res = PrintLodTensor(&tensor1, -1, 2);
ASSERT_EQ(res, "access violation");
res = PrintLodTensor(&tensor1, 0, 2);
ASSERT_EQ(res, ":0.2:0.5");
LoDTensor tensor2;
tensor2.Resize({2});
tensor2.mutable_data<int64_t>(platform::CPUPlace());
tensor2.data<int64_t>()[0] = 1;
tensor2.data<int64_t>()[1] = 2;
res = PrintLodTensor(&tensor2, -1, 2);
ASSERT_EQ(res, "access violation");
res = PrintLodTensor(&tensor2, 0, 2);
ASSERT_EQ(res, ":1:2");
LoDTensor tensor3;
tensor3.Resize({2});
tensor3.mutable_data<double>(platform::CPUPlace());
tensor3.data<double>()[0] = 0.1;
tensor3.data<double>()[1] = 0.2;
res = PrintLodTensor(&tensor3, 0, 2);
ASSERT_EQ(res, ":0.1:0.2");
} }
TEST(LodTensor, GetTensorBound) {
LoD lod{{0, 2}};
LoDTensor tensor;
tensor.set_lod(lod);
tensor.Resize({2, 1});
tensor.mutable_data<float>(platform::CPUPlace());
tensor.data<float>()[0] = 0;
tensor.data<float>()[1] = 1;
std::pair<int64_t, int64_t> res = GetTensorBound(&tensor, 0);
ASSERT_EQ(res.first, 0);
ASSERT_EQ(res.second, 2);
} }
TEST(LodTensor, CheckValidOutput) {
LoD lod{{0, 1, 2}};
LoDTensor tensor;
tensor.set_lod(lod);
tensor.Resize({2, 1});
tensor.mutable_data<float>(platform::CPUPlace());
tensor.data<float>()[0] = 0;
tensor.data<float>()[1] = 1;
ASSERT_TRUE(CheckValidOutput(&tensor, 2));
} }
} // namespace framework
} // namespace paddle
...@@ -129,89 +129,19 @@ void DownpourWorker::SetNeedDump(bool need_dump_field) { ...@@ -129,89 +129,19 @@ void DownpourWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field; need_dump_field_ = need_dump_field;
} }
template <typename T> void DownpourWorker::DumpParam(const int batch_id) {
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << tensor->data<T>()[i];
}
return os.str();
}
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os; std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
}
return os.str();
}
std::string DownpourWorker::PrintLodTensor(LoDTensor* tensor, int64_t start,
int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
} else if (tensor->type() == proto::VarType::INT64) {
out_val = PrintLodTensorIntType(tensor, start, end);
} else if (tensor->type() == proto::VarType::FP64) {
out_val = PrintLodTensorType<double>(tensor, start, end);
} else {
out_val = "unsupported type";
}
return out_val;
}
std::pair<int64_t, int64_t> DownpourWorker::GetTensorBound(LoDTensor* tensor,
int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
} else {
return {index * dims[1], (index + 1) * dims[1]};
}
}
bool DownpourWorker::CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
auto& dims = tensor->dims();
if (dims.size() != 2) return false;
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
if (lod.size() != batch_size + 1) {
return false;
}
} else {
if (dims[0] != static_cast<int>(batch_size)) {
return false;
}
}
return true;
}
void DownpourWorker::DumpParam() {
std::string os;
for (auto& param : dump_param_) { for (auto& param : dump_param_) {
os.clear(); os.str("");
os = param;
Variable* var = thread_scope_->FindVar(param); Variable* var = thread_scope_->FindVar(param);
if (var == nullptr) { if (var == nullptr) {
continue; continue;
} }
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel(); int64_t len = tensor->numel();
os += PrintLodTensor(tensor, 0, len); os << "(" << batch_id << "," << param << ")"
writer_ << os; << PrintLodTensor(tensor, 0, len);
writer_ << os.str();
} }
} }
...@@ -1022,7 +952,7 @@ void DownpourWorker::TrainFiles() { ...@@ -1022,7 +952,7 @@ void DownpourWorker::TrainFiles() {
writer_ << ars[i]; writer_ << ars[i];
} }
if (need_dump_param_ && thread_id_ == 0) { if (need_dump_param_ && thread_id_ == 0) {
DumpParam(); DumpParam(batch_cnt);
} }
} }
......
...@@ -564,7 +564,7 @@ void DownpourWorkerOpt::TrainFiles() { ...@@ -564,7 +564,7 @@ void DownpourWorkerOpt::TrainFiles() {
writer_ << ars[i]; writer_ << ars[i];
} }
if (need_dump_param_ && thread_id_ == 0) { if (need_dump_param_ && thread_id_ == 0) {
DumpParam(); DumpParam(batch_cnt);
} }
} }
......
...@@ -31,6 +31,20 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) { ...@@ -31,6 +31,20 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
} }
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
thread_barrier_ = desc.thread_barrier(); thread_barrier_ = desc.thread_barrier();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
} }
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) { void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
...@@ -143,6 +157,49 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -143,6 +157,49 @@ void HogwildWorker::TrainFilesWithProfiler() {
op_total_time[i] += timeline.ElapsedSec(); op_total_time[i] += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
} }
if (need_dump_field_) {
size_t batch_size = device_reader_->GetCurBatchSize();
std::vector<std::string> ars(batch_size);
for (auto &ar : ars) {
ar.clear();
}
auto &ins_id_vec = device_reader_->GetInsIdVec();
auto &ins_content_vec = device_reader_->GetInsContentVec();
for (size_t i = 0; i < ins_id_vec.size(); i++) {
ars[i] += ins_id_vec[i];
ars[i] = ars[i] + "\t" + ins_content_vec[i];
}
for (auto &field : dump_fields_) {
Variable *var = thread_scope_->FindVar(field);
if (var == nullptr) {
continue;
}
LoDTensor *tensor = var->GetMutable<LoDTensor>();
if (!CheckValidOutput(tensor, batch_size)) {
continue;
}
for (size_t i = 0; i < batch_size; ++i) {
auto output_dim = tensor->dims()[1];
std::string output_dimstr =
boost::lexical_cast<std::string>(output_dim);
ars[i] = ars[i] + "\t" + field + ":" + output_dimstr;
auto bound = GetTensorBound(tensor, i);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
}
// #pragma omp parallel for
for (size_t i = 0; i < ars.size(); i++) {
if (ars[i].length() == 0) {
continue;
}
writer_ << ars[i];
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(batch_cnt);
}
}
total_inst += cur_batch; total_inst += cur_batch;
++batch_cnt; ++batch_cnt;
PrintFetchVars(); PrintFetchVars();
...@@ -160,6 +217,11 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -160,6 +217,11 @@ void HogwildWorker::TrainFilesWithProfiler() {
thread_scope_->DropKids(); thread_scope_->DropKids();
timeline.Start(); timeline.Start();
} }
if (need_dump_field_) {
writer_.Flush();
}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
if (thread_barrier_) { if (thread_barrier_) {
operators::distributed::Communicator::GetInstance() operators::distributed::Communicator::GetInstance()
...@@ -168,6 +230,10 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -168,6 +230,10 @@ void HogwildWorker::TrainFilesWithProfiler() {
#endif #endif
} }
void HogwildWorker::SetChannelWriter(ChannelObject<std::string> *queue) {
writer_.Reset(queue);
}
void HogwildWorker::TrainFiles() { void HogwildWorker::TrainFiles() {
platform::SetNumThreads(1); platform::SetNumThreads(1);
...@@ -214,5 +280,25 @@ void HogwildWorker::PrintFetchVars() { ...@@ -214,5 +280,25 @@ void HogwildWorker::PrintFetchVars() {
} }
} }
void HogwildWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
void HogwildWorker::DumpParam(const int batch_id) {
std::ostringstream os;
for (auto &param : dump_param_) {
os.str("");
Variable *var = thread_scope_->FindVar(param);
if (var == nullptr) {
continue;
}
LoDTensor *tensor = var->GetMutable<LoDTensor>();
int64_t len = tensor->numel();
os << "(" << batch_id << "," << param << ")"
<< PrintLodTensor(tensor, 0, len);
writer_ << os.str();
}
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
...@@ -25,12 +26,29 @@ namespace framework { ...@@ -25,12 +26,29 @@ namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dump_fields_path_ = trainer_desc.dump_fields_path();
dump_converter_ = trainer_desc.dump_converter();
need_dump_field_ = false;
if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") {
need_dump_field_ = true;
}
if (need_dump_field_) {
auto& file_list = dataset->GetFileList();
if (file_list.size() == 0) {
need_dump_field_ = false;
}
}
mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size();
dump_file_num_ = trainer_desc.dump_file_num();
for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size(); for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
i++) { i++) {
need_merge_var_names_.push_back( need_merge_var_names_.push_back(
trainer_desc.downpour_param().stat_var_names(i)); trainer_desc.downpour_param().stat_var_names(i));
} }
SetDataset(dataset);
// get filelist from trainer_desc here // get filelist from trainer_desc here
const std::vector<paddle::framework::DataFeed*> readers = const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders(); dataset->GetReaders();
...@@ -53,12 +71,66 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -53,12 +71,66 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->Initialize(trainer_desc); workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i); workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]); workers_[i]->SetDataFeed(readers[i]);
workers_[i]->SetNeedDump(need_dump_field_);
} }
// set debug here // set debug here
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
} }
void MultiTrainer::DumpWork(int tid) {
#ifdef _LINUX
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d-%05d", dump_fields_path_.c_str(), mpi_rank_, tid);
std::shared_ptr<FILE> fp = fs_open_write(path, &err_no, dump_converter_);
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
break;
}
size_t write_count =
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp.get());
if (write_count != out_str.length()) {
VLOG(3) << "dump text failed";
continue;
}
write_count = fwrite_unlocked("\n", 1, 1, fp.get());
if (write_count != 1) {
VLOG(3) << "dump text failed";
continue;
}
}
#endif
}
void MultiTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&MultiTrainer::DumpWork, this, i)));
}
}
void MultiTrainer::FinalizeDumpEnv() {
queue_->Close();
for (auto& th : dump_thread_) {
th.join();
}
queue_.reset();
}
// call only after all resources are set in current trainer // call only after all resources are set in current trainer
void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) { const platform::Place& place) {
...@@ -71,6 +143,13 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -71,6 +143,13 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
} }
} }
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
}
Scope* MultiTrainer::GetWorkerScope(int thread_id) { Scope* MultiTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope(); return workers_[thread_id]->GetThreadScope();
} }
...@@ -91,7 +170,12 @@ void MultiTrainer::Run() { ...@@ -91,7 +170,12 @@ void MultiTrainer::Run() {
} }
} }
void MultiTrainer::Finalize() { root_scope_->DropKids(); } void MultiTrainer::Finalize() {
if (need_dump_field_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();
}
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -68,10 +68,13 @@ class MultiTrainer : public TrainerBase { ...@@ -68,10 +68,13 @@ class MultiTrainer : public TrainerBase {
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set); virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program, virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place); const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {} virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run(); virtual void Run();
virtual void Finalize(); virtual void Finalize();
virtual void FinalizeDumpEnv();
virtual void InitDumpEnv();
virtual Scope* GetWorkerScope(int thread_id); virtual Scope* GetWorkerScope(int thread_id);
virtual void DumpWork(int tid);
protected: protected:
int thread_num_; int thread_num_;
...@@ -79,6 +82,17 @@ class MultiTrainer : public TrainerBase { ...@@ -79,6 +82,17 @@ class MultiTrainer : public TrainerBase {
std::vector<DataFeed*> readers_; std::vector<DataFeed*> readers_;
std::vector<std::shared_ptr<DeviceWorker>> workers_; std::vector<std::shared_ptr<DeviceWorker>> workers_;
std::vector<std::string> need_merge_var_names_; std::vector<std::string> need_merge_var_names_;
bool need_dump_field_;
std::string dump_fields_path_;
std::string dump_converter_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
std::vector<std::thread> dump_thread_;
int dump_thread_num_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
}; };
class DistMultiTrainer : public MultiTrainer { class DistMultiTrainer : public MultiTrainer {
...@@ -98,16 +112,6 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -98,16 +112,6 @@ class DistMultiTrainer : public MultiTrainer {
protected: protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_; std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::vector<std::thread> dump_thread_;
int dump_thread_num_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
bool need_dump_field_;
std::string dump_fields_path_;
std::string dump_converter_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
}; };
#if defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_NCCL)
......
...@@ -919,7 +919,7 @@ class Executor(object): ...@@ -919,7 +919,7 @@ class Executor(object):
def _dump_debug_info(self, program=None, trainer=None): def _dump_debug_info(self, program=None, trainer=None):
with open(str(id(program)) + "_train_desc.prototxt", "w") as fout: with open(str(id(program)) + "_train_desc.prototxt", "w") as fout:
fout.write(str(trainer)) fout.write(str(trainer))
if program._fleet_opt: if program._fleet_opt and "fleet_desc" in program._fleet_opt:
with open("fleet_desc.prototxt", "w") as fout: with open("fleet_desc.prototxt", "w") as fout:
fout.write(str(program._fleet_opt["fleet_desc"])) fout.write(str(program._fleet_opt["fleet_desc"]))
......
...@@ -333,6 +333,12 @@ class DistributedTranspiler(Fleet): ...@@ -333,6 +333,12 @@ class DistributedTranspiler(Fleet):
self._transpiler.get_pserver_programs( self._transpiler.get_pserver_programs(
self.server_endpoints()[self.server_index()]) self.server_endpoints()[self.server_index()])
def _set_opt_info(self, opt_info):
"""
this function saves the result from DistributedOptimizer.minimize()
"""
self._opt_info = opt_info
fleet = DistributedTranspiler() fleet = DistributedTranspiler()
...@@ -358,9 +364,11 @@ class TranspilerOptimizer(DistributedOptimizer): ...@@ -358,9 +364,11 @@ class TranspilerOptimizer(DistributedOptimizer):
def __init__(self, optimizer, strategy=None): def __init__(self, optimizer, strategy=None):
super(TranspilerOptimizer, self).__init__(optimizer, strategy) super(TranspilerOptimizer, self).__init__(optimizer, strategy)
self.opt_info = dict()
if strategy: if strategy:
if isinstance(strategy, DistributeTranspilerConfig) or isinstance( if isinstance(strategy, DistributeTranspilerConfig):
strategy, DistributedStrategy): self._strategy = strategy
elif isinstance(strategy, DistributedStrategy):
self._strategy = strategy self._strategy = strategy
else: else:
raise TypeError( raise TypeError(
...@@ -369,6 +377,14 @@ class TranspilerOptimizer(DistributedOptimizer): ...@@ -369,6 +377,14 @@ class TranspilerOptimizer(DistributedOptimizer):
else: else:
self._strategy = StrategyFactory.create_sync_strategy() self._strategy = StrategyFactory.create_sync_strategy()
if isinstance(self._strategy, DistributedStrategy):
self.opt_info = self._strategy.get_debug_opt()
self.opt_info["mpi_rank"] = fleet.worker_index()
self.opt_info["mpi_size"] = fleet.worker_num()
self.opt_info["trainer"] = "MultiTrainer"
self.opt_info["device_worker"] = "Hogwild"
fleet._set_opt_info(self.opt_info)
def backward(self, def backward(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -456,4 +472,5 @@ class TranspilerOptimizer(DistributedOptimizer): ...@@ -456,4 +472,5 @@ class TranspilerOptimizer(DistributedOptimizer):
optimize_ops, params_grads = self._optimizer.minimize( optimize_ops, params_grads = self._optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
fleet._transpile(config=self._strategy) fleet._transpile(config=self._strategy)
loss.block.program._fleet_opt = self.opt_info
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -69,6 +69,23 @@ class DistributedStrategy(object): ...@@ -69,6 +69,23 @@ class DistributedStrategy(object):
self._execute_strategy.num_threads = num_threads self._execute_strategy.num_threads = num_threads
if num_threads > 1: if num_threads > 1:
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
self.debug_opt = None
def set_debug_opt(self, opt_info):
self.debug_opt = opt_info
def get_debug_opt(self):
opt_info = dict()
if self.debug_opt is not None and isinstance(self.debug_opt, dict):
opt_info["dump_slot"] = bool(self.debug_opt.get("dump_slot", 0))
opt_info["dump_converter"] = str(
self.debug_opt.get("dump_converter", ""))
opt_info["dump_fields"] = self.debug_opt.get("dump_fields", [])
opt_info["dump_file_num"] = self.debug_opt.get("dump_file_num", 16)
opt_info["dump_fields_path"] = self.debug_opt.get(
"dump_fields_path", "")
opt_info["dump_param"] = self.debug_opt.get("dump_param", [])
return opt_info
def get_program_config(self): def get_program_config(self):
return self._program_config return self._program_config
......
...@@ -229,7 +229,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -229,7 +229,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fetch_list=[self.avg_cost], fetch_list=[self.avg_cost],
fetch_info=["cost"], fetch_info=["cost"],
print_period=2, print_period=2,
debug=False) debug=int(os.getenv("Debug", "0")))
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
if os.getenv("SAVE_MODEL") == "1": if os.getenv("SAVE_MODEL") == "1":
......
...@@ -79,6 +79,17 @@ class FleetDistRunnerBase(object): ...@@ -79,6 +79,17 @@ class FleetDistRunnerBase(object):
elif args.mode == "geo": elif args.mode == "geo":
self.strategy = StrategyFactory.create_geo_strategy( self.strategy = StrategyFactory.create_geo_strategy(
args.geo_sgd_need_push_nums) args.geo_sgd_need_push_nums)
self.dump_param = os.getenv("dump_param", "").split(",")
self.dump_fields = os.getenv("dump_fields", "").split(",")
self.dump_fields_path = os.getenv("dump_fields_path", "")
debug = int(os.getenv("Debug", "0"))
if debug:
self.strategy.set_debug_opt({
"dump_param": self.dump_param,
"dump_fields": self.dump_fields,
"dump_fields_path": self.dump_fields_path
})
return self.strategy return self.strategy
def build_optimizer(self, avg_cost, strategy): def build_optimizer(self, avg_cost, strategy):
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import os import os
import unittest import unittest
import tempfile
from test_dist_fleet_base import TestFleetBase from test_dist_fleet_base import TestFleetBase
...@@ -99,7 +100,11 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase): ...@@ -99,7 +100,11 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "", "http_proxy": "",
"SAVE_MODEL": "1" "SAVE_MODEL": "1",
"dump_param": "concat_0.tmp_0",
"dump_fields": "dnn-fc-3.tmp_0,dnn-fc-3.tmp_0@GRAD",
"dump_fields_path": tempfile.mkdtemp(),
"Debug": "1"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
......
...@@ -198,5 +198,30 @@ class TestHalfAsyncStrategy(unittest.TestCase): ...@@ -198,5 +198,30 @@ class TestHalfAsyncStrategy(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, half_async_config) optimizer = fleet.distributed_optimizer(optimizer, half_async_config)
class TestDebugInfo(unittest.TestCase):
def test_debug_info(self):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
optimizer = fluid.optimizer.SGD(0.0001)
strategy = StrategyFactory.create_sync_strategy()
strategy.set_debug_opt({
"dump_param": ["fc_0.tmp_0"],
"dump_fields": ["fc_0.tmp_0", "fc_0.tmp_0@GRAD"],
"dump_fields_path": "dump_text/"
})
optimizer = fleet.distributed_optimizer(optimizer, strategy)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,6 +29,7 @@ from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT ...@@ -29,6 +29,7 @@ from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker
from google.protobuf import text_format from google.protobuf import text_format
import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib
from paddle.fluid.trainer_factory import TrainerFactory
class TestListenAndServOp(unittest.TestCase): class TestListenAndServOp(unittest.TestCase):
...@@ -87,12 +88,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -87,12 +88,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info["program_id_to_worker"] = {program_id: worker} opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info main_program._fleet_opt = opt_info
trainer = DistMultiTrainer() trainer = TrainerFactory()._create_trainer(main_program._fleet_opt)
trainer._set_program(main_program) trainer._set_program(main_program)
device_worker = DownpourSGD()
device_worker._set_fleet_desc(fleet_desc)
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(fleet_desc)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
cmd = "rm fleet_desc.prototxt*" cmd = "rm fleet_desc.prototxt*"
os.system(cmd) os.system(cmd)
...@@ -147,12 +144,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -147,12 +144,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info["program_id_to_worker"] = {program_id: worker} opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info main_program._fleet_opt = opt_info
trainer = DistMultiTrainer() trainer = TrainerFactory()._create_trainer(main_program._fleet_opt)
trainer._set_program(main_program) trainer._set_program(main_program)
device_worker = DownpourSGD()
device_worker._set_fleet_desc(fleet_desc)
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(fleet_desc)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
cmd = "rm fleet_desc.prototxt*" cmd = "rm fleet_desc.prototxt*"
os.system(cmd) os.system(cmd)
...@@ -207,12 +200,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -207,12 +200,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info["program_id_to_worker"] = {program_id: worker} opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info main_program._fleet_opt = opt_info
trainer = DistMultiTrainer() trainer = TrainerFactory()._create_trainer(main_program._fleet_opt)
trainer._set_program(main_program) trainer._set_program(main_program)
device_worker = DownpourSGDOPT()
device_worker._set_fleet_desc(fleet_desc)
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(fleet_desc)
trainer._gen_trainer_desc() trainer._gen_trainer_desc()
cmd = "rm fleet_desc.prototxt*" cmd = "rm fleet_desc.prototxt*"
os.system(cmd) os.system(cmd)
......
...@@ -53,15 +53,9 @@ class TrainerFactory(object): ...@@ -53,15 +53,9 @@ class TrainerFactory(object):
device_worker_class = opt_info["device_worker"] device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]() trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]() device_worker = globals()[device_worker_class]()
if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"]) # for debug tools
trainer._set_fleet_desc(opt_info["fleet_desc"]) if opt_info is not None:
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None: if opt_info.get("dump_slot") is not None:
trainer._set_dump_slot(opt_info["dump_slot"]) trainer._set_dump_slot(opt_info["dump_slot"])
if opt_info.get("mpi_rank") is not None: if opt_info.get("mpi_rank") is not None:
...@@ -76,6 +70,18 @@ class TrainerFactory(object): ...@@ -76,6 +70,18 @@ class TrainerFactory(object):
trainer._set_dump_file_num(opt_info["dump_file_num"]) trainer._set_dump_file_num(opt_info["dump_file_num"])
if opt_info.get("dump_converter") is not None: if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"]) trainer._set_dump_converter(opt_info["dump_converter"])
if opt_info.get("dump_param") is not None:
trainer._set_dump_param(opt_info["dump_param"])
if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("no_cvm") is not None:
trainer._set_no_cvm(opt_info["no_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("adjust_ins_weight") is not None: if opt_info.get("adjust_ins_weight") is not None:
trainer._set_adjust_ins_weight(opt_info[ trainer._set_adjust_ins_weight(opt_info[
"adjust_ins_weight"]) "adjust_ins_weight"])
...@@ -84,8 +90,6 @@ class TrainerFactory(object): ...@@ -84,8 +90,6 @@ class TrainerFactory(object):
if opt_info.get("check_nan_var_names") is not None: if opt_info.get("check_nan_var_names") is not None:
trainer._set_check_nan_var_names(opt_info[ trainer._set_check_nan_var_names(opt_info[
"check_nan_var_names"]) "check_nan_var_names"])
if opt_info.get("dump_param") is not None:
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("loss_names") is not None: if opt_info.get("loss_names") is not None:
trainer._set_loss_names(opt_info["loss_names"]) trainer._set_loss_names(opt_info["loss_names"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册