accessor.h 5.9 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
18

T
tangwei12 已提交
19 20
#include <unordered_map>
#include <vector>
21

22
#include "paddle/fluid/distributed/common/afs_warpper.h"
T
tangwei12 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"

namespace paddle {
namespace distributed {

struct Region {
  Region() : data(NULL), size(0) {}
  Region(char* data, size_t data_num) : data(data), size(data_num) {}
  Region(float* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int16_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
  Region(int32_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int64_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
  char* data;
  size_t size;
};

struct DataConverter {
  int param;
  std::string converter;
  std::string deconverter;
};

Y
yaoxuefeng 已提交
50
struct AccessorInfo {
51
  // value维度
Y
yaoxuefeng 已提交
52
  size_t dim;
53
  // value各个维度的size
Y
yaoxuefeng 已提交
54
  size_t size;
55
  // pull value维度
Y
yaoxuefeng 已提交
56
  size_t select_dim;
57 58 59
  // pull value各维度相加总size
  size_t select_size;
  // push value维度
Y
yaoxuefeng 已提交
60
  size_t update_dim;
61 62 63
  // push value各个维度的size
  size_t update_size;
  // value中mf动态长度部分总size大小, sparse下生效
Y
yaoxuefeng 已提交
64
  size_t mf_size;
65
  // value总维度,dense下生效
Y
yaoxuefeng 已提交
66 67 68
  size_t fea_dim;
};

T
tangwei12 已提交
69 70
class ValueAccessor {
 public:
71 72
  ValueAccessor() {}
  virtual ~ValueAccessor() {}
T
tangwei12 已提交
73

74
  virtual int Configure(const TableAccessorParameter& parameter) {
T
tangwei12 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    _config = parameter;
    // data_convert结构体初始化
    if (_config.table_accessor_save_param_size() != 0) {
      for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
        int param = _config.table_accessor_save_param(i).param();
        std::string converter =
            _config.table_accessor_save_param(i).converter();
        std::string deconverter =
            _config.table_accessor_save_param(i).deconverter();
        _data_coverter_map[param] = std::make_shared<DataConverter>();
        *(_data_coverter_map[param]) = {param, converter, deconverter};
      }
    }
    return 0;
  }
90
  virtual int Initialize() = 0;
T
tangwei12 已提交
91

92
  virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
Y
yaoxuefeng 已提交
93

94 95
  virtual bool NeedExtendMF(float* value) { return false; }
  virtual bool HasMF(size_t size) { return false; }
T
tangwei12 已提交
96
  // converter for save
97
  virtual std::string GetConverter(int param) {
T
tangwei12 已提交
98 99 100 101 102 103 104 105
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->converter;
    }
  }
  // deconverter for load
106
  virtual std::string GetDeconverter(int param) {
T
tangwei12 已提交
107 108 109 110 111 112 113 114
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->deconverter;
    }
  }
  // 判断该value是否进行shrink
115
  virtual bool Shrink(float* value) = 0;
T
tangwei12 已提交
116 117 118

  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
119
  virtual bool Save(float* value, int param) = 0;
T
tangwei12 已提交
120
  // update delta_score and unseen_days after save
121
  virtual void UpdateStatAfterSave(float* value, int param) {}
Z
zhaocaibei123 已提交
122 123 124 125 126
  // 判断该value是否保存到ssd
  virtual bool SaveSSD(float* value) = 0;
  //
  virtual bool SaveCache(float* value, int param,
                         double global_cache_threshold) = 0;
T
tangwei12 已提交
127 128

  // keys不存在时,为values生成随机值
129 130
  virtual int32_t Create(float** value, size_t num) = 0;
  virtual bool CreateValue(int type, const float* value) { return true; }
T
tangwei12 已提交
131
  // 从values中选取到select_values中
132
  virtual int32_t Select(float** select_values, const float** values,
T
tangwei12 已提交
133 134
                         size_t num) = 0;
  // 将update_values聚合到一起
135
  virtual int32_t Merge(float** update_values,
T
tangwei12 已提交
136 137
                        const float** other_update_values, size_t num) = 0;
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
138
  // virtual int32_t Merge(float** update_values, iterator it);
T
tangwei12 已提交
139
  // 将update_values更新应用到values中
140
  virtual int32_t Update(float** values, const float** update_values,
T
tangwei12 已提交
141 142 143
                         size_t num) = 0;

  // used to save model, will filter feature
144
  virtual std::string ParseToString(const float* value, int param) = 0;
T
tangwei12 已提交
145
  //  parse value from string, used to load model
146
  virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
T
tangwei12 已提交
147

148
  virtual FsDataConverter Converter(int param) {
T
tangwei12 已提交
149
    FsDataConverter data_convert;
150 151
    data_convert.converter = this->GetConverter(param);
    data_convert.deconverter = this->GetDeconverter(param);
T
tangwei12 已提交
152 153 154
    return data_convert;
  }

155 156
  virtual int SetWeight(float** values, const float** update_values,
                        size_t num) {
T
tangwei12 已提交
157 158 159
    return 0;
  }

160
  virtual float GetField(float* value, const std::string& name) { return 0.0; }
Y
yaoxuefeng 已提交
161 162
#define DEFINE_GET_INDEX(class, field) \
  virtual int get_##field##_index() override { return class ::field##_index(); }
T
tangwei12 已提交
163 164 165 166 167 168 169 170

 protected:
  size_t _value_size;
  size_t _select_value_size;
  size_t _update_value_size;
  TableAccessorParameter _config;
  std::unordered_map<int, std::shared_ptr<struct DataConverter>>
      _data_coverter_map;
Y
yaoxuefeng 已提交
171
  AccessorInfo _accessor_info;
T
tangwei12 已提交
172
};
T
tangwei12 已提交
173
REGISTER_PSCORE_REGISTERER(ValueAccessor);
T
tangwei12 已提交
174 175
}  // namespace distributed
}  // namespace paddle