ctr_double_accessor.h 8.8 KB
Newer Older
Y
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"

namespace paddle {
namespace distributed {

class DownpourCtrDoubleAccessor : public ValueAccessor {
 public:
  struct DownpourCtrDoubleFeatureValue {
    /*
    float unseen_days;
    float delta_score;
    double show;
    double click;
    float embed_w;
    float embed_g2sum;
    float slot;
    float embedx_g2sum;
    std::vector<float> embedx_w;
    */
41 42 43 44
    static int Dim(int embedx_dim) { return 8 + embedx_dim; }
    static int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
    static int Size(int embedx_dim) {
      return (Dim(embedx_dim) + 2) * sizeof(float);
Y
yaoxuefeng 已提交
45 46 47 48 49
    }
    static int unseen_days_index() { return 0; }
    static int delta_score_index() {
      return DownpourCtrDoubleFeatureValue::unseen_days_index() + 1;
    }
50
    static int ShowIndex() {
Y
yaoxuefeng 已提交
51 52 53
      return DownpourCtrDoubleFeatureValue::delta_score_index() + 1;
    }
    // show is double
54 55
    static int ClickIndex() {
      return DownpourCtrDoubleFeatureValue::ShowIndex() + 2;
Y
yaoxuefeng 已提交
56 57
    }
    // click is double
58 59
    static int Embed_W_Index() {
      return DownpourCtrDoubleFeatureValue::ClickIndex() + 2;
Y
yaoxuefeng 已提交
60 61
    }
    static int embed_g2sum_index() {
62
      return DownpourCtrDoubleFeatureValue::Embed_W_Index() + 1;
Y
yaoxuefeng 已提交
63
    }
64
    static int SlotIndex() {
Y
yaoxuefeng 已提交
65 66 67
      return DownpourCtrDoubleFeatureValue::embed_g2sum_index() + 1;
    }
    static int embedx_g2sum_index() {
68
      return DownpourCtrDoubleFeatureValue::SlotIndex() + 1;
Y
yaoxuefeng 已提交
69
    }
70
    static int Embedx_W_Index() {
Y
yaoxuefeng 已提交
71 72 73 74 75 76 77 78
      return DownpourCtrDoubleFeatureValue::embedx_g2sum_index() + 1;
    }
    static float& unseen_days(float* val) {
      return val[DownpourCtrDoubleFeatureValue::unseen_days_index()];
    }
    static float& delta_score(float* val) {
      return val[DownpourCtrDoubleFeatureValue::delta_score_index()];
    }
79 80
    static double& Show(float* val) {
      return ((double*)(val + DownpourCtrDoubleFeatureValue::ShowIndex()))[0];
Y
yaoxuefeng 已提交
81
    }
82 83
    static double& Click(float* val) {
      return ((double*)(val + DownpourCtrDoubleFeatureValue::ClickIndex()))[0];
Y
yaoxuefeng 已提交
84
    }
85 86
    static float& Slot(float* val) {
      return val[DownpourCtrDoubleFeatureValue::SlotIndex()];
Y
yaoxuefeng 已提交
87
    }
88 89
    static float& EmbedW(float* val) {
      return val[DownpourCtrDoubleFeatureValue::Embed_W_Index()];
Y
yaoxuefeng 已提交
90 91 92 93 94 95 96
    }
    static float& embed_g2sum(float* val) {
      return val[DownpourCtrDoubleFeatureValue::embed_g2sum_index()];
    }
    static float& embedx_g2sum(float* val) {
      return val[DownpourCtrDoubleFeatureValue::embedx_g2sum_index()];
    }
97 98
    static float* EmbedxW(float* val) {
      return (val + DownpourCtrDoubleFeatureValue::Embedx_W_Index());
Y
yaoxuefeng 已提交
99 100 101 102 103 104 105 106 107 108
    }
  };
  struct DownpourCtrDoublePushValue {
    /*
    float slot;
    float show;
    float click;
    float embed_g;
    std::vector<float> embedx_g;
    */
109 110 111 112 113 114
    static int Dim(int embedx_dim) { return 4 + embedx_dim; }
    static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int SlotIndex() { return 0; }
    static int ShowIndex() {
      return DownpourCtrDoublePushValue::SlotIndex() + 1;
Y
yaoxuefeng 已提交
115
    }
116 117
    static int ClickIndex() {
      return DownpourCtrDoublePushValue::ShowIndex() + 1;
Y
yaoxuefeng 已提交
118
    }
119 120
    static int Embed_G_Index() {
      return DownpourCtrDoublePushValue::ClickIndex() + 1;
Y
yaoxuefeng 已提交
121
    }
122 123
    static int Embedx_G_Index() {
      return DownpourCtrDoublePushValue::Embed_G_Index() + 1;
Y
yaoxuefeng 已提交
124
    }
125 126
    static float& Slot(float* val) {
      return val[DownpourCtrDoublePushValue::SlotIndex()];
Y
yaoxuefeng 已提交
127
    }
128 129
    static float& Show(float* val) {
      return val[DownpourCtrDoublePushValue::ShowIndex()];
Y
yaoxuefeng 已提交
130
    }
131 132
    static float& Click(float* val) {
      return val[DownpourCtrDoublePushValue::ClickIndex()];
Y
yaoxuefeng 已提交
133
    }
134 135
    static float& EmbedG(float* val) {
      return val[DownpourCtrDoublePushValue::Embed_G_Index()];
Y
yaoxuefeng 已提交
136
    }
137 138
    static float* EmbedxG(float* val) {
      return val + DownpourCtrDoublePushValue::Embedx_G_Index();
Y
yaoxuefeng 已提交
139 140 141 142 143 144 145 146 147
    }
  };
  struct DownpourCtrDoublePullValue {
    /*
    float show;
    float click;
    float embed_w;
    std::vector<float> embedx_w;
    */
148 149 150 151 152 153 154 155 156
    static int Dim(int embedx_dim) { return 3 + embedx_dim; }
    static int DimSize(size_t dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int ShowIndex() { return 0; }
    static int ClickIndex() { return 1; }
    static int Embed_W_Index() { return 2; }
    static int Embedx_W_Index() { return 3; }
    static float& Show(float* val) {
      return val[DownpourCtrDoublePullValue::ShowIndex()];
Y
yaoxuefeng 已提交
157
    }
158 159
    static float& Click(float* val) {
      return val[DownpourCtrDoublePullValue::ClickIndex()];
Y
yaoxuefeng 已提交
160
    }
161 162
    static float& EmbedW(float* val) {
      return val[DownpourCtrDoublePullValue::Embed_W_Index()];
Y
yaoxuefeng 已提交
163
    }
164 165
    static float* EmbedxW(float* val) {
      return val + DownpourCtrDoublePullValue::Embedx_W_Index();
Y
yaoxuefeng 已提交
166 167 168 169
    }
  };
  DownpourCtrDoubleAccessor() {}
  virtual ~DownpourCtrDoubleAccessor() {}
170
  virtual int Initialize();
171 172
  virtual void SetTableInfo(AccessorInfo& info);
  virtual size_t GetTableInfo(InfoKey key);
Y
yaoxuefeng 已提交
173
  // value维度
174
  size_t Dim();
Y
yaoxuefeng 已提交
175
  // value各个维度的size
176
  size_t DimSize(size_t dim);
Y
yaoxuefeng 已提交
177
  // value各维度相加总size
178
  size_t Size();
Y
yaoxuefeng 已提交
179
  // value中mf动态长度部分总size大小, sparse下生效
180
  size_t MFSize();
Y
yaoxuefeng 已提交
181
  // pull value维度
182
  size_t SelectDim();
Y
yaoxuefeng 已提交
183
  // pull value各个维度的size
184
  size_t SelectDimSize(size_t dim);
Y
yaoxuefeng 已提交
185
  // pull value各维度相加总size
186
  size_t SelectSize();
Y
yaoxuefeng 已提交
187
  // push value维度
188
  size_t UpdateDim();
Y
yaoxuefeng 已提交
189
  // push value各个维度的size
190
  size_t UpdateDimSize(size_t dim);
Y
yaoxuefeng 已提交
191
  // push value各维度相加总size
192
  size_t UpdateSize();
Y
yaoxuefeng 已提交
193
  // 判断该value是否进行shrink
194 195
  virtual bool Shrink(float* value);
  virtual bool NeedExtendMF(float* value);
Y
yaoxuefeng 已提交
196 197 198 199 200
  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
  // param = 0, save all feature
  // param = 1, save delta feature
  // param = 3, save all feature with time decay
201
  virtual bool Save(float* value, int param) override;
Y
yaoxuefeng 已提交
202
  // update delta_score and unseen_days after save
203
  virtual void UpdateStatAfterSave(float* value, int param) override;
Y
yaoxuefeng 已提交
204 205 206 207 208 209
  // 判断该value是否保存到ssd
  virtual bool save_ssd(float* value);
  // virtual bool save_cache(float* value, int param, double
  // global_cache_threshold) override;
  // keys不存在时,为values生成随机值
  // 要求value的内存由外部调用者分配完毕
210
  virtual int32_t Create(float** value, size_t num);
Y
yaoxuefeng 已提交
211
  // 从values中选取到select_values中
212
  virtual int32_t Select(float** select_values, const float** values,
Y
yaoxuefeng 已提交
213 214
                         size_t num);
  // 将update_values聚合到一起
215
  virtual int32_t Merge(float** update_values,
Y
yaoxuefeng 已提交
216 217
                        const float** other_update_values, size_t num);
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
218
  // virtual int32_t Merge(float** update_values, iterator it);
Y
yaoxuefeng 已提交
219
  // 将update_values更新应用到values中
220
  virtual int32_t Update(float** values, const float** update_values,
Y
yaoxuefeng 已提交
221
                         size_t num);
222 223 224
  virtual std::string ParseToString(const float* value, int param) override;
  virtual int32_t ParseFromString(const std::string& str, float* v) override;
  virtual bool CreateValue(int type, const float* value);
Y
yaoxuefeng 已提交
225
  //这个接口目前只用来取show
226
  virtual float GetField(float* value, const std::string& name) override {
Y
yaoxuefeng 已提交
227 228
    CHECK(name == "show");
    if (name == "show") {
229
      return (float)DownpourCtrDoubleFeatureValue::Show(value);
Y
yaoxuefeng 已提交
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    }
    return 0.0;
  }
  // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, show)
  // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, click)
  // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embed_w)
  // DEFINE_GET_INDEX(DownpourCtrDoubleFeatureValue, embedx_w)
 private:
  double show_click_score(double show, double click);

 private:
  SparseValueSGDRule* _embed_sgd_rule;
  SparseValueSGDRule* _embedx_sgd_rule;
  float _show_click_decay_rate;
  int32_t _ssd_unseenday_threshold;
};
}  // namespace distributed
}  // namespace paddle