ctr_accessor.cc 12.5 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/table/ctr_accessor.h"
16

Z
zhaocaibei123 已提交
17
#include <gflags/gflags.h>
18

Z
zhaocaibei123 已提交
19 20 21 22 23 24
#include "glog/logging.h"
#include "paddle/fluid/string/string_helper.h"

namespace paddle {
namespace distributed {

25
int CtrCommonAccessor::Initialize() {
Z
zhaocaibei123 已提交
26 27
  auto name = _config.embed_sgd_param().name();
  _embed_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
28
  _embed_sgd_rule->LoadConfig(_config.embed_sgd_param(), 1);
Z
zhaocaibei123 已提交
29 30 31

  name = _config.embedx_sgd_param().name();
  _embedx_sgd_rule = CREATE_PSCORE_CLASS(SparseValueSGDRule, name);
32 33
  _embedx_sgd_rule->LoadConfig(_config.embedx_sgd_param(),
                               _config.embedx_dim());
Z
zhaocaibei123 已提交
34

35
  common_feature_value.embed_sgd_dim = _embed_sgd_rule->Dim();
Z
zhaocaibei123 已提交
36
  common_feature_value.embedx_dim = _config.embedx_dim();
37
  common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
Z
zhaocaibei123 已提交
38
  _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
Z
zhaocaibei123 已提交
39 40
  _ssd_unseenday_threshold =
      _config.ctr_accessor_param().ssd_unseenday_threshold();
Z
zhaocaibei123 已提交
41

42 43 44 45
  if (_config.ctr_accessor_param().show_scale()) {
    _show_scale = true;
  }

46
  InitAccessorInfo();
Z
zhaocaibei123 已提交
47 48 49
  return 0;
}

50 51 52
void CtrCommonAccessor::InitAccessorInfo() {
  _accessor_info.dim = common_feature_value.Dim();
  _accessor_info.size = common_feature_value.Size();
Z
zhaocaibei123 已提交
53 54

  auto embedx_dim = _config.embedx_dim();
55 56 57 58 59 60
  _accessor_info.select_dim = 3 + embedx_dim;
  _accessor_info.select_size = _accessor_info.select_dim * sizeof(float);
  _accessor_info.update_dim = 4 + embedx_dim;
  _accessor_info.update_size = _accessor_info.update_dim * sizeof(float);
  _accessor_info.mf_size =
      (embedx_dim + common_feature_value.embedx_sgd_dim) * sizeof(float);
Z
zhaocaibei123 已提交
61 62
}

63
bool CtrCommonAccessor::Shrink(float* value) {
Z
zhaocaibei123 已提交
64 65 66 67 68 69 70
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delete_after_unseen_days =
      _config.ctr_accessor_param().delete_after_unseen_days();
  auto delete_threshold = _config.ctr_accessor_param().delete_threshold();

  // time_decay first
71 72
  common_feature_value.Show(value) *= _show_click_decay_rate;
  common_feature_value.Click(value) *= _show_click_decay_rate;
Z
zhaocaibei123 已提交
73 74

  // shrink after
75 76 77
  auto score = ShowClickScore(common_feature_value.Show(value),
                              common_feature_value.Click(value));
  auto unseen_days = common_feature_value.UnseenDays(value);
Z
zhaocaibei123 已提交
78 79 80 81 82 83
  if (score < delete_threshold || unseen_days > delete_after_unseen_days) {
    return true;
  }
  return false;
}

Z
zhaocaibei123 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
bool CtrCommonAccessor::SaveCache(float* value, int param,
                                  double global_cache_threshold) {
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
  if (ShowClickScore(common_feature_value.Show(value),
                     common_feature_value.Click(value)) >= base_threshold &&
      common_feature_value.UnseenDays(value) <= delta_keep_days) {
    return common_feature_value.Show(value) > global_cache_threshold;
  }
  return false;
}

bool CtrCommonAccessor::SaveSSD(float* value) {
  if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
    return true;
  }
  return false;
}

103
bool CtrCommonAccessor::Save(float* value, int param) {
Z
zhaocaibei123 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
  if (param == 2) {
    delta_threshold = 0;
  }
  switch (param) {
    // save all
    case 0: {
      return true;
    }
    // save xbox delta
    case 1:
    // save xbox base
    case 2: {
119 120 121 122
      if (ShowClickScore(common_feature_value.Show(value),
                         common_feature_value.Click(value)) >= base_threshold &&
          common_feature_value.DeltaScore(value) >= delta_threshold &&
          common_feature_value.UnseenDays(value) <= delta_keep_days) {
Z
zhaocaibei123 已提交
123 124
        // do this after save, because it must not be modified when retry
        if (param == 2) {
125
          common_feature_value.DeltaScore(value) = 0;
Z
zhaocaibei123 已提交
126 127 128 129 130 131 132 133 134
        }
        return true;
      } else {
        return false;
      }
    }
    // already decayed in shrink
    case 3: {
      // do this after save, because it must not be modified when retry
135
      // common_feature_value.UnseenDays(value)++;
Z
zhaocaibei123 已提交
136 137 138 139 140 141 142 143 144 145 146
      return true;
    }
    // save revert batch_model
    case 5: {
      return true;
    }
    default:
      return true;
  }
}

147
void CtrCommonAccessor::UpdateStatAfterSave(float* value, int param) {
Z
zhaocaibei123 已提交
148 149 150 151 152 153 154 155
  auto base_threshold = _config.ctr_accessor_param().base_threshold();
  auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
  auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
  if (param == 2) {
    delta_threshold = 0;
  }
  switch (param) {
    case 1: {
156 157 158 159 160
      if (ShowClickScore(common_feature_value.Show(value),
                         common_feature_value.Click(value)) >= base_threshold &&
          common_feature_value.DeltaScore(value) >= delta_threshold &&
          common_feature_value.UnseenDays(value) <= delta_keep_days) {
        common_feature_value.DeltaScore(value) = 0;
Z
zhaocaibei123 已提交
161 162 163 164
      }
    }
      return;
    case 3: {
165
      common_feature_value.UnseenDays(value)++;
Z
zhaocaibei123 已提交
166 167 168 169 170 171 172
    }
      return;
    default:
      return;
  }
}

173
int32_t CtrCommonAccessor::Create(float** values, size_t num) {
Z
zhaocaibei123 已提交
174 175 176
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* value = values[value_item];
177 178
    value[common_feature_value.UnseenDaysIndex()] = 0;
    value[common_feature_value.DeltaScoreIndex()] = 0;
179 180 181
    value[common_feature_value.ShowIndex()] = 0;
    value[common_feature_value.ClickIndex()] = 0;
    value[common_feature_value.SlotIndex()] = -1;
182 183 184 185 186
    _embed_sgd_rule->InitValue(value + common_feature_value.EmbedWIndex(),
                               value + common_feature_value.EmbedG2SumIndex());
    _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
                                value + common_feature_value.EmbedxG2SumIndex(),
                                false);
Z
zhaocaibei123 已提交
187 188 189 190
  }
  return 0;
}

191 192 193
bool CtrCommonAccessor::NeedExtendMF(float* value) {
  float show = value[common_feature_value.ShowIndex()];
  float click = value[common_feature_value.ClickIndex()];
Z
zhaocaibei123 已提交
194 195 196 197 198
  float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
                click * _config.ctr_accessor_param().click_coeff();
  return score >= _config.embedx_threshold();
}

199
bool CtrCommonAccessor::HasMF(size_t size) {
200
  return size > common_feature_value.EmbedxG2SumIndex();
Z
zhaocaibei123 已提交
201 202 203
}

// from CommonFeatureValue to CtrCommonPullValue
204
int32_t CtrCommonAccessor::Select(float** select_values, const float** values,
Z
zhaocaibei123 已提交
205 206 207 208 209
                                  size_t num) {
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* select_value = select_values[value_item];
    const float* value = values[value_item];
210 211 212 213
    select_value[CtrCommonPullValue::ShowIndex()] =
        value[common_feature_value.ShowIndex()];
    select_value[CtrCommonPullValue::ClickIndex()] =
        value[common_feature_value.ClickIndex()];
214 215 216 217
    select_value[CtrCommonPullValue::EmbedWIndex()] =
        value[common_feature_value.EmbedWIndex()];
    memcpy(select_value + CtrCommonPullValue::EmbedxWIndex(),
           value + common_feature_value.EmbedxWIndex(),
Z
zhaocaibei123 已提交
218 219 220 221 222 223 224 225
           embedx_dim * sizeof(float));
  }
  return 0;
}

// from CtrCommonPushValue to CtrCommonPushValue
// first dim: item
// second dim: field num
226
int32_t CtrCommonAccessor::Merge(float** update_values,
Z
zhaocaibei123 已提交
227 228 229
                                 const float** other_update_values,
                                 size_t num) {
  auto embedx_dim = _config.embedx_dim();
230
  size_t total_dim = CtrCommonPushValue::Dim(embedx_dim);
Z
zhaocaibei123 已提交
231 232 233 234
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* update_value = update_values[value_item];
    const float* other_update_value = other_update_values[value_item];
    for (auto i = 0u; i < total_dim; ++i) {
235
      if (i != CtrCommonPushValue::SlotIndex()) {
Z
zhaocaibei123 已提交
236 237 238 239 240 241 242 243 244 245
        update_value[i] += other_update_value[i];
      }
    }
  }
  return 0;
}

// from CtrCommonPushValue to CommonFeatureValue
// first dim: item
// second dim: field num
246
int32_t CtrCommonAccessor::Update(float** update_values,
Z
zhaocaibei123 已提交
247 248 249 250 251
                                  const float** push_values, size_t num) {
  auto embedx_dim = _config.embedx_dim();
  for (size_t value_item = 0; value_item < num; ++value_item) {
    float* update_value = update_values[value_item];
    const float* push_value = push_values[value_item];
252 253 254 255 256 257
    float push_show = push_value[CtrCommonPushValue::ShowIndex()];
    float push_click = push_value[CtrCommonPushValue::ClickIndex()];
    float slot = push_value[CtrCommonPushValue::SlotIndex()];
    update_value[common_feature_value.ShowIndex()] += push_show;
    update_value[common_feature_value.ClickIndex()] += push_click;
    update_value[common_feature_value.SlotIndex()] = slot;
258
    update_value[common_feature_value.DeltaScoreIndex()] +=
Z
zhaocaibei123 已提交
259 260
        (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
        push_click * _config.ctr_accessor_param().click_coeff();
261
    update_value[common_feature_value.UnseenDaysIndex()] = 0;
262
    // TODO(zhaocaibei123): add configure show_scale
263 264 265 266 267
    if (!_show_scale) {
      push_show = 1;
    }
    VLOG(3) << "accessor show scale:" << _show_scale
            << ", push_show:" << push_show;
268 269 270
    _embed_sgd_rule->UpdateValue(
        update_value + common_feature_value.EmbedWIndex(),
        update_value + common_feature_value.EmbedG2SumIndex(),
271
        push_value + CtrCommonPushValue::EmbedGIndex(), push_show);
272 273 274
    _embedx_sgd_rule->UpdateValue(
        update_value + common_feature_value.EmbedxWIndex(),
        update_value + common_feature_value.EmbedxG2SumIndex(),
275
        push_value + CtrCommonPushValue::EmbedxGIndex(), push_show);
Z
zhaocaibei123 已提交
276 277 278 279
  }
  return 0;
}

280
bool CtrCommonAccessor::CreateValue(int stage, const float* value) {
Z
zhaocaibei123 已提交
281 282 283 284 285 286
  // stage == 0, pull
  // stage == 1, push
  if (stage == 0) {
    return true;
  } else if (stage == 1) {
    // operation
287 288
    auto show = CtrCommonPushValue::Show(const_cast<float*>(value));
    auto click = CtrCommonPushValue::Click(const_cast<float*>(value));
289
    auto score = ShowClickScore(show, click);
Z
zhaocaibei123 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302
    if (score <= 0) {
      return false;
    }
    if (score >= 1) {
      return true;
    }
    return local_uniform_real_distribution<float>()(local_random_engine()) <
           score;
  } else {
    return true;
  }
}

303
float CtrCommonAccessor::ShowClickScore(float show, float click) {
Z
zhaocaibei123 已提交
304 305 306 307 308
  auto nonclk_coeff = _config.ctr_accessor_param().nonclk_coeff();
  auto click_coeff = _config.ctr_accessor_param().click_coeff();
  return (show - click) * nonclk_coeff + click * click_coeff;
}

309
std::string CtrCommonAccessor::ParseToString(const float* v, int param) {
Z
zhaocaibei123 已提交
310 311 312 313 314
  thread_local std::ostringstream os;
  os.clear();
  os.str("");
  os << v[0] << " " << v[1] << " " << v[2] << " " << v[3] << " " << v[4] << " "
     << v[5];
315 316
  for (int i = common_feature_value.EmbedG2SumIndex();
       i < common_feature_value.EmbedxWIndex(); i++) {
Z
zhaocaibei123 已提交
317 318
    os << " " << v[i];
  }
319 320
  auto show = common_feature_value.Show(const_cast<float*>(v));
  auto click = common_feature_value.Click(const_cast<float*>(v));
321
  auto score = ShowClickScore(show, click);
322
  if (score >= _config.embedx_threshold() &&
323 324
      param > common_feature_value.EmbedxWIndex()) {
    for (auto i = common_feature_value.EmbedxWIndex();
325
         i < common_feature_value.Dim(); ++i) {
Z
zhaocaibei123 已提交
326 327 328 329 330 331
      os << " " << v[i];
    }
  }
  return os.str();
}

332
int CtrCommonAccessor::ParseFromString(const std::string& str, float* value) {
Z
zhaocaibei123 已提交
333 334
  int embedx_dim = _config.embedx_dim();

335 336
  _embedx_sgd_rule->InitValue(value + common_feature_value.EmbedxWIndex(),
                              value + common_feature_value.EmbedxG2SumIndex());
Z
zhaocaibei123 已提交
337 338 339 340 341 342 343
  auto ret = paddle::string::str_to_float(str.data(), value);
  CHECK(ret >= 6) << "expect more than 6 real:" << ret;
  return ret;
}

}  // namespace distributed
}  // namespace paddle