提交 3f12c2e0 编写于 作者: M malin10

bug fix

上级 7c0196d4
...@@ -48,7 +48,7 @@ ExternalProject_Add( ...@@ -48,7 +48,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR} PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR} DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz DOWNLOAD_COMMAND cp /home/malin10/baidu/paddlepaddle/pslib/pslib.tar.gz ${PSLIB_NAME}.tar.gz
&& tar zxvf ${PSLIB_NAME}.tar.gz && tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -47,7 +47,7 @@ ExternalProject_Add( ...@@ -47,7 +47,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_BRPC_SOURCE_DIR} PREFIX ${PSLIB_BRPC_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR} DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_NAME}.tar.gz DOWNLOAD_COMMAND cp /home/malin10/Paddle/pslib_brpc.tar.gz ${PSLIB_BRPC_NAME}.tar.gz
&& tar zxvf ${PSLIB_BRPC_NAME}.tar.gz && tar zxvf ${PSLIB_BRPC_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -217,7 +217,7 @@ elseif(WITH_PSLIB) ...@@ -217,7 +217,7 @@ elseif(WITH_PSLIB)
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper tree_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc ) graph_to_program_pass variable_helper timer monitor pslib_brpc )
# TODO: Fix these unittest failed on Windows # TODO: Fix these unittest failed on Windows
if(NOT WIN32) if(NOT WIN32)
......
...@@ -859,7 +859,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -859,7 +859,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
} else { } else {
const char* str = reader.get(); const char* str = reader.get();
std::string line = std::string(str); std::string line = std::string(str);
// VLOG(3) << line; VLOG(1) << line;
char* endptr = const_cast<char*>(str); char* endptr = const_cast<char*>(str);
int pos = 0; int pos = 0;
if (parse_ins_id_) { if (parse_ins_id_) {
...@@ -907,9 +907,11 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -907,9 +907,11 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
instance->rank = rank; instance->rank = rank;
pos += len + 1; pos += len + 1;
} }
std::stringstream ss;
for (size_t i = 0; i < use_slots_index_.size(); ++i) { for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i]; int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10); int num = strtol(&str[pos], &endptr, 10);
ss << "(" << idx << ", " << num << "); ";
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
num, 0, num, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -936,7 +938,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -936,7 +938,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
// if uint64 feasign is equal to zero, ignore it // if uint64 feasign is equal to zero, ignore it
// except when slot is dense // except when slot is dense
if (feasign == 0 && !use_slots_is_dense_[i]) { if (feasign == 0 && !use_slots_is_dense_[i] &&
all_slots_[i] != "12345") {
continue; continue;
} }
FeatureKey f; FeatureKey f;
...@@ -954,6 +957,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) { ...@@ -954,6 +957,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
} }
} }
} }
VLOG(1) << ss.str();
instance->float_feasigns_.shrink_to_fit(); instance->float_feasigns_.shrink_to_fit();
instance->uint64_feasigns_.shrink_to_fit(); instance->uint64_feasigns_.shrink_to_fit();
fea_num_ += instance->uint64_feasigns_.size(); fea_num_ += instance->uint64_feasigns_.size();
......
...@@ -31,6 +31,7 @@ limitations under the License. */ ...@@ -31,6 +31,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
...@@ -94,6 +95,22 @@ struct Record { ...@@ -94,6 +95,22 @@ struct Record {
uint64_t search_id; uint64_t search_id;
uint32_t rank; uint32_t rank;
uint32_t cmatch; uint32_t cmatch;
void Print() {
std::stringstream ss;
ss << "int64_feasigns: [";
for (uint64_t i = 0; i < uint64_feasigns_.size(); i++) {
ss << "(" << uint64_feasigns_[i].slot() << ", "
<< uint64_feasigns_[i].sign().uint64_feasign_ << "); ";
}
ss << "]\t\tfloat64_feasigns:[";
for (uint64_t i = 0; i < float_feasigns_.size(); i++) {
ss << "(" << float_feasigns_[i].slot() << ", "
<< float_feasigns_[i].sign().float_feasign_ << "); ";
}
ss << "]\n";
VLOG(1) << ss.str();
}
}; };
struct PvInstanceObject { struct PvInstanceObject {
......
...@@ -365,7 +365,8 @@ void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id, ...@@ -365,7 +365,8 @@ void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id,
// do sample // do sample
template <typename T> template <typename T>
void DatasetImpl<T>::TDMSample(const uint16_t sample_slot, void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
const uint64_t type_slot) { const uint64_t type_slot,
const uint64_t start_h) {
VLOG(0) << "DatasetImpl<T>::Sample() begin"; VLOG(0) << "DatasetImpl<T>::Sample() begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
...@@ -379,6 +380,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot, ...@@ -379,6 +380,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) { if (!multi_output_channel_[i] || multi_output_channel_[i]->Size() == 0) {
continue; continue;
} }
multi_output_channel_[i]->Close();
multi_output_channel_[i]->ReadAll(data[i]); multi_output_channel_[i]->ReadAll(data[i]);
} }
} else { } else {
...@@ -388,17 +390,23 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot, ...@@ -388,17 +390,23 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
input_channel_->ReadAll(data[data.size() - 1]); input_channel_->ReadAll(data[data.size() - 1]);
} }
VLOG(1) << "finish read src data, data.size = " << data.size()
<< "; details: ";
auto tree_ptr = TreeWrapper::GetInstance(); auto tree_ptr = TreeWrapper::GetInstance();
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
for (auto i = 0; i < data.size(); i++) { for (auto i = 0; i < data.size(); i++) {
VLOG(1) << "data[" << i << "]: size = " << data[i].size();
std::vector<T> tmp_results; std::vector<T> tmp_results;
tree_ptr->sample(sample_slot, type_slot, data[i], &tmp_results); tree_ptr->sample(sample_slot, type_slot, &data[i], &tmp_results, start_h);
VLOG(1) << "sample_results(" << sample_slot << ", " << type_slot
<< ") = " << tmp_results.size();
sample_results.push_back(tmp_results); sample_results.push_back(tmp_results);
} }
auto output_channel_num = multi_output_channel_.size(); auto output_channel_num = multi_output_channel_.size();
for (auto i = 0; i < sample_results.size(); i++) { for (auto i = 0; i < sample_results.size(); i++) {
auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num; auto output_idx = fleet_ptr->LocalRandomEngine()() % output_channel_num;
multi_output_channel_[output_idx]->Open();
multi_output_channel_[output_idx]->Write(std::move(sample_results[i])); multi_output_channel_[output_idx]->Write(std::move(sample_results[i]));
} }
......
...@@ -47,8 +47,8 @@ class Dataset { ...@@ -47,8 +47,8 @@ class Dataset {
virtual ~Dataset() {} virtual ~Dataset() {}
virtual void InitTDMTree( virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config) = 0; const std::vector<std::pair<std::string, std::string>> config) = 0;
virtual void TDMSample(const uint16_t sample_slot, virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot,
const uint64_t type_slot) = 0; const uint64_t start_h) = 0;
virtual void TDMDump(std::string name, const uint64_t table_id, virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path) = 0; int fea_value_dim, const std::string tree_path) = 0;
// set file list // set file list
...@@ -168,7 +168,8 @@ class DatasetImpl : public Dataset { ...@@ -168,7 +168,8 @@ class DatasetImpl : public Dataset {
virtual void InitTDMTree( virtual void InitTDMTree(
const std::vector<std::pair<std::string, std::string>> config); const std::vector<std::pair<std::string, std::string>> config);
virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot); virtual void TDMSample(const uint16_t sample_slot, const uint64_t type_slot,
const uint64_t start_h);
virtual void TDMDump(std::string name, const uint64_t table_id, virtual void TDMDump(std::string name, const uint64_t table_id,
int fea_value_dim, const std::string tree_path); int fea_value_dim, const std::string tree_path);
......
...@@ -171,6 +171,7 @@ class DeviceWorker { ...@@ -171,6 +171,7 @@ class DeviceWorker {
device_reader_->SetPlace(place); device_reader_->SetPlace(place);
} }
virtual Scope* GetThreadScope() { return thread_scope_; } virtual Scope* GetThreadScope() { return thread_scope_; }
virtual void GetXpuOpIndex() {}
protected: protected:
virtual void DumpParam(const Scope& scope, const int batch_id); virtual void DumpParam(const Scope& scope, const int batch_id);
......
if(WITH_PSLIB) if(WITH_PSLIB)
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib) cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib)
cc_library(tree_wrapper SRCS tree_wrapper.cc)
else() else()
cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope) cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_PSLIB) endif(WITH_PSLIB)
......
...@@ -192,7 +192,8 @@ framework::proto::VarType::Type HeterWrapper::ToVarType( ...@@ -192,7 +192,8 @@ framework::proto::VarType::Type HeterWrapper::ToVarType(
case VariableMessage::BOOL: case VariableMessage::BOOL:
return framework::proto::VarType::BOOL; // NOLINT return framework::proto::VarType::BOOL; // NOLINT
default: default:
VLOG(0) << "Not support type " << type; PADDLE_THROW(platform::errors::InvalidArgument(
"ToVarType:Unsupported type %d", type));
} }
} }
......
...@@ -12,20 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h" #include "paddle/fluid/framework/fleet/tree_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int Tree::load(std::string path, std::string tree_pipe_command_) { std::shared_ptr<TreeWrapper> TreeWrapper::s_instance_(nullptr);
int Tree::load(std::string path) {
uint64_t linenum = 0; uint64_t linenum = 0;
size_t idx = 0; size_t idx = 0;
std::vector<std::string> lines; std::vector<std::string> lines;
...@@ -33,10 +38,10 @@ int Tree::load(std::string path, std::string tree_pipe_command_) { ...@@ -33,10 +38,10 @@ int Tree::load(std::string path, std::string tree_pipe_command_) {
std::vector<std::string> items; std::vector<std::string> items;
int err_no; int err_no;
std::shared_ptr<FILE> fp_ = fs_open_read(path, &err_no, tree_pipe_command_); std::shared_ptr<FILE> fp_ = fs_open_read(path, &err_no, "");
string::LineFileReader reader; string::LineFileReader reader;
while (reader.getline(&*(fp_.get()))) { while (reader.getline(&*(fp_.get()))) {
line = std::string(reader.get()); auto line = std::string(reader.get());
strs.clear(); strs.clear();
boost::split(strs, line, boost::is_any_of("\t")); boost::split(strs, line, boost::is_any_of("\t"));
if (0 == linenum) { if (0 == linenum) {
...@@ -132,16 +137,21 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, ...@@ -132,16 +137,21 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
std::shared_ptr<FILE> fp = std::shared_ptr<FILE> fp =
paddle::framework::fs_open(tree_path, "w", &ret, ""); paddle::framework::fs_open(tree_path, "w", &ret, "");
std::vector<uint64_t> fea_keys, std::vector<float *> pull_result_ptr; std::vector<uint64_t> fea_keys;
std::vector<float*> pull_result_ptr;
fea_keys.reserve(_total_node_num); fea_keys.reserve(_total_node_num);
pull_result_ptr.reserve(_total_node_num); pull_result_ptr.reserve(_total_node_num);
for (size_t i = 0; i != _total_node_num; ++i) { for (size_t i = 0; i != _total_node_num; ++i) {
_nodes[i].embedding.resize(fea_value_dim); _nodes[i].embedding.resize(fea_value_dim);
fea_key.push_back(_nodes[i].id); fea_keys.push_back(_nodes[i].id);
pull_result_ptr.push_back(_nodes[i].embedding.data()); pull_result_ptr.push_back(_nodes[i].embedding.data());
} }
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
std::string first_line = boost::lexical_cast<std::string>(_total_node_num) + std::string first_line = boost::lexical_cast<std::string>(_total_node_num) +
"\t" + "\t" +
boost::lexical_cast<std::string>(_tree_height); boost::lexical_cast<std::string>(_tree_height);
...@@ -183,7 +193,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, ...@@ -183,7 +193,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
bool Tree::trace_back(uint64_t id, bool Tree::trace_back(uint64_t id,
std::vector<std::pair<uint64_t, uint32_t>>* ids) { std::vector<std::pair<uint64_t, uint32_t>>* ids) {
ids.clear(); ids->clear();
std::unordered_map<uint64_t, Node*>::iterator find_it = std::unordered_map<uint64_t, Node*>::iterator find_it =
_leaf_node_map.find(id); _leaf_node_map.find(id);
if (find_it == _leaf_node_map.end()) { if (find_it == _leaf_node_map.end()) {
......
...@@ -103,15 +103,14 @@ class TreeWrapper { ...@@ -103,15 +103,14 @@ class TreeWrapper {
} }
void sample(const uint16_t sample_slot, const uint64_t type_slot, void sample(const uint16_t sample_slot, const uint64_t type_slot,
const std::vector<Record>& src_datas, std::vector<Record>* src_datas,
std::vector<Record>* sample_results) { std::vector<Record>* sample_results, const uint64_t start_h) {
sample_results->clear(); sample_results->clear();
auto debug_idx = 0; for (auto& data : *src_datas) {
for (auto& data : src_datas) { VLOG(1) << "src record";
if (debug_idx == 0) { data.Print();
VLOG(0) << "src record"; uint64_t start_idx = sample_results->size();
data.Print(); VLOG(1) << "before sample, sample_results.size = " << start_idx;
}
uint64_t sample_feasign_idx = -1, type_feasign_idx = -1; uint64_t sample_feasign_idx = -1, type_feasign_idx = -1;
for (uint64_t i = 0; i < data.uint64_feasigns_.size(); i++) { for (uint64_t i = 0; i < data.uint64_feasigns_.size(); i++) {
if (data.uint64_feasigns_[i].slot() == sample_slot) { if (data.uint64_feasigns_[i].slot() == sample_slot) {
...@@ -121,6 +120,8 @@ class TreeWrapper { ...@@ -121,6 +120,8 @@ class TreeWrapper {
type_feasign_idx = i; type_feasign_idx = i;
} }
} }
VLOG(1) << "sample_feasign_idx: " << sample_feasign_idx
<< "; type_feasign_idx: " << type_feasign_idx;
if (sample_feasign_idx > 0) { if (sample_feasign_idx > 0) {
std::vector<std::pair<uint64_t, uint32_t>> trace_ids; std::vector<std::pair<uint64_t, uint32_t>> trace_ids;
for (std::unordered_map<std::string, TreePtr>::iterator ite = for (std::unordered_map<std::string, TreePtr>::iterator ite =
...@@ -139,18 +140,20 @@ class TreeWrapper { ...@@ -139,18 +140,20 @@ class TreeWrapper {
Record instance(data); Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
trace_ids[i].first; trace_ids[i].first;
if (type_feasign_idx > 0) if (type_feasign_idx > 0 && trace_ids[i].second > start_h)
instance.uint64_feasigns_[type_feasign_idx] instance.uint64_feasigns_[type_feasign_idx].sign().uint64_feasign_ =
.sign() (instance.uint64_feasigns_[type_feasign_idx]
.uint64_feasign_ += trace_ids[i].second * 100; .sign()
if (debug_idx == 0) { .uint64_feasign_ +
VLOG(0) << "sample results:" << i; 1) *
instance.Print(); 100 +
} trace_ids[i].second;
sample_results->push_back(instance); sample_results->push_back(instance);
} }
} }
debug_idx += 1; for (auto i = start_idx; i < sample_results->size(); i++) {
sample_results->at(i).Print();
}
} }
return; return;
} }
......
...@@ -611,8 +611,8 @@ class InMemoryDataset(DatasetBase): ...@@ -611,8 +611,8 @@ class InMemoryDataset(DatasetBase):
def init_tdm_tree(self, configs): def init_tdm_tree(self, configs):
self.dataset.init_tdm_tree(configs) self.dataset.init_tdm_tree(configs)
def tdm_sample(self, sample_slot, type_slot): def tdm_sample(self, sample_slot, type_slot, start_h):
self.dataset.tdm_sample(sample_slot, type_slot) self.dataset.tdm_sample(sample_slot, type_slot, start_h)
def tdm_dump(self, name, table_id, fea_value_dim, tree_path): def tdm_dump(self, name, table_id, fea_value_dim, tree_path):
self.dataset.tdm_dump(name, table_id, fea_value_dim, tree_path) self.dataset.tdm_dump(name, table_id, fea_value_dim, tree_path)
......
...@@ -1353,10 +1353,11 @@ class Executor(object): ...@@ -1353,10 +1353,11 @@ class Executor(object):
print_period=100): print_period=100):
is_heter = 0 is_heter = 0
if not program._fleet_opt is None: if not program._fleet_opt is None:
if program._fleet_opt.get("worker_class", "") == "HeterCpuWorker": is_heter = 0
is_heter = 1 #if program._fleet_opt.get("worker_class", "") == "HeterCpuWorker":
if program._fleet_opt("trainer", "") == "HeterXpuTrainer": # is_heter = 1
is_heter = 1 #if program._fleet_opt("trainer", "") == "HeterXpuTrainer":
# is_heter = 1
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
......
...@@ -24,8 +24,7 @@ import sys ...@@ -24,8 +24,7 @@ import sys
import time import time
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_pslib from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet as fleet_transpiler
from . import hdfs from . import hdfs
from .hdfs import * from .hdfs import *
from . import utils from . import utils
...@@ -35,7 +34,7 @@ __all__ = ["FleetUtil"] ...@@ -35,7 +34,7 @@ __all__ = ["FleetUtil"]
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
fleet = fleet_pslib #fleet = fleet_pslib
class FleetUtil(object): class FleetUtil(object):
...@@ -52,14 +51,16 @@ class FleetUtil(object): ...@@ -52,14 +51,16 @@ class FleetUtil(object):
""" """
def __init__(self, mode="pslib"): def __init__(self, mode="pslib"):
global fleet pass
if mode == "pslib":
fleet = fleet_pslib # global fleet
elif mode == "transpiler": # if mode == "pslib":
fleet = fleet_transpiler # fleet = fleet_pslib
else: # elif mode == "transpiler":
raise ValueError( # fleet = fleet_transpiler
"Please choose one mode from [\"pslib\", \"transpiler\"]") # else:
# raise ValueError(
# "Please choose one mode from [\"pslib\", \"transpiler\"]")
def rank0_print(self, s): def rank0_print(self, s):
""" """
......
...@@ -79,7 +79,7 @@ class HDFSClient(FS): ...@@ -79,7 +79,7 @@ class HDFSClient(FS):
time_out=5 * 60 * 1000, #ms time_out=5 * 60 * 1000, #ms
sleep_inter=1000): #ms sleep_inter=1000): #ms
# Raise exception if JAVA_HOME not exists. # Raise exception if JAVA_HOME not exists.
java_home = os.environ["JAVA_HOME"] #java_home = os.environ["JAVA_HOME"]
self.pre_commands = [] self.pre_commands = []
hadoop_bin = '%s/bin/hadoop' % hadoop_home hadoop_bin = '%s/bin/hadoop' % hadoop_home
......
...@@ -489,11 +489,11 @@ def embedding(input, ...@@ -489,11 +489,11 @@ def embedding(input,
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'], check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fluid.layers.embedding') 'fluid.layers.embedding')
if is_distributed: # if is_distributed:
is_distributed = False # is_distributed = False
warnings.warn( # warnings.warn(
"is_distributed is go out of use, `fluid.contrib.layers.sparse_embedding` is your needed" # "is_distributed is go out of use, `fluid.contrib.layers.sparse_embedding` is your needed"
) # )
remote_prefetch = True if is_sparse else False remote_prefetch = True if is_sparse else False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册