sparse_accessor.cc 12.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
#include <gflags/gflags.h>
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle {
namespace distributed {

23
int SparseAccessor::Initialize() {
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
  auto name = _config.embed_sgd_param().name();
  _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
  _embed_sgd_rule->load_config(_config.embed_sgd_param(), 1);

  name = _config.embedx_sgd_param().name();
  _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
  _embedx_sgd_rule->load_config(_config.embedx_sgd_param(),
                                _config.embedx_dim());

  sparse_feature_value.embed_sgd_dim = _embed_sgd_rule->dim();
  sparse_feature_value.embedx_dim = _config.embedx_dim();
  sparse_feature_value.embedx_sgd_dim = _embedx_sgd_rule->dim();
  _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();

  return 0;
}

41
void SparseAccessor::SetTableInfo(AccessorInfo& info) {
42 43 44 45 46 47 48
  info.dim = Dim();
  info.size = Size();
  info.select_dim = SelectDim();
  info.select_size = SelectSize();
  info.update_dim = UpdateDim();
  info.update_size = UpdateSize();
  info.mf_size = MFSize();
49 50
}

51 52 53
size_t SparseAccessor::GetTableInfo(InfoKey key) {
  switch (key) {
    case DIM:
54
      return Dim();
55
    case SIZE:
56
      return Size();
57
    case SELECT_DIM:
58
      return SelectDim();
59
    case SELECT_SIZE:
60
      return SelectSize();
61
    case UPDATE_DIM:
62
      return UpdateDim();
63
    case UPDATE_SIZE:
64
      return UpdateSize();
65
    case MF_SIZE:
66 67 68
      return MFSize();
    default:
      return 0;
69 70 71 72
  }
  return 0;
}

73
size_t SparseAccessor::Dim() { return sparse_feature_value.Dim(); }
74

75
size_t SparseAccessor::DimSize(size_t dim) {
76
  auto embedx_dim = _config.embedx_dim();
77
  return sparse_feature_value.DimSize(dim, embedx_dim);
78 79
}

80
size_t SparseAccessor::Size() { return sparse_feature_value.Size(); }
81

82
size_t SparseAccessor::MFSize() {
83 84 85 86 87
  return (_config.embedx_dim() + sparse_feature_value.embedx_sgd_dim) *
         sizeof(float);  // embedx embedx_g2sum
}

// pull value
88
size_t SparseAccessor::SelectDim() {
89 90 91 92
  auto embedx_dim = _config.embedx_dim();
  return 1 + embedx_dim;
}

93
size_t SparseAccessor::SelectDimSize(size_t dim) { return sizeof(float); }
94

95
size_t SparseAccessor::SelectSize() { return SelectDim() * sizeof(float); }
96 97

// push value
98
size_t SparseAccessor::UpdateDim() {
99 100 101 102
  auto embedx_dim = _config.embedx_dim();
  return 4 + embedx_dim;
}

103
size_t SparseAccessor::UpdateDimSize(size_t dim) { return sizeof(float); }
104

105
size_t SparseAccessor::UpdateSize() { return UpdateDim() * sizeof(float); }
106

107
bool SparseAccessor::Shrink(float* value) {
108 109 110 111 112 113 114
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delete_after_unseen_days =
      _config.ctr_accessor_param().delete_after_unseen_days();
  auto delete_threshold = _config.ctr_accessor_param().delete_threshold();

  // time_decay first
115 116
  sparse_feature_value.Show(value) *= _show_click_decay_rate;
  sparse_feature_value.Click(value) *= _show_click_decay_rate;
117 118

  // shrink after
119 120
  auto score = show_click_score(sparse_feature_value.Show(value),
                                sparse_feature_value.Click(value));
121 122 123 124 125 126 127
  auto unseen_days = sparse_feature_value.unseen_days(value);
  if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
    return true;
  }
  return false;
}

128
bool SparseAccessor::Save(float* value, int param) {
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
  if (param == 2) {
    delta_threshold = 0;
  }
  switch (param) {
    // save all
    case 0: {
      return true;
    }
    // save xbox delta
    case 1:
    // save xbox base
    case 2: {
144 145
      if (show_click_score(sparse_feature_value.Show(value),
                           sparse_feature_value.Click(value)) >=
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
              base_threshold &&
          sparse_feature_value.delta_score(value) >= delta_threshold &&
          sparse_feature_value.unseen_days(value) <= delta_keep_days) {
        // do this after save, because it must not be modified when retry
        if (param == 2) {
          sparse_feature_value.delta_score(value) = 0;
        }
        return true;
      } else {
        return false;
      }
    }
    // already decayed in shrink
    case 3: {
      // do this after save, because it must not be modified when retry
      // sparse_feature_value.unseen_days(value)++;
      return true;
    }
    // save revert batch_model
    case 5: {
      return true;
    }
    default:
      return true;
  }
}

173
void SparseAccessor::UpdateStatAfterSave(float* value, int param) {
174 175 176 177 178 179 180 181
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
  if (param == 2) {
    delta_threshold = 0;
  }
  switch (param) {
    case 1: {
182 183
      if (show_click_score(sparse_feature_value.Show(value),
                           sparse_feature_value.Click(value)) >=
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
              base_threshold &&
          sparse_feature_value.delta_score(value) >= delta_threshold &&
          sparse_feature_value.unseen_days(value) <= delta_keep_days) {
        sparse_feature_value.delta_score(value) = 0;
      }
    }
      return;
    case 3: {
      sparse_feature_value.unseen_days(value)++;
    }
      return;
    default:
      return;
  }
}

200
int32_t SparseAccessor::Create(float** values, size_t num) {
201 202 203 204 205
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* value = values[value_item];
    value[sparse_feature_value.unseen_days_index()] = 0;
    value[sparse_feature_value.delta_score_index()] = 0;
206 207 208
    value[sparse_feature_value.ShowIndex()] = 0;
    value[sparse_feature_value.ClickIndex()] = 0;
    value[sparse_feature_value.SlotIndex()] = -1;
209
    _embed_sgd_rule->init_value(
210
        value + sparse_feature_value.Embed_W_Index(),
211 212
        value + sparse_feature_value.embed_g2sum_index());
    _embedx_sgd_rule->init_value(
213
        value + sparse_feature_value.Embedx_W_Index(),
214 215 216 217 218
        value + sparse_feature_value.embedx_g2sum_index(), false);
  }
  return 0;
}

219 220 221
bool SparseAccessor::NeedExtendMF(float* value) {
  float show = value[sparse_feature_value.ShowIndex()];
  float click = value[sparse_feature_value.ClickIndex()];
222 223 224 225 226
  float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
                click * _config.ctr_accessor_param().click_coeff();
  return score >= _config.embedx_threshold();
}

227
bool SparseAccessor::HasMF(size_t size) {
228 229 230 231
  return size > sparse_feature_value.embedx_g2sum_index();
}

// from SparseFeatureValue to SparsePullValue
232
int32_t SparseAccessor::Select(float** select_values, const float** values,
233 234 235 236 237
                               size_t num) {
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* select_value = select_values[value_item];
    const float* value = values[value_item];
238 239 240 241
    select_value[SparsePullValue::Embed_W_Index()] =
        value[sparse_feature_value.Embed_W_Index()];
    memcpy(select_value + SparsePullValue::Embedx_W_Index(),
           value + sparse_feature_value.Embedx_W_Index(),
242 243 244 245 246 247 248 249
           embedx_dim * sizeof(float));
  }
  return 0;
}

// from SparsePushValue to SparsePushValue
// first dim: item
// second dim: field num
250
int32_t SparseAccessor::Merge(float** update_values,
251 252
                              const float** other_update_values, size_t num) {
  auto embedx_dim = _config.embedx_dim();
253
  size_t total_dim = SparsePushValue::Dim(embedx_dim);
254 255 256 257
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* update_value = update_values[value_item];
    const float* other_update_value = other_update_values[value_item];
    for (auto i = 0u; i < total_dim; ++i) {
258
      if (i != SparsePushValue::SlotIndex()) {
259 260 261 262 263 264 265 266 267 268
        update_value[i] += other_update_value[i];
      }
    }
  }
  return 0;
}

// from SparsePushValue to SparseFeatureValue
// first dim: item
// second dim: field num
269
int32_t SparseAccessor::Update(float** update_values, const float** push_values,
270 271 272 273 274
                               size_t num) {
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* update_value = update_values[value_item];
    const float* push_value = push_values[value_item];
275 276 277 278 279 280
    float push_show = push_value[SparsePushValue::ShowIndex()];
    float push_click = push_value[SparsePushValue::ClickIndex()];
    float slot = push_value[SparsePushValue::SlotIndex()];
    update_value[sparse_feature_value.ShowIndex()] += push_show;
    update_value[sparse_feature_value.ClickIndex()] += push_click;
    update_value[sparse_feature_value.SlotIndex()] = slot;
281 282 283 284 285
    update_value[sparse_feature_value.delta_score_index()] +=
        (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
        push_click * _config.ctr_accessor_param().click_coeff();
    update_value[sparse_feature_value.unseen_days_index()] = 0;
    _embed_sgd_rule->update_value(
286
        update_value + sparse_feature_value.Embed_W_Index(),
287
        update_value + sparse_feature_value.embed_g2sum_index(),
288
        push_value + SparsePushValue::Embed_G_Index());
289
    _embedx_sgd_rule->update_value(
290
        update_value + sparse_feature_value.Embedx_W_Index(),
291
        update_value + sparse_feature_value.embedx_g2sum_index(),
292
        push_value + SparsePushValue::Embedx_G_Index());
293 294 295 296
  }
  return 0;
}

297
bool SparseAccessor::CreateValue(int stage, const float* value) {
298 299 300 301 302 303
  // stage == 0, pull
  // stage == 1, push
  if (stage == 0) {
    return true;
  } else if (stage == 1) {
    // operation
304 305
    auto show = SparsePushValue::Show(const_cast<float*>(value));
    auto click = SparsePushValue::Click(const_cast<float*>(value));
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
    auto score = show_click_score(show, click);
    if (score <= 0) {
      return false;
    }
    if (score >= 1) {
      return true;
    }
    return local_uniform_real_distribution<float>()(local_random_engine()) <
           score;
  } else {
    return true;
  }
}

float SparseAccessor::show_click_score(float show, float click) {
  auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
  auto click_coeff = _config.ctr_accessor_param().click_coeff();
  return (show - click) * nonclk_coeff + click * click_coeff;
}

326
std::string SparseAccessor::ParseToString(const float* v, int param) {
327 328 329 330 331 332
  thread_local std::ostringstream os;
  os.clear();
  os.str("");
  os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
     << v[5];
  for (int i = sparse_feature_value.embed_g2sum_index();
333
       i < sparse_feature_value.Embedx_W_Index(); i++) {
334 335
    os << " " << v[i];
  }
336 337
  auto show = sparse_feature_value.Show(const_cast<float*>(v));
  auto click = sparse_feature_value.Click(const_cast<float*>(v));
338 339
  auto score = show_click_score(show, click);
  if (score >= _config.embedx_threshold() &&
340 341 342
      param > sparse_feature_value.Embedx_W_Index()) {
    for (auto i = sparse_feature_value.Embedx_W_Index();
         i < sparse_feature_value.Dim(); ++i) {
343 344 345 346 347 348
      os << " " << v[i];
    }
  }
  return os.str();
}

349
int SparseAccessor::ParseFromString(const std::string& str, float* value) {
350 351 352
  int embedx_dim = _config.embedx_dim();

  _embedx_sgd_rule->init_value(
353
      value + sparse_feature_value.Embedx_W_Index(),
354 355 356 357 358 359 360 361
      value + sparse_feature_value.embedx_g2sum_index());
  auto ret = paddle::string::str_to_float(str.data(), value);
  CHECK(ret >= 6) << "expect more than 6 real:" << ret;
  return ret;
}

}  // namespace distributed
}  // namespace paddle