ctr_accessor.h 8.0 KB
Newer Older
Z
zhaocaibei123 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <vector>
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"
21 22
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/sparse_sgd_rule.h"
Z
zhaocaibei123 已提交
23 24 25 26

namespace paddle {
namespace distributed {

27
// DownpourUnitAccessor
Z
zhaocaibei123 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
class CtrCommonAccessor : public ValueAccessor {
 public:
  struct CtrCommonFeatureValue {
    /*
       float slot;
       float unseen_days;
       float delta_score;
       float show;
       float click;
       float embed_w;
       std::vector<float> embed_g2sum;
       std::vector<float> embedx_w;
       std::<vector>float embedx_g2sum;
       */

    int dim() { return 6 + embed_sgd_dim + embedx_sgd_dim + embedx_dim; }
    int dim_size(size_t dim, int embedx_dim) { return sizeof(float); }
    int size() { return dim() * sizeof(float); }
    int slot_index() { return 0; }
    int unseen_days_index() { return slot_index() + 1; }
    int delta_score_index() { return unseen_days_index() + 1; }
    int show_index() { return delta_score_index() + 1; }
    int click_index() { return show_index() + 1; }
    int embed_w_index() { return click_index() + 1; }
    int embed_g2sum_index() { return embed_w_index() + 1; }
    int embedx_w_index() { return embed_g2sum_index() + embed_sgd_dim; }
    int embedx_g2sum_index() { return embedx_w_index() + embedx_dim; }

    float& unseen_days(float* val) { return val[unseen_days_index()]; }
    float& delta_score(float* val) { return val[delta_score_index()]; }
    float& show(float* val) { return val[show_index()]; }
    float& click(float* val) { return val[click_index()]; }
    float& slot(float* val) { return val[slot_index()]; }
    float& embed_w(float* val) { return val[embed_w_index()]; }
    float& embed_g2sum(float* val) { return val[embed_g2sum_index()]; }
    float& embedx_w(float* val) { return val[embedx_w_index()]; }
    float& embedx_g2sum(float* val) { return val[embedx_g2sum_index()]; }
Z
zhaocaibei123 已提交
65

Z
zhaocaibei123 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    int embed_sgd_dim;
    int embedx_dim;
    int embedx_sgd_dim;
  };

  struct CtrCommonPushValue {
    /*
       float slot;
       float show;
       float click;
       float embed_g;
       std::vector<float> embedx_g;
       */

    static int dim(int embedx_dim) { return 4 + embedx_dim; }

    static int dim_size(int dim, int embedx_dim) { return sizeof(float); }
    static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
    static int slot_index() { return 0; }
    static int show_index() { return CtrCommonPushValue::slot_index() + 1; }
    static int click_index() { return CtrCommonPushValue::show_index() + 1; }
    static int embed_g_index() { return CtrCommonPushValue::click_index() + 1; }
    static int embedx_g_index() {
      return CtrCommonPushValue::embed_g_index() + 1;
    }
    static float& slot(float* val) {
      return val[CtrCommonPushValue::slot_index()];
    }
    static float& show(float* val) {
      return val[CtrCommonPushValue::show_index()];
    }
    static float& click(float* val) {
      return val[CtrCommonPushValue::click_index()];
    }
    static float& embed_g(float* val) {
      return val[CtrCommonPushValue::embed_g_index()];
    }
    static float* embedx_g(float* val) {
      return val + CtrCommonPushValue::embedx_g_index();
    }
  };

  struct CtrCommonPullValue {
    /*
110 111
       float show;
       float click;
Z
zhaocaibei123 已提交
112 113 114 115
       float embed_w;
       std::vector<float> embedx_w;
       */

116
    static int dim(int embedx_dim) { return 3 + embedx_dim; }
Z
zhaocaibei123 已提交
117 118
    static int dim_size(size_t dim) { return sizeof(float); }
    static int size(int embedx_dim) { return dim(embedx_dim) * sizeof(float); }
119 120 121 122 123 124 125 126 127 128
    static int show_index() { return 0; }
    static int click_index() { return 1; }
    static int embed_w_index() { return 2; }
    static int embedx_w_index() { return 3; }
    static float& show(float* val) {
      return val[CtrCommonPullValue::show_index()];
    }
    static float& click(float* val) {
      return val[CtrCommonPullValue::click_index()];
    }
Z
zhaocaibei123 已提交
129 130 131 132 133 134 135 136 137 138 139
    static float& embed_w(float* val) {
      return val[CtrCommonPullValue::embed_w_index()];
    }
    static float* embedx_w(float* val) {
      return val + CtrCommonPullValue::embedx_w_index();
    }
  };
  CtrCommonAccessor() {}
  virtual int initialize();
  virtual ~CtrCommonAccessor() {}

140 141
  virtual void SetTableInfo(AccessorInfo& info);
  virtual size_t GetTableInfo(InfoKey key);
Z
zhaocaibei123 已提交
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  // value维度
  virtual size_t dim();
  // value各个维度的size
  virtual size_t dim_size(size_t dim);
  // value各维度相加总size
  virtual size_t size();
  // value中mf动态长度部分总size大小, sparse下生效
  virtual size_t mf_size();
  // pull value维度
  virtual size_t select_dim();
  // pull value各个维度的size
  virtual size_t select_dim_size(size_t dim);
  // pull value各维度相加总size
  virtual size_t select_size();
  // push value维度
  virtual size_t update_dim();
  // push value各个维度的size
  virtual size_t update_dim_size(size_t dim);
  // push value各维度相加总size
  virtual size_t update_size();
  // 判断该value是否进行shrink
  virtual bool shrink(float* value);
  // 判断该value是否保存到ssd
  // virtual bool save_ssd(float* value);
  virtual bool need_extend_mf(float* value);
  virtual bool has_mf(size_t size);
  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
  // param = 0, save all feature
  // param = 1, save delta feature
  // param = 2, save xbox base feature
  bool save(float* value, int param) override;
  // update delta_score and unseen_days after save
  void update_stat_after_save(float* value, int param) override;
  // keys不存在时,为values生成随机值
  // 要求value的内存由外部调用者分配完毕
  virtual int32_t create(float** value, size_t num);
  // 从values中选取到select_values中
  virtual int32_t select(float** select_values, const float** values,
                         size_t num);
  // 将update_values聚合到一起
  virtual int32_t merge(float** update_values,
                        const float** other_update_values, size_t num);
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
  // virtual int32_t merge(float** update_values, iterator it);
  // 将update_values更新应用到values中
  virtual int32_t update(float** values, const float** update_values,
                         size_t num);

  std::string parse_to_string(const float* value, int param) override;
  int32_t parse_from_string(const std::string& str, float* v) override;
  virtual bool create_value(int type, const float* value);

  // 这个接口目前只用来取show
  float get_field(float* value, const std::string& name) override {
    // CHECK(name == "show");
    if (name == "show") {
      return common_feature_value.show(value);
    }
    return 0.0;
  }

 private:
  // float show_click_score(float show, float click);

  // SparseValueSGDRule* _embed_sgd_rule;
  // SparseValueSGDRule* _embedx_sgd_rule;
  // CtrCommonFeatureValue common_feature_value;
  float _show_click_decay_rate;
  int32_t _ssd_unseenday_threshold;

 public:  // TODO(zhaocaibei123): it should be private, but we make it public
          // for unit test
  CtrCommonFeatureValue common_feature_value;
  float show_click_score(float show, float click);
  SparseValueSGDRule* _embed_sgd_rule;
  SparseValueSGDRule* _embedx_sgd_rule;
};
}  // namespace distributed
}  // namespace paddle