未验证 提交 513d1f97 编写于 作者: Z zhaocaibei123 提交者: GitHub

fix pscore geo&lr_decay (#37995)

* fix

* modify log

* fix batch_size
上级 11c785a4
......@@ -75,9 +75,8 @@ class CostTimer {
}
~CostTimer() {
if (_is_print_cost) {
LOG(INFO) << "CostTimer label:" << _label
<< ", cost:" << butil::gettimeofday_ms() - _start_time_ms
<< "ms";
VLOG(3) << "CostTimer label:" << _label
<< ", cost:" << butil::gettimeofday_ms() - _start_time_ms << "ms";
} else {
*(_profiler_node->recorder) << butil::gettimeofday_ms() - _start_time_ms;
}
......
......@@ -439,13 +439,16 @@ void FleetWrapper::PushSparseFromTensorAsync(
const LoDTensor* shows, const LoDTensor* clks,
std::vector<LoDTensor*>* outputs) {
int batch_size = -1;
bool batch_size_consist = true;
for (auto* input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
CHECK(batch_size == cur_batch_size); // NOLINT
// CHECK(batch_size == cur_batch_size); // NOLINT
batch_size_consist = false;
break;
}
}
CHECK(batch_size > 0); // NOLINT
......@@ -461,7 +464,7 @@ void FleetWrapper::PushSparseFromTensorAsync(
for (framework::LoDTensor* g_tensor : *outputs) {
float* g_ori = g_tensor->data<float>();
// no cvm
if (true) { // TODO(zhaocaibei123): add config
if (batch_size_consist) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
......
......@@ -949,6 +949,10 @@ void GeoCommunicator::InitDense(std::vector<std::string> &varnames,
auto *old_var = old_scope_->Var(t);
old_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, old_var);
// init pserver_scope_
auto *pserver_var = pserver_scope_->Var(t);
pserver_var->GetMutable<framework::LoDTensor>();
framework::CopyVariable(*global_var, pserver_var);
}
VLOG(1) << "init dense table " << table_id << " done";
}
......
......@@ -42,17 +42,24 @@ class SendOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
// auto is_sparse = Attr<int>("is_sparse");
auto is_sparse = Attr<int>("is_sparse");
auto table_id = Attr<int>("table_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
// for common_dense_table, distributed_push_sparse op for push sparse in
// async
if (is_sparse == 0 && send_varnames.size() >= 1 &&
send_varnames[0] != "@PS_STEP_COUNTER@") {
auto fleet = paddle::distributed::FleetWrapper::GetInstance();
std::vector<::std::future<int32_t>> status;
// Note: only send push_dense now!
// communicator->Send(ins, scope) can be used to push_sparse or push_dense
fleet->PushDenseVarsAsync(scope, table_id, ins, &status, 0, -1);
} else {
auto* communicator = paddle::distributed::Communicator::GetInstance();
if (communicator->Check(send_varnames)) {
communicator->Send(ins, scope);
}
}
// auto fleet = paddle::distributed::FleetWrapper::GetInstance();
// if (is_sparse == 0) {
// std::vector<::std::future<int32_t>> status;
......
......@@ -503,7 +503,7 @@ def append_send_ops_pass(program, config):
split_dense_table=config.is_heter_ps_mode)
for merged_name, send in sends.items():
if send.is_sparse():
if send.is_sparse() and not config.is_geo_mode():
continue
is_sparse = 1 if send.is_sparse() else 0
is_sparse = 2 if send.is_distributed() else is_sparse
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册