未验证 提交 1fe468d3 编写于 作者: T Thunderbrook 提交者: GitHub

support debug each output of each ins (#19004)

* dump slot

* test

* proto

* dump slot

* test

* proto

* code style

* code style

* code style

* style

* add delete after unseen days

* add unseen days

* code style

* conflict solve
test=develop

* add clear model

* code style
test=develop

* code style
test=develop

* support debug tensor of each ins
test=develop

* support debug tensor of each ins
test=develop

* learning rate

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style

* code style
test=develop

* code style
test=develop

* unitest

* style

* style

* multi phase

* add channel

* code style

* style

* style

* unitest

* style

* define

* define
test=develop

* style
test=develop

* rm define
test=develop

* linux

* linux
test=develop

* style
test=develop

* output format
test=develop

* windows ci
test=develop
上级 bd35a7f0
......@@ -168,10 +168,10 @@ class ArchiveBase {
#else
if (newsize > Capacity()) {
#endif
Reserve(std::max(Capacity() * 2, newsize));
Reserve((std::max)(Capacity() * 2, newsize));
}
finish_ = buffer_ + newsize;
cursor_ = std::min(cursor_, finish_);
cursor_ = (std::min)(cursor_, finish_);
}
void Reserve(size_t newcap) {
......@@ -207,7 +207,7 @@ class ArchiveBase {
#else
if (size > size_t(limit_ - finish_)) {
#endif
Reserve(std::max(Capacity() * 2, Length() + size));
Reserve((std::max)(Capacity() * 2, Length() + size));
}
}
......@@ -311,6 +311,18 @@ class Archive<BinaryArchiveType> : public ArchiveBase {
*this >> x;
return x;
}
template <class... ARGS>
void Printf(const char* fmt, ARGS&&... args) {
size_t temp = Limit() - Finish();
int len = snprintf(Finish(), temp, fmt, args...);
CHECK(len >= 0); // NOLINT
if ((size_t)len >= temp) {
PrepareWrite(len + 1);
CHECK(snprintf(Finish(), (size_t)len + 1, fmt, args...) == len);
}
AdvanceFinish(len);
}
};
template <class AR, class T, size_t N>
......
......@@ -40,7 +40,7 @@ class ChannelObject {
// capacity can be zero
explicit ChannelObject(size_t capacity) {
capacity_ = std::min(MaxCapacity(), capacity);
capacity_ = (std::min)(MaxCapacity(), capacity);
}
void Clear() {
......@@ -192,7 +192,7 @@ class ChannelObject {
std::condition_variable full_cond_;
static constexpr size_t MaxCapacity() {
return std::numeric_limits<size_t>::max() / 2;
return (std::numeric_limits<size_t>::max)() / 2;
}
void Notify() {
......@@ -289,7 +289,7 @@ template <class T>
using Channel = std::shared_ptr<ChannelObject<T>>;
template <class T>
Channel<T> MakeChannel(size_t capacity = std::numeric_limits<size_t>::max()) {
Channel<T> MakeChannel(size_t capacity = (std::numeric_limits<size_t>::max)()) {
return std::make_shared<ChannelObject<T>>(capacity);
}
......@@ -370,7 +370,7 @@ class ChannelWriter {
void Reset(ChannelObject<T>* channel) {
CHECK(buffer_.empty()) << "Forgot to flush";
CHECK(channel != nullptr) << "Channel can not be nullptr";
// CHECK(channel != nullptr) << "Channel can not be nullptr";
channel_ = channel;
buffer_.clear();
failed_ = !channel;
......
......@@ -224,6 +224,7 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_content_ = false;
this->input_channel_ = nullptr;
this->output_channel_ = nullptr;
this->consume_channel_ = nullptr;
......@@ -307,6 +308,11 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
......@@ -766,6 +772,18 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
pos += len + 1;
VLOG(3) << "ins_id " << instance->ins_id_;
}
if (parse_content_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->content_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "content " << instance->content_;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
......@@ -890,8 +908,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
std::vector<std::vector<size_t>> offset(use_slots_.size(),
std::vector<size_t>{0});
std::vector<bool> visit(use_slots_.size(), false);
ins_content_vec_.clear();
ins_content_vec_.reserve(ins_vec.size());
ins_id_vec_.clear();
ins_id_vec_.reserve(ins_vec.size());
for (size_t i = 0; i < ins_vec.size(); ++i) {
auto& r = ins_vec[i];
ins_id_vec_.push_back(r.ins_id_);
ins_content_vec_.push_back(r.content_);
for (auto& item : r.float_feasigns_) {
batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
visit[item.slot()] = true;
......
......@@ -105,10 +105,18 @@ class DataFeed {
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetParseInsId(bool parse_ins_id) {}
virtual void SetParseContent(bool parse_content) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
virtual const std::vector<std::string>& GetInsIdVec() const {
return ins_id_vec_;
}
virtual const std::vector<std::string>& GetInsContentVec() const {
return ins_content_vec_;
}
virtual int GetCurBatchSize() { return batch_size_; }
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
......@@ -164,6 +172,8 @@ class DataFeed {
bool finish_set_filelist_;
bool finish_start_;
std::string pipe_command_;
std::vector<std::string> ins_id_vec_;
std::vector<std::string> ins_content_vec_;
platform::Place place_;
};
......@@ -222,6 +232,7 @@ class InMemoryDataFeed : public DataFeed {
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void LoadIntoMemory();
protected:
......@@ -232,6 +243,7 @@ class InMemoryDataFeed : public DataFeed {
int thread_id_;
int thread_num_;
bool parse_ins_id_;
bool parse_content_;
std::ifstream file_;
std::shared_ptr<FILE> fp_;
paddle::framework::ChannelObject<T>* input_channel_;
......@@ -426,6 +438,7 @@ struct Record {
std::vector<FeatureItem> uint64_feasigns_;
std::vector<FeatureItem> float_feasigns_;
std::string ins_id_;
std::string content_;
};
struct RecordCandidate {
......
......@@ -48,6 +48,8 @@ DatasetImpl<T>::DatasetImpl() {
erase_duplicate_feas_ = true;
keep_unmerged_ins_ = true;
min_merge_size_ = 2;
parse_ins_id_ = false;
parse_content_ = false;
}
// set filelist, file_idx_ will reset to zero.
......@@ -103,6 +105,16 @@ void DatasetImpl<T>::SetChannelNum(int channel_num) {
channel_num_ = channel_num;
}
template <typename T>
void DatasetImpl<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void DatasetImpl<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T>
void DatasetImpl<T>::SetMergeByInsId(
const std::vector<std::string>& merge_slot_list, bool erase_duplicate_feas,
......@@ -378,7 +390,8 @@ void DatasetImpl<T>::CreateReaders() {
readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
readers_[i]->SetFileListIndex(&file_idx_);
readers_[i]->SetFileList(filelist_);
readers_[i]->SetParseInsId(merge_by_insid_);
readers_[i]->SetParseInsId(parse_ins_id_);
readers_[i]->SetParseContent(parse_content_);
if (input_channel_ != nullptr) {
readers_[i]->SetInputChannel(input_channel_.get());
}
......
......@@ -58,6 +58,9 @@ class Dataset {
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// set channel num
virtual void SetChannelNum(int channel_num) = 0;
// set parse ins id
virtual void SetParseInsId(bool parse_ins_id) = 0;
virtual void SetParseContent(bool parse_content) = 0;
// set merge by ins id
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
......@@ -133,6 +136,8 @@ class DatasetImpl : public Dataset {
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual void SetChannelNum(int channel_num);
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetMergeByInsId(const std::vector<std::string>& merge_slot_list,
bool erase_duplicate_feas, int min_merge_size,
bool keep_unmerged_ins);
......@@ -193,6 +198,8 @@ class DatasetImpl : public Dataset {
int64_t fleet_send_sleep_seconds_;
std::vector<std::thread> preload_threads_;
bool merge_by_insid_;
bool parse_ins_id_;
bool parse_content_;
bool erase_duplicate_feas_;
bool keep_unmerged_ins_;
int min_merge_size_;
......
......@@ -114,6 +114,8 @@ class DeviceWorker {
virtual void BindingDataFeedMemory() = 0;
virtual void SetRootScope(Scope* root_scope);
virtual void SetDataFeed(DataFeed* data_feed);
virtual void SetNeedDump(bool need_dump_field) {}
virtual void SetChannelWriter(ChannelObject<std::string>* queue) {}
virtual void SetPlace(const paddle::platform::Place& place) {
place_ = place;
}
......@@ -172,6 +174,8 @@ class DownpourWorker : public HogwildWorker {
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
protected:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
......@@ -183,8 +187,11 @@ class DownpourWorker : public HogwildWorker {
private:
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_;
float scale_datanorm_;
// just save the value in param_ for easy access
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker_factory.h"
......@@ -27,6 +28,19 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
SetDataset(dataset);
dump_fields_path_ = trainer_desc.dump_fields_path();
dump_converter_ = trainer_desc.dump_converter();
need_dump_field_ = false;
if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") {
need_dump_field_ = true;
}
if (need_dump_field_) {
auto& file_list = dataset->GetFileList();
if (file_list.size() == 0) {
need_dump_field_ = false;
}
}
mpi_rank_ = trainer_desc.mpi_rank() / 2;
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
......@@ -39,6 +53,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetNeedDump(need_dump_field_);
}
VLOG(3) << "going to initialize pull dense worker";
......@@ -48,7 +63,51 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
SetDebug(trainer_desc.debug());
}
void DistMultiTrainer::DumpWork() {
#ifdef _LINUX
while (1) {
std::string out_str;
if (!queue_->Get(out_str)) {
break;
}
size_t write_count =
fwrite_unlocked(out_str.data(), 1, out_str.length(), fp_.get());
if (write_count != out_str.length()) {
VLOG(3) << "dump text failed";
continue;
}
write_count = fwrite_unlocked("\n", 1, 1, fp_.get());
if (write_count != 1) {
VLOG(3) << "dump text failed";
continue;
}
}
#endif
}
void DistMultiTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
int err_no = 0;
std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_ = std::thread(&DistMultiTrainer::DumpWork, this);
}
void DistMultiTrainer::FinalizeDumpEnv() {
queue_->Close();
dump_thread_.join();
queue_.reset();
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) {
InitDumpEnv();
}
pull_dense_worker_->SetRootScope(root_scope_);
pull_dense_worker_->Start();
VLOG(3) << "init other env done.";
......@@ -70,6 +129,9 @@ void DistMultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
if (need_dump_field_) {
FinalizeDumpEnv();
}
pull_dense_worker_->Stop();
root_scope_->DropKids();
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/string/string_helper.h"
#if defined _WIN32 || defined __APPLE__
#else
......@@ -71,9 +72,89 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
use_cvm_ = desc.use_cvm();
scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
}
void DownpourWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue);
}
void DownpourWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}
template <typename T>
std::string PrintLodTensorType(LoDTensor* tensor, int64_t start, int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << tensor->data<T>()[i];
}
return os.str();
}
std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
int64_t end) {
auto count = tensor->numel();
if (start < 0 || end > count) {
VLOG(3) << "access violation";
return "access violation";
}
std::ostringstream os;
for (int64_t i = start; i < end; i++) {
os << ":" << static_cast<uint64_t>(tensor->data<int64_t>()[i]);
}
return os.str();
}
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
} else if (tensor->type() == proto::VarType::INT64) {
out_val = PrintLodTensorIntType(tensor, start, end);
} else if (tensor->type() == proto::VarType::FP64) {
out_val = PrintLodTensorType<double>(tensor, start, end);
} else {
out_val = "unsupported type";
}
return out_val;
}
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
return {lod[index] * dims[1], lod[index + 1] * dims[1]};
} else {
return {index * dims[1], (index + 1) * dims[1]};
}
}
bool CheckValidOutput(LoDTensor* tensor, int batch_size) {
auto& dims = tensor->dims();
if (dims.size() != 2) return false;
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
if (lod.size() != batch_size + 1) {
return false;
}
} else {
if (dims[0] != batch_size) {
return false;
}
}
return true;
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
uint64_t table_id = static_cast<uint64_t>(
param_.program_config(0).pull_sparse_table_id(table_idx));
......@@ -646,11 +727,52 @@ void DownpourWorker::TrainFiles() {
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
}
if (need_dump_field_) {
int batch_size = device_reader_->GetCurBatchSize();
std::vector<std::string> ars(batch_size);
for (auto& ar : ars) {
ar.clear();
}
auto& ins_id_vec = device_reader_->GetInsIdVec();
auto& ins_content_vec = device_reader_->GetInsContentVec();
for (size_t i = 0; i < ins_id_vec.size(); i++) {
ars[i] += ins_id_vec[i];
ars[i] = ars[i] + "\t" + ins_content_vec[i];
}
for (auto& field : dump_fields_) {
Variable* var = thread_scope_->FindVar(field);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (!CheckValidOutput(tensor, batch_size)) {
continue;
}
for (int i = 0; i < batch_size; ++i) {
auto output_dim = tensor->dims()[1];
std::string output_dimstr =
boost::lexical_cast<std::string>(output_dim);
ars[i] = ars[i] + "\t" + field + ":" + output_dimstr;
auto bound = GetTensorBound(tensor, i);
ars[i] += PrintLodTensor(tensor, bound.first, bound.second);
}
}
// #pragma omp parallel for
for (size_t i = 0; i < ars.size(); i++) {
if (ars[i].length() == 0) {
continue;
}
writer_ << ars[i];
}
}
PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
}
if (need_dump_field_) {
writer_.Flush();
}
}
} // end namespace framework
......
......@@ -86,9 +86,21 @@ class DistMultiTrainer : public MultiTrainer {
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Run();
virtual void Finalize();
virtual void FinalizeDumpEnv();
virtual void InitDumpEnv();
virtual void DumpWork();
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
std::thread dump_thread_;
std::shared_ptr<FILE> fp_;
std::shared_ptr<paddle::framework::ChannelObject<std::string>> queue_;
bool need_dump_field_;
std::string dump_fields_path_;
std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_;
};
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
......
......@@ -35,6 +35,10 @@ message TrainerDesc {
optional bool use_cvm = 8 [ default = false ];
optional bool dump_slot = 9 [ default = false ];
optional float scale_datanorm = 10 [ default = -1 ];
optional int32 mpi_rank = 11 [ default = -1 ];
optional string dump_fields_path = 12;
repeated string dump_fields = 13;
optional string dump_converter = 14;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......
......@@ -100,6 +100,10 @@ void BindDataset(py::module* m) {
py::call_guard<py::gil_scoped_release>())
.def("set_queue_num", &framework::Dataset::SetChannelNum,
py::call_guard<py::gil_scoped_release>())
.def("set_parse_ins_id", &framework::Dataset::SetParseInsId,
py::call_guard<py::gil_scoped_release>())
.def("set_parse_content", &framework::Dataset::SetParseContent,
py::call_guard<py::gil_scoped_release>())
.def("set_merge_by_lineid", &framework::Dataset::SetMergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
......
......@@ -282,6 +282,8 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
self.fleet_send_batch_size = None
self.queue_num = None
self.parse_ins_id = False
self.parse_content = False
self.merge_by_lineid = False
def _prepare_to_run(self):
......@@ -297,6 +299,8 @@ class InMemoryDataset(DatasetBase):
if self.queue_num is None:
self.queue_num = self.thread_num
self.dataset.set_queue_num(self.queue_num)
self.dataset.set_parse_ins_id(self.parse_ins_id)
self.dataset.set_parse_content(self.parse_content)
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_channel()
self.dataset.create_readers()
......@@ -318,6 +322,40 @@ class InMemoryDataset(DatasetBase):
"""
self.queue_num = queue_num
def set_parse_ins_id(self, parse_ins_id):
"""
Set id Dataset need to parse insid
Args:
parse_ins_id(bool): if parse ins_id or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_parse_ins_id(True)
"""
self.parse_ins_id = parse_ins_id
def set_parse_content(self, parse_content):
"""
Set if Dataset need to parse content
Args:
parse_content(bool): if parse content or not
Examples:
.. code-block:: python
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_parse_content(True)
"""
self.parse_content = parse_content
def set_fleet_send_batch_size(self, fleet_send_batch_size):
"""
Set fleet send batch size, default is 80000
......
......@@ -347,6 +347,21 @@ class PSLib(Fleet):
self._fleet_ptr.clear_model()
self._role_maker._barrier_worker()
def clear_model(self):
"""
clear_model() will be called by user. It will clear sparse model.
Examples:
.. code-block:: python
fleet.clear_model()
"""
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.clear_model()
self._role_maker._barrier_worker()
def load_one_table(self, table_id, model_path, **kwargs):
"""
load pslib model for one table or load params from paddle model
......@@ -385,6 +400,7 @@ class PSLib(Fleet):
fout.write(my_program.desc.serialize_to_string())
"""
self._role_maker._barrier_worker()
mode = kwargs.get("mode", 0)
scope = kwargs.get("scope", None)
model_proto_file = kwargs.get("model_proto_file", None)
......@@ -558,7 +574,7 @@ class DownpourOptimizer(DistributedOptimizer):
parameter_list,
no_grad_set,
self._strategy)
opt_info["mpi_rank"] = fleet._role_maker._get_rank()
fleet._set_opt_info(opt_info)
programs = [loss.block.program for loss in losses]
......
......@@ -248,6 +248,9 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["dump_slot"] = False
opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
......
......@@ -52,6 +52,65 @@ class TestDataset(unittest.TestCase):
except:
self.assertTrue(True)
def test_config(self):
"""
Testcase for python config.
"""
dataset = fluid.InMemoryDataset()
dataset.set_parse_ins_id(True)
dataset.set_parse_content(True)
self.assertTrue(dataset.parse_ins_id)
self.assertTrue(dataset.parse_content)
def test_run_with_dump(self):
"""
Testcase for InMemoryDataset from create to run.
"""
with open("test_run_with_dump_a.txt", "w") as f:
data = "1 a 1 a 1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 b 1 b 1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 c 1 c 1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_run_with_dump_b.txt", "w") as f:
data = "1 d 1 d 1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 e 1 e 1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 f 1 f 1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 g 1 g 1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(
name=slot, shape=[1], dtype="int64", lod_level=1)
slots_vars.append(var)
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
dataset.set_batch_size(32)
dataset.set_thread(3)
dataset.set_filelist(
["test_run_with_dump_a.txt", "test_run_with_dump_b.txt"])
dataset.set_parse_ins_id(True)
dataset.set_parse_content(True)
dataset.set_pipe_command("cat")
dataset.set_use_var(slots_vars)
dataset.load_into_memory()
dataset.set_fea_eval(10000, True)
dataset.local_shuffle()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except ImportError as e:
pass
except Exception as e:
self.assertTrue(False)
os.remove("./test_run_with_dump_a.txt")
os.remove("./test_run_with_dump_b.txt")
def test_dataset_config(self):
""" Testcase for dataset configuration. """
dataset = fluid.core.Dataset("MultiSlotDataset")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TestCases for TrainerDesc,
including config, etc.
"""
from __future__ import print_function
import paddle.fluid as fluid
import numpy as np
import os
import shutil
import unittest
class TestTrainerDesc(unittest.TestCase):
""" TestCases for TrainerDesc. """
def test_config(self):
"""
Testcase for python config.
"""
trainer_desc = fluid.trainer_desc.TrainerDesc()
trainer_desc._set_dump_fields(["a", "b"])
trainer_desc._set_mpi_rank(1)
trainer_desc._set_dump_fields_path("path")
dump_fields = trainer_desc.proto_desc.dump_fields
mpi_rank = trainer_desc.proto_desc.mpi_rank
dump_fields_path = trainer_desc.proto_desc.dump_fields_path
self.assertEqual(len(dump_fields), 2)
self.assertEqual(dump_fields[0], "a")
self.assertEqual(dump_fields[1], "b")
self.assertEqual(mpi_rank, 1)
self.assertEqual(dump_fields_path, "path")
if __name__ == '__main__':
unittest.main()
......@@ -81,6 +81,19 @@ class TrainerDesc(object):
def _set_dump_slot(self, dump_slot):
self.proto_desc.dump_slot = dump_slot
def _set_mpi_rank(self, mpi_rank):
self.proto_desc.mpi_rank = mpi_rank
def _set_dump_fields(self, dump_fields):
for field in dump_fields:
self.proto_desc.dump_fields.append(field)
def _set_dump_fields_path(self, path):
self.proto_desc.dump_fields_path = path
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
def _set_adjust_ins_weight(self, config_dict):
self.proto_desc.adjust_ins_weight_config.need_adjust = \
config_dict.get("need_adjust", False)
......
......@@ -41,6 +41,10 @@ class TrainerFactory(object):
trainer._set_use_cvm(opt_info["use_cvm"])
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
trainer._set_dump_slot(opt_info["dump_slot"])
trainer._set_mpi_rank(opt_info["mpi_rank"])
trainer._set_dump_fields(opt_info["dump_fields"])
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
trainer._set_dump_converter(opt_info["dump_converter"])
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_device_worker(device_worker)
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册