ctr_double_accessor.h 7.8 KB
Newer Older
Y
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"

namespace paddle {
namespace distributed {

27
class CtrDoubleAccessor : public ValueAccessor {
Y
yaoxuefeng 已提交
28
 public:
29
  struct CtrDoubleFeatureValue {
Y
yaoxuefeng 已提交
30 31 32 33 34 35 36 37 38 39 40
    /*
    float unseen_days;
    float delta_score;
    double show;
    double click;
    float embed_w;
    float embed_g2sum;
    float slot;
    float embedx_g2sum;
    std::vector<float> embedx_w;
    */
41 42 43 44
    static int Dim(int embedx_dim) { return 8 + embedx_dim; }
    static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
    static int Size(int embedx_dim) {
      return (Dim(embedx_dim) + 2) * sizeof(float);
Y
yaoxuefeng 已提交
45
    }
46 47
    static int UnseenDaysIndex() { return 0; }
    static int DeltaScoreIndex() {
48
      return CtrDoubleFeatureValue::UnseenDaysIndex() + 1;
Y
yaoxuefeng 已提交
49
    }
50
    static int ShowIndex() {
51
      return CtrDoubleFeatureValue::DeltaScoreIndex() + 1;
Y
yaoxuefeng 已提交
52 53
    }
    // show is double
54
    static int ClickIndex() { return CtrDoubleFeatureValue::ShowIndex() + 2; }
Y
yaoxuefeng 已提交
55
    // click is double
56
    static int EmbedWIndex() { return CtrDoubleFeatureValue::ClickIndex() + 2; }
57
    static int EmbedG2SumIndex() {
58
      return CtrDoubleFeatureValue::EmbedWIndex() + 1;
Y
yaoxuefeng 已提交
59
    }
60
    static int SlotIndex() {
61
      return CtrDoubleFeatureValue::EmbedG2SumIndex() + 1;
Y
yaoxuefeng 已提交
62
    }
63
    static int EmbedxG2SumIndex() {
64
      return CtrDoubleFeatureValue::SlotIndex() + 1;
Y
yaoxuefeng 已提交
65
    }
66
    static int EmbedxWIndex() {
67
      return CtrDoubleFeatureValue::EmbedxG2SumIndex() + 1;
Y
yaoxuefeng 已提交
68
    }
69
    static float& UnseenDays(float* val) {
70
      return val[CtrDoubleFeatureValue::UnseenDaysIndex()];
Y
yaoxuefeng 已提交
71
    }
72
    static float& DeltaScore(float* val) {
73
      return val[CtrDoubleFeatureValue::DeltaScoreIndex()];
Y
yaoxuefeng 已提交
74
    }
75
    static double& Show(float* val) {
76
      return ((double*)(val + CtrDoubleFeatureValue::ShowIndex()))[0];
Y
yaoxuefeng 已提交
77
    }
78
    static double& Click(float* val) {
79
      return ((double*)(val + CtrDoubleFeatureValue::ClickIndex()))[0];
Y
yaoxuefeng 已提交
80
    }
81
    static float& Slot(float* val) {
82
      return val[CtrDoubleFeatureValue::SlotIndex()];
Y
yaoxuefeng 已提交
83
    }
84
    static float& EmbedW(float* val) {
85
      return val[CtrDoubleFeatureValue::EmbedWIndex()];
Y
yaoxuefeng 已提交
86
    }
87
    static float& EmbedG2Sum(float* val) {
88
      return val[CtrDoubleFeatureValue::EmbedG2SumIndex()];
Y
yaoxuefeng 已提交
89
    }
90
    static float& EmbedxG2Sum(float* val) {
91
      return val[CtrDoubleFeatureValue::EmbedxG2SumIndex()];
Y
yaoxuefeng 已提交
92
    }
93
    static float* EmbedxW(float* val) {
94
      return (val + CtrDoubleFeatureValue::EmbedxWIndex());
Y
yaoxuefeng 已提交
95 96
    }
  };
97
  struct CtrDoublePushValue {
Y
yaoxuefeng 已提交
98 99 100 101 102 103 104
    /*
    float slot;
    float show;
    float click;
    float embed_g;
    std::vector<float> embedx_g;
    */
105 106 107 108
    static int Dim(int embedx_dim) { return 4 + embedx_dim; }
    static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int SlotIndex() { return 0; }
109 110 111 112
    static int ShowIndex() { return CtrDoublePushValue::SlotIndex() + 1; }
    static int ClickIndex() { return CtrDoublePushValue::ShowIndex() + 1; }
    static int EmbedGIndex() { return CtrDoublePushValue::ClickIndex() + 1; }
    static int EmbedxGIndex() { return CtrDoublePushValue::EmbedGIndex() + 1; }
113
    static float& Slot(float* val) {
114
      return val[CtrDoublePushValue::SlotIndex()];
Y
yaoxuefeng 已提交
115
    }
116
    static float& Show(float* val) {
117
      return val[CtrDoublePushValue::ShowIndex()];
Y
yaoxuefeng 已提交
118
    }
119
    static float& Click(float* val) {
120
      return val[CtrDoublePushValue::ClickIndex()];
Y
yaoxuefeng 已提交
121
    }
122
    static float& EmbedG(float* val) {
123
      return val[CtrDoublePushValue::EmbedGIndex()];
Y
yaoxuefeng 已提交
124
    }
125
    static float* EmbedxG(float* val) {
126
      return val + CtrDoublePushValue::EmbedxGIndex();
Y
yaoxuefeng 已提交
127 128
    }
  };
129
  struct CtrDoublePullValue {
Y
yaoxuefeng 已提交
130 131 132 133 134 135
    /*
    float show;
    float click;
    float embed_w;
    std::vector<float> embedx_w;
    */
136 137 138 139 140
    static int Dim(int embedx_dim) { return 3 + embedx_dim; }
    static int DimSize(size_t dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int ShowIndex() { return 0; }
    static int ClickIndex() { return 1; }
141 142
    static int EmbedWIndex() { return 2; }
    static int EmbedxWIndex() { return 3; }
143
    static float& Show(float* val) {
144
      return val[CtrDoublePullValue::ShowIndex()];
Y
yaoxuefeng 已提交
145
    }
146
    static float& Click(float* val) {
147
      return val[CtrDoublePullValue::ClickIndex()];
Y
yaoxuefeng 已提交
148
    }
149
    static float& EmbedW(float* val) {
150
      return val[CtrDoublePullValue::EmbedWIndex()];
Y
yaoxuefeng 已提交
151
    }
152
    static float* EmbedxW(float* val) {
153
      return val + CtrDoublePullValue::EmbedxWIndex();
Y
yaoxuefeng 已提交
154 155
    }
  };
156 157
  CtrDoubleAccessor() {}
  virtual ~CtrDoubleAccessor() {}
158
  virtual int Initialize();
159 160
  // 初始化AccessorInfo
  virtual void InitAccessorInfo();
Y
yaoxuefeng 已提交
161
  // 判断该value是否进行shrink
162 163
  virtual bool Shrink(float* value);
  virtual bool NeedExtendMF(float* value);
Y
yaoxuefeng 已提交
164 165 166 167 168
  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
  // param = 0, save all feature
  // param = 1, save delta feature
  // param = 3, save all feature with time decay
169
  virtual bool Save(float* value, int param) override;
Y
yaoxuefeng 已提交
170
  // update delta_score and unseen_days after save
171
  virtual void UpdateStatAfterSave(float* value, int param) override;
Y
yaoxuefeng 已提交
172
  // 判断该value是否保存到ssd
173
  virtual bool SaveSSD(float* value);
Y
yaoxuefeng 已提交
174 175 176 177
  // virtual bool save_cache(float* value, int param, double
  // global_cache_threshold) override;
  // keys不存在时,为values生成随机值
  // 要求value的内存由外部调用者分配完毕
178
  virtual int32_t Create(float** value, size_t num);
Y
yaoxuefeng 已提交
179
  // 从values中选取到select_values中
180
  virtual int32_t Select(float** select_values, const float** values,
Y
yaoxuefeng 已提交
181 182
                         size_t num);
  // 将update_values聚合到一起
183
  virtual int32_t Merge(float** update_values,
Y
yaoxuefeng 已提交
184 185
                        const float** other_update_values, size_t num);
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
186
  // virtual int32_t Merge(float** update_values, iterator it);
Y
yaoxuefeng 已提交
187
  // 将update_values更新应用到values中
188
  virtual int32_t Update(float** values, const float** update_values,
Y
yaoxuefeng 已提交
189
                         size_t num);
190 191 192
  virtual std::string ParseToString(const float* value, int param) override;
  virtual int32_t ParseFromString(const std::string& str, float* v) override;
  virtual bool CreateValue(int type, const float* value);
Y
yaoxuefeng 已提交
193
  //这个接口目前只用来取show
194
  virtual float GetField(float* value, const std::string& name) override {
Y
yaoxuefeng 已提交
195 196
    CHECK(name == "show");
    if (name == "show") {
197
      return (float)CtrDoubleFeatureValue::Show(value);
Y
yaoxuefeng 已提交
198 199 200
    }
    return 0.0;
  }
201 202 203 204
  // DEFINE_GET_INDEX(CtrDoubleFeatureValue, show)
  // DEFINE_GET_INDEX(CtrDoubleFeatureValue, click)
  // DEFINE_GET_INDEX(CtrDoubleFeatureValue, embed_w)
  // DEFINE_GET_INDEX(CtrDoubleFeatureValue, embedx_w)
Y
yaoxuefeng 已提交
205
 private:
206
  double ShowClickScore(double show, double click);
Y
yaoxuefeng 已提交
207 208 209 210 211 212

 private:
  SparseValueSGDRule* _embed_sgd_rule;
  SparseValueSGDRule* _embedx_sgd_rule;
  float _show_click_decay_rate;
  int32_t _ssd_unseenday_threshold;
213
  bool _show_scale = false;
Y
yaoxuefeng 已提交
214 215 216
};
}  // namespace distributed
}  // namespace paddle