ctr_accessor.h 7.9 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
21 22
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
Z
zhaocaibei123 已提交
23 24 25 26

namespace paddle {
namespace distributed {

27
// DownpourUnitAccessor
Z
zhaocaibei123 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
class CtrCommonAccessor : public ValueAccessor {
 public:
  struct CtrCommonFeatureValue {
    /*
       float slot;
       float unseen_days;
       float delta_score;
       float show;
       float click;
       float embed_w;
       std::vector<float> embed_g2sum;
       std::vector<float> embedx_w;
       std::<vector>float embedx_g2sum;
       */

43 44 45 46 47
    int Dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
    int DimSize(size_t dim, int embedx_dim) { return sizeof(float); }
    int Size() { return Dim() * sizeof(float); }
    int SlotIndex() { return 0; }
    int unseen_days_index() { return SlotIndex() + 1; }
Z
zhaocaibei123 已提交
48
    int delta_score_index() { return unseen_days_index() + 1; }
49 50 51 52 53 54
    int ShowIndex() { return delta_score_index() + 1; }
    int ClickIndex() { return ShowIndex() + 1; }
    int Embed_W_Index() { return ClickIndex() + 1; }
    int embed_g2sum_index() { return Embed_W_Index() + 1; }
    int Embedx_W_Index() { return embed_g2sum_index() + embed_sgd_dim; }
    int embedx_g2sum_index() { return Embedx_W_Index() + embedx_dim; }
Z
zhaocaibei123 已提交
55 56 57

    float& unseen_days(float* val) { return val[unseen_days_index()]; }
    float& delta_score(float* val) { return val[delta_score_index()]; }
58 59 60 61
    float& Show(float* val) { return val[ShowIndex()]; }
    float& Click(float* val) { return val[ClickIndex()]; }
    float& Slot(float* val) { return val[SlotIndex()]; }
    float& EmbedW(float* val) { return val[Embed_W_Index()]; }
Z
zhaocaibei123 已提交
62
    float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
63
    float& EmbedxW(float* val) { return val[Embedx_W_Index()]; }
Z
zhaocaibei123 已提交
64
    float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
Z
zhaocaibei123 已提交
65

Z
zhaocaibei123 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79
    int embed_sgd_dim;
    int embedx_dim;
    int embedx_sgd_dim;
  };

  struct CtrCommonPushValue {
    /*
       float slot;
       float show;
       float click;
       float embed_g;
       std::vector<float> embedx_g;
       */

80
    static int Dim(int embedx_dim) { return 4 + embedx_dim; }
Z
zhaocaibei123 已提交
81

82 83 84 85 86 87 88 89
    static int DimSize(int dim, int embedx_dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int SlotIndex() { return 0; }
    static int ShowIndex() { return CtrCommonPushValue::SlotIndex() + 1; }
    static int ClickIndex() { return CtrCommonPushValue::ShowIndex() + 1; }
    static int Embed_G_Index() { return CtrCommonPushValue::ClickIndex() + 1; }
    static int Embedx_G_Index() {
      return CtrCommonPushValue::Embed_G_Index() + 1;
Z
zhaocaibei123 已提交
90
    }
91 92
    static float& Slot(float* val) {
      return val[CtrCommonPushValue::SlotIndex()];
Z
zhaocaibei123 已提交
93
    }
94 95
    static float& Show(float* val) {
      return val[CtrCommonPushValue::ShowIndex()];
Z
zhaocaibei123 已提交
96
    }
97 98
    static float& Click(float* val) {
      return val[CtrCommonPushValue::ClickIndex()];
Z
zhaocaibei123 已提交
99
    }
100 101
    static float& EmbedG(float* val) {
      return val[CtrCommonPushValue::Embed_G_Index()];
Z
zhaocaibei123 已提交
102
    }
103 104
    static float* EmbedxG(float* val) {
      return val + CtrCommonPushValue::Embedx_G_Index();
Z
zhaocaibei123 已提交
105 106 107 108 109
    }
  };

  struct CtrCommonPullValue {
    /*
110 111
       float show;
       float click;
Z
zhaocaibei123 已提交
112 113 114 115
       float embed_w;
       std::vector<float> embedx_w;
       */

116 117 118 119 120 121 122 123 124
    static int Dim(int embedx_dim) { return 3 + embedx_dim; }
    static int DimSize(size_t dim) { return sizeof(float); }
    static int Size(int embedx_dim) { return Dim(embedx_dim) * sizeof(float); }
    static int ShowIndex() { return 0; }
    static int ClickIndex() { return 1; }
    static int Embed_W_Index() { return 2; }
    static int Embedx_W_Index() { return 3; }
    static float& Show(float* val) {
      return val[CtrCommonPullValue::ShowIndex()];
125
    }
126 127
    static float& Click(float* val) {
      return val[CtrCommonPullValue::ClickIndex()];
128
    }
129 130
    static float& EmbedW(float* val) {
      return val[CtrCommonPullValue::Embed_W_Index()];
Z
zhaocaibei123 已提交
131
    }
132 133
    static float* EmbedxW(float* val) {
      return val + CtrCommonPullValue::Embedx_W_Index();
Z
zhaocaibei123 已提交
134 135 136
    }
  };
  CtrCommonAccessor() {}
137
  virtual int Initialize();
Z
zhaocaibei123 已提交
138 139
  virtual ~CtrCommonAccessor() {}

140 141
  virtual void SetTableInfo(AccessorInfo& info);
  virtual size_t GetTableInfo(InfoKey key);
Z
zhaocaibei123 已提交
142
  // value维度
143
  size_t Dim();
Z
zhaocaibei123 已提交
144
  // value各个维度的size
145
  size_t DimSize(size_t dim);
Z
zhaocaibei123 已提交
146
  // value各维度相加总size
147
  size_t Size();
Z
zhaocaibei123 已提交
148
  // value中mf动态长度部分总size大小, sparse下生效
149
  size_t MFSize();
Z
zhaocaibei123 已提交
150
  // pull value维度
151
  size_t SelectDim();
Z
zhaocaibei123 已提交
152
  // pull value各个维度的size
153
  size_t SelectDimSize(size_t dim);
Z
zhaocaibei123 已提交
154
  // pull value各维度相加总size
155
  size_t SelectSize();
Z
zhaocaibei123 已提交
156
  // push value维度
157
  size_t UpdateDim();
Z
zhaocaibei123 已提交
158
  // push value各个维度的size
159
  size_t UpdateDimSize(size_t dim);
Z
zhaocaibei123 已提交
160
  // push value各维度相加总size
161
  size_t UpdateSize();
Z
zhaocaibei123 已提交
162
  // 判断该value是否进行shrink
163
  virtual bool Shrink(float* value);
Z
zhaocaibei123 已提交
164 165
  // 判断该value是否保存到ssd
  // virtual bool save_ssd(float* value);
166 167
  virtual bool NeedExtendMF(float* value);
  virtual bool HasMF(size_t size);
Z
zhaocaibei123 已提交
168 169 170 171 172
  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
  // param = 0, save all feature
  // param = 1, save delta feature
  // param = 2, save xbox base feature
173
  bool Save(float* value, int param) override;
Z
zhaocaibei123 已提交
174
  // update delta_score and unseen_days after save
175
  void UpdateStatAfterSave(float* value, int param) override;
Z
zhaocaibei123 已提交
176 177
  // keys不存在时,为values生成随机值
  // 要求value的内存由外部调用者分配完毕
178
  virtual int32_t Create(float** value, size_t num);
Z
zhaocaibei123 已提交
179
  // 从values中选取到select_values中
180
  virtual int32_t Select(float** select_values, const float** values,
Z
zhaocaibei123 已提交
181 182
                         size_t num);
  // 将update_values聚合到一起
183
  virtual int32_t Merge(float** update_values,
Z
zhaocaibei123 已提交
184 185
                        const float** other_update_values, size_t num);
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
186
  // virtual int32_t Merge(float** update_values, iterator it);
Z
zhaocaibei123 已提交
187
  // 将update_values更新应用到values中
188
  virtual int32_t Update(float** values, const float** update_values,
Z
zhaocaibei123 已提交
189 190
                         size_t num);

191 192 193
  std::string ParseToString(const float* value, int param) override;
  int32_t ParseFromString(const std::string& str, float* v) override;
  virtual bool CreateValue(int type, const float* value);
Z
zhaocaibei123 已提交
194 195

  // 这个接口目前只用来取show
196
  float GetField(float* value, const std::string& name) override {
Z
zhaocaibei123 已提交
197 198
    // CHECK(name == "show");
    if (name == "show") {
199
      return common_feature_value.Show(value);
Z
zhaocaibei123 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    }
    return 0.0;
  }

 private:
  // float show_click_score(float show, float click);

  // SparseValueSGDRule* _embed_sgd_rule;
  // SparseValueSGDRule* _embedx_sgd_rule;
  // CtrCommonFeatureValue common_feature_value;
  float _show_click_decay_rate;
  int32_t _ssd_unseenday_threshold;

 public:  // TODO(zhaocaibei123): it should be private, but we make it public
          // for unit test
  CtrCommonFeatureValue common_feature_value;
  float show_click_score(float show, float click);
  SparseValueSGDRule* _embed_sgd_rule;
  SparseValueSGDRule* _embedx_sgd_rule;
};
}  // namespace distributed
}  // namespace paddle