未验证 提交 85c6937b 编写于 作者: Z zhaocaibei123 提交者: GitHub

add slot attr for push sparse op (#44422)

* add slot attr for push sparse op

* add pybind

* remove fleet

* add unittest

* fix
上级 1a7f2de3
...@@ -529,10 +529,12 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -529,10 +529,12 @@ void FleetWrapper::PushSparseFromTensorAsync(
uint64_t padding_id, uint64_t padding_id,
platform::Place place, platform::Place place,
std::vector<const LoDTensor*>* inputs, std::vector<const LoDTensor*>* inputs,
std::vector<int>& slots,
const LoDTensor* shows, const LoDTensor* shows,
const LoDTensor* clks, const LoDTensor* clks,
std::vector<LoDTensor*>* outputs, std::vector<LoDTensor*>* outputs,
bool use_cvm_op) { bool use_cvm_op) {
CHECK(slots.size() == inputs->size());
int batch_size = -1; int batch_size = -1;
bool batch_size_consist = true; bool batch_size_consist = true;
for (auto* input : *inputs) { for (auto* input : *inputs) {
...@@ -568,8 +570,8 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -568,8 +570,8 @@ void FleetWrapper::PushSparseFromTensorAsync(
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64? // TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>(); // const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>(); // const long int* clk_tensor = clks->data<int64_t>();
const int64_t* show_tensor = shows->data<int64_t>(); const float* show_tensor = shows->data<float>();
const int64_t* clk_tensor = clks->data<int64_t>(); const float* clk_tensor = clks->data<float>();
for (size_t index = 0; index < inputs->size(); ++index) { for (size_t index = 0; index < inputs->size(); ++index) {
framework::LoDTensor* g_tensor = outputs->at(index); framework::LoDTensor* g_tensor = outputs->at(index);
...@@ -603,15 +605,14 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -603,15 +605,14 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_keys.emplace_back(real_id); push_keys.emplace_back(real_id);
if (use_cvm_op) { if (use_cvm_op) {
push_values.emplace_back(fea_dim + 1); push_values.emplace_back(fea_dim + 1);
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot push_values.back()[0] = static_cast<float>(slots[index]);
float* data = push_values.back().data() + 1; float* data = push_values.back().data() + 1;
memcpy(data, g + output_len, sizeof(float) * fea_dim); memcpy(data, g + output_len, sizeof(float) * fea_dim);
} else { } else {
push_values.emplace_back(fea_dim + 3); push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined // slot show clk grad... consistent with CtrCommonPushValue defined
// in // in ctr_accessor.h
// ctr_accessor.h push_values.back()[0] = static_cast<float>(slots[index]);
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] = push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i])); (i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] = push_values.back()[2] =
...@@ -631,18 +632,16 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -631,18 +632,16 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_keys.emplace_back(real_id); push_keys.emplace_back(real_id);
if (use_cvm_op) { if (use_cvm_op) {
push_values.emplace_back(fea_dim + 1); push_values.emplace_back(fea_dim + 1);
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot push_values.back()[0] = static_cast<float>(slots[index]);
float* data = push_values.back().data() + 1; float* data = push_values.back().data() + 1;
memcpy(data, g + output_len, sizeof(float) * fea_dim); memcpy(data, g + output_len, sizeof(float) * fea_dim);
} else { } else {
push_values.emplace_back(fea_dim + 3); push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined in // slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h // ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot push_values.back()[0] = static_cast<float>(slots[index]);
push_values.back()[1] = push_values.back()[1] = (i >= show_size ? 1 : show_tensor[i]);
(i >= show_size ? 1 : static_cast<float>(show_tensor[i])); push_values.back()[2] = (i >= clk_size ? 0 : clk_tensor[i]);
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3; float* data = push_values.back().data() + 3;
memcpy(data, g + output_len, sizeof(float) * fea_dim); memcpy(data, g + output_len, sizeof(float) * fea_dim);
} }
......
...@@ -190,11 +190,13 @@ class FleetWrapper { ...@@ -190,11 +190,13 @@ class FleetWrapper {
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT std::vector<const LoDTensor*>* outputs); // NOLINT
void PushSparseFromTensorAsync(const uint64_t table_id, void PushSparseFromTensorAsync(const uint64_t table_id,
int fea_dim, int fea_dim,
uint64_t padding_id, uint64_t padding_id,
platform::Place place, platform::Place place,
std::vector<const LoDTensor*>* inputs, std::vector<const LoDTensor*>* inputs,
std::vector<int>& slots, // NOLINT
const LoDTensor* shows, const LoDTensor* shows,
const LoDTensor* clicks, const LoDTensor* clicks,
std::vector<LoDTensor*>* outputs, std::vector<LoDTensor*>* outputs,
......
...@@ -309,7 +309,7 @@ void PrivateQueueDataFeed<T>::ReadThread() { ...@@ -309,7 +309,7 @@ void PrivateQueueDataFeed<T>::ReadThread() {
std::string filename; std::string filename;
while (PickOneFile(&filename)) { while (PickOneFile(&filename)) {
int err_no = 0; int err_no = 0;
fp_ = fs_open_read(filename, &err_no, pipe_command_); fp_ = fs_open_read(filename, &err_no, pipe_command_, true);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER); __fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
T instance; T instance;
while (ParseOneInstanceFromPipe(&instance)) { while (ParseOneInstanceFromPipe(&instance)) {
...@@ -538,7 +538,7 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -538,7 +538,7 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
} else { } else {
#endif #endif
int err_no = 0; int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_, true);
#ifdef PADDLE_WITH_BOX_PS #ifdef PADDLE_WITH_BOX_PS
} }
#endif #endif
...@@ -574,7 +574,7 @@ void InMemoryDataFeed<T>::LoadIntoMemoryFromSo() { ...@@ -574,7 +574,7 @@ void InMemoryDataFeed<T>::LoadIntoMemoryFromSo() {
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
VLOG(3) << "LoadIntoMemoryFromSo() begin, thread_id=" << thread_id_; VLOG(3) << "LoadIntoMemoryFromSo() begin, thread_id=" << thread_id_;
int buf_len = 1024 * 1024 * 10; int buf_len = 1024 * 1024 * 10;
char* buf = (char*)malloc(buf_len + 10); char* buf = reinterpret_cast<char*>(malloc(buf_len + 10));
auto ps_gpu_ptr = PSGPUWrapper::GetInstance(); auto ps_gpu_ptr = PSGPUWrapper::GetInstance();
paddle::framework::CustomParser* parser = paddle::framework::CustomParser* parser =
...@@ -681,7 +681,7 @@ void MultiSlotDataFeed::ReadThread() { ...@@ -681,7 +681,7 @@ void MultiSlotDataFeed::ReadThread() {
std::string filename; std::string filename;
while (PickOneFile(&filename)) { while (PickOneFile(&filename)) {
int err_no = 0; int err_no = 0;
fp_ = fs_open_read(filename, &err_no, pipe_command_); fp_ = fs_open_read(filename, &err_no, pipe_command_, true);
CHECK(fp_ != nullptr); CHECK(fp_ != nullptr);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER); __fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
std::vector<MultiSlotType> instance; std::vector<MultiSlotType> instance;
...@@ -2175,7 +2175,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) { ...@@ -2175,7 +2175,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) {
lines); lines);
} else { } else {
int err_no = 0; int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_, true);
CHECK(this->fp_ != nullptr); CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
...@@ -2265,7 +2265,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) { ...@@ -2265,7 +2265,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) {
do { do {
int err_no = 0; int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_, true);
CHECK(this->fp_ != nullptr); CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
lines = line_reader.read_file(this->fp_.get(), line_func, lines); lines = line_reader.read_file(this->fp_.get(), line_func, lines);
...@@ -2314,7 +2314,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByCommand(void) { ...@@ -2314,7 +2314,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByCommand(void) {
do { do {
int err_no = 0; int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_, true);
CHECK(this->fp_ != nullptr); CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
......
...@@ -102,7 +102,7 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name, ...@@ -102,7 +102,7 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
cmd += " -D fs.default.name=" + fs_name; cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi; cmd += " -D hadoop.job.ugi=" + fs_ugi;
cmd += " -Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=500000"; cmd += " -Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=500000";
paddle::framework::hdfs_set_command(cmd); paddle::framework::dataset_hdfs_set_command(cmd);
} }
template <typename T> template <typename T>
......
...@@ -230,6 +230,20 @@ const std::string& hdfs_command() { return hdfs_command_internal(); } ...@@ -230,6 +230,20 @@ const std::string& hdfs_command() { return hdfs_command_internal(); }
void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; } void hdfs_set_command(const std::string& x) { hdfs_command_internal() = x; }
// dataset and model may be on different afs cluster
static std::string& dataset_hdfs_command_internal() {
static std::string x = "hadoop fs";
return x;
}
const std::string& dataset_hdfs_command() {
return dataset_hdfs_command_internal();
}
void dataset_hdfs_set_command(const std::string& x) {
dataset_hdfs_command_internal() = x;
}
static std::string& customized_download_cmd_internal() { static std::string& customized_download_cmd_internal() {
static std::string x = ""; static std::string x = "";
return x; return x;
...@@ -243,19 +257,30 @@ void set_download_command(const std::string& x) { ...@@ -243,19 +257,30 @@ void set_download_command(const std::string& x) {
std::shared_ptr<FILE> hdfs_open_read(std::string path, std::shared_ptr<FILE> hdfs_open_read(std::string path,
int* err_no, int* err_no,
const std::string& converter) { const std::string& converter,
bool read_data) {
if (download_cmd() != "") { // use customized download command if (download_cmd() != "") { // use customized download command
path = string::format_string( path = string::format_string(
"%s \"%s\"", download_cmd().c_str(), path.c_str()); "%s \"%s\"", download_cmd().c_str(), path.c_str());
} else { } else {
if (fs_end_with_internal(path, ".gz")) { if (fs_end_with_internal(path, ".gz")) {
if (read_data) {
path = string::format_string(
"%s -text \"%s\"", dataset_hdfs_command().c_str(), path.c_str());
} else {
path = string::format_string( path = string::format_string(
"%s -text \"%s\"", hdfs_command().c_str(), path.c_str()); "%s -text \"%s\"", hdfs_command().c_str(), path.c_str());
}
} else {
if (read_data) {
path = string::format_string(
"%s -cat \"%s\"", dataset_hdfs_command().c_str(), path.c_str());
} else { } else {
path = string::format_string( path = string::format_string(
"%s -cat \"%s\"", hdfs_command().c_str(), path.c_str()); "%s -cat \"%s\"", hdfs_command().c_str(), path.c_str());
} }
} }
}
bool is_pipe = true; bool is_pipe = true;
fs_add_read_converter_internal(path, is_pipe, converter); fs_add_read_converter_internal(path, is_pipe, converter);
...@@ -370,13 +395,14 @@ int fs_select_internal(const std::string& path) { ...@@ -370,13 +395,14 @@ int fs_select_internal(const std::string& path) {
std::shared_ptr<FILE> fs_open_read(const std::string& path, std::shared_ptr<FILE> fs_open_read(const std::string& path,
int* err_no, int* err_no,
const std::string& converter) { const std::string& converter,
bool read_data) {
switch (fs_select_internal(path)) { switch (fs_select_internal(path)) {
case 0: case 0:
return localfs_open_read(path, converter); return localfs_open_read(path, converter);
case 1: case 1:
return hdfs_open_read(path, err_no, converter); return hdfs_open_read(path, err_no, converter, read_data);
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -64,13 +64,18 @@ extern const std::string& hdfs_command(); ...@@ -64,13 +64,18 @@ extern const std::string& hdfs_command();
extern void hdfs_set_command(const std::string& x); extern void hdfs_set_command(const std::string& x);
extern const std::string& dataset_hdfs_command();
extern void dataset_hdfs_set_command(const std::string& x);
extern const std::string& download_cmd(); extern const std::string& download_cmd();
extern void set_download_command(const std::string& x); extern void set_download_command(const std::string& x);
extern std::shared_ptr<FILE> hdfs_open_read(std::string path, extern std::shared_ptr<FILE> hdfs_open_read(std::string path,
int* err_no, int* err_no,
const std::string& converter); const std::string& converter,
bool read_data);
extern std::shared_ptr<FILE> hdfs_open_write(std::string path, extern std::shared_ptr<FILE> hdfs_open_write(std::string path,
int* err_no, int* err_no,
...@@ -91,7 +96,8 @@ extern void hdfs_mv(const std::string& src, const std::string& dest); ...@@ -91,7 +96,8 @@ extern void hdfs_mv(const std::string& src, const std::string& dest);
// aut-detect fs // aut-detect fs
extern std::shared_ptr<FILE> fs_open_read(const std::string& path, extern std::shared_ptr<FILE> fs_open_read(const std::string& path,
int* err_no, int* err_no,
const std::string& converter); const std::string& converter,
bool read_data = false);
extern std::shared_ptr<FILE> fs_open_write(const std::string& path, extern std::shared_ptr<FILE> fs_open_write(const std::string& path,
int* err_no, int* err_no,
......
...@@ -45,5 +45,18 @@ TEST(FS, mv) { ...@@ -45,5 +45,18 @@ TEST(FS, mv) {
} catch (...) { } catch (...) {
VLOG(3) << "test hdfs_mv, catch expected errors of unknown prefix"; VLOG(3) << "test hdfs_mv, catch expected errors of unknown prefix";
} }
try {
paddle::framework::dataset_hdfs_set_command(
"hadoop -D hadoop.job.ugi=anotherxxx fs -text");
int err_no = 0;
paddle::framework::hdfs_open_read("afs:/none.gz", &err_no, "", true);
paddle::framework::hdfs_open_read("afs:/none.gz", &err_no, "", false);
paddle::framework::hdfs_open_read("afs:/none", &err_no, "", true);
paddle::framework::hdfs_open_read("afs:/none", &err_no, "", false);
} catch (...) {
VLOG(3) << "test hdfs_open_read, catch expected errors of unknown path";
}
#endif #endif
} }
...@@ -134,6 +134,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -134,6 +134,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"in the order of input variables for mapping") "in the order of input variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("slot", "slot of id").SetDefault(0).AsExtra();
AddAttr<bool>("grad_inplace", AddAttr<bool>("grad_inplace",
"(boolean, default false) " "(boolean, default false) "
"If the grad op reuse the input's variable.") "If the grad op reuse the input's variable.")
......
...@@ -105,6 +105,7 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,6 +105,7 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.") AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.")
.SetDefault(0) .SetDefault(0)
.AsExtra(); .AsExtra();
AddAttr<int>("slot", "slot of id").SetDefault(0).AsExtra();
AddAttr<std::vector<int64_t>>("height_sections", AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.") "Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({})) .SetDefault(std::vector<int64_t>({}))
......
...@@ -113,6 +113,11 @@ class DistributedPushSparseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -113,6 +113,11 @@ class DistributedPushSparseOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_cvm_op", "(boolean, default false) Use cvm op or not.") AddAttr<bool>("use_cvm_op", "(boolean, default false) Use cvm op or not.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("slots",
"[slot_id1, slot_id2] Slots array of Ids.")
.SetDefault({})
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Lookup Tablel Prefetch Operator. Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W, This operator is used to perform lookup on parameter W,
......
...@@ -33,6 +33,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> { ...@@ -33,6 +33,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
auto table_id = context.Attr<int>("table_id"); auto table_id = context.Attr<int>("table_id");
auto emb_dim = context.Attr<int>("size"); auto emb_dim = context.Attr<int>("size");
auto use_cvm_op = context.Attr<bool>("use_cvm_op"); auto use_cvm_op = context.Attr<bool>("use_cvm_op");
auto slots = context.Attr<std::vector<int>>("slots");
auto inputs = context.MultiInput<framework::LoDTensor>("Ids"); auto inputs = context.MultiInput<framework::LoDTensor>("Ids");
auto shows = context.Input<framework::LoDTensor>("Shows"); auto shows = context.Input<framework::LoDTensor>("Shows");
...@@ -47,6 +48,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> { ...@@ -47,6 +48,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
static_cast<uint64_t>(padding_idx), static_cast<uint64_t>(padding_idx),
context.GetPlace(), context.GetPlace(),
&inputs, &inputs,
slots,
shows, shows,
clks, clks,
&outputs, &outputs,
...@@ -103,6 +105,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> { ...@@ -103,6 +105,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
static_cast<uint64_t>(padding_idx), static_cast<uint64_t>(padding_idx),
context.GetPlace(), context.GetPlace(),
&tmp_input_vec, &tmp_input_vec,
slots,
tmp_shows_tensor, tmp_shows_tensor,
tmp_clicks_tensor, tmp_clicks_tensor,
&tmp_output_vec); &tmp_output_vec);
......
...@@ -150,7 +150,7 @@ class DistributedOpsPass(PassBase): ...@@ -150,7 +150,7 @@ class DistributedOpsPass(PassBase):
print('ShowClickEntry not configured, will not use') print('ShowClickEntry not configured, will not use')
show = _program.global_block().create_var( show = _program.global_block().create_var(
name="show", name="show",
dtype=core.VarDesc.VarType.INT64, dtype=core.VarDesc.VarType.FP32,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True)
_program.global_block()._insert_op(index=0, _program.global_block()._insert_op(index=0,
...@@ -165,7 +165,7 @@ class DistributedOpsPass(PassBase): ...@@ -165,7 +165,7 @@ class DistributedOpsPass(PassBase):
clk = _program.global_block().create_var( clk = _program.global_block().create_var(
name="clk", name="clk",
dtype=core.VarDesc.VarType.INT64, dtype=core.VarDesc.VarType.FP32,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True)
_program.global_block()._insert_op(index=0, _program.global_block()._insert_op(index=0,
...@@ -190,6 +190,9 @@ class DistributedOpsPass(PassBase): ...@@ -190,6 +190,9 @@ class DistributedOpsPass(PassBase):
padding_idx = ops[0].attr("padding_idx") padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed") is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type op_type = ops[0].type
slots = [op.attr("slot") for op in ops]
print('debug zcb slots: ', slots)
outputs = [ outputs = [
_program.global_block().vars[op.input("Out@GRAD")[0]] _program.global_block().vars[op.input("Out@GRAD")[0]]
for op in ops for op in ops
...@@ -204,7 +207,7 @@ class DistributedOpsPass(PassBase): ...@@ -204,7 +207,7 @@ class DistributedOpsPass(PassBase):
'W': w, 'W': w,
"Outputs": outputs, "Outputs": outputs,
"Shows": show, "Shows": show,
"Clicks": clk "Clicks": clk,
}, },
outputs={"Outputs": outputs}, outputs={"Outputs": outputs},
attrs={ attrs={
...@@ -213,7 +216,8 @@ class DistributedOpsPass(PassBase): ...@@ -213,7 +216,8 @@ class DistributedOpsPass(PassBase):
"padding_idx": padding_idx, "padding_idx": padding_idx,
"table_id": table_id, "table_id": table_id,
"size": self.emb_size[param], "size": self.emb_size[param],
"use_cvm_op": use_cvm_op "use_cvm_op": use_cvm_op,
"slots": slots
}) })
def _pull_sparse_fuse(self, _program, pull_sparse_ops, attrs, send_ctx): def _pull_sparse_fuse(self, _program, pull_sparse_ops, attrs, send_ctx):
......
...@@ -1073,7 +1073,8 @@ def sparse_embedding(input, ...@@ -1073,7 +1073,8 @@ def sparse_embedding(input,
entry=None, entry=None,
table_class="MemorySparseTable", table_class="MemorySparseTable",
param_attr=None, param_attr=None,
dtype='float32'): dtype='float32',
slot=None):
r""" r"""
:api_attr: Static Graph :api_attr: Static Graph
...@@ -1220,6 +1221,9 @@ def sparse_embedding(input, ...@@ -1220,6 +1221,9 @@ def sparse_embedding(input,
) )
entry_str = entry._to_attr() entry_str = entry._to_attr()
if slot == None:
slot = 0
helper.append_op(type='lookup_table', helper.append_op(type='lookup_table',
inputs={ inputs={
'Ids': input, 'Ids': input,
...@@ -1233,9 +1237,9 @@ def sparse_embedding(input, ...@@ -1233,9 +1237,9 @@ def sparse_embedding(input,
'remote_prefetch': True, 'remote_prefetch': True,
'is_test': is_test, 'is_test': is_test,
'entry': entry_str, 'entry': entry_str,
'table_class': table_class 'table_class': table_class,
'slot': slot
}) })
return tmp return tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册