accessor.h 5.6 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
20
#include "paddle/fluid/distributed/common/afs_warpper.h"
T
tangwei12 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"

namespace paddle {
namespace distributed {

struct Region {
  Region() : data(NULL), size(0) {}
  Region(char* data, size_t data_num) : data(data), size(data_num) {}
  Region(float* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int16_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
  Region(int32_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int64_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
  char* data;
  size_t size;
};

struct DataConverter {
  int param;
  std::string converter;
  std::string deconverter;
};

Y
yaoxuefeng 已提交
48 49 50 51 52 53 54 55 56 57 58
struct AccessorInfo {
  size_t dim;
  size_t size;
  size_t select_size;
  size_t select_dim;
  size_t update_size;
  size_t update_dim;
  size_t mf_size;
  size_t fea_dim;
};

59 60 61 62 63 64 65 66 67 68 69
enum InfoKey {
  DIM = 0,
  SIZE = 1,
  SELECT_SIZE = 2,
  SELECT_DIM = 3,
  UPDATE_SIZE = 4,
  UPDATE_DIM = 5,
  MF_SIZE = 6,
  FEA_DIM = 7
};

T
tangwei12 已提交
70 71
class ValueAccessor {
 public:
72 73
  ValueAccessor() {}
  virtual ~ValueAccessor() {}
T
tangwei12 已提交
74

75
  virtual int Configure(const TableAccessorParameter& parameter) {
T
tangwei12 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    _config = parameter;
    // data_convert结构体初始化
    if (_config.table_accessor_save_param_size() != 0) {
      for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
        int param = _config.table_accessor_save_param(i).param();
        std::string converter =
            _config.table_accessor_save_param(i).converter();
        std::string deconverter =
            _config.table_accessor_save_param(i).deconverter();
        _data_coverter_map[param] = std::make_shared<DataConverter>();
        *(_data_coverter_map[param]) = {param, converter, deconverter};
      }
    }
    return 0;
  }
91
  virtual int Initialize() = 0;
T
tangwei12 已提交
92

93 94
  virtual void SetTableInfo(AccessorInfo& info) = 0;
  virtual size_t GetTableInfo(InfoKey key) = 0;
Y
yaoxuefeng 已提交
95

96 97
  virtual bool NeedExtendMF(float* value) { return false; }
  virtual bool HasMF(size_t size) { return false; }
T
tangwei12 已提交
98
  // converter for save
99
  virtual std::string GetConverter(int param) {
T
tangwei12 已提交
100 101 102 103 104 105 106 107
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->converter;
    }
  }
  // deconverter for load
108
  virtual std::string GetDeconverter(int param) {
T
tangwei12 已提交
109 110 111 112 113 114 115 116
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->deconverter;
    }
  }
  // 判断该value是否进行shrink
117
  virtual bool Shrink(float* value) = 0;
T
tangwei12 已提交
118 119 120

  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
121
  virtual bool Save(float* value, int param) = 0;
T
tangwei12 已提交
122
  // update delta_score and unseen_days after save
123
  virtual void UpdateStatAfterSave(float* value, int param) {}
T
tangwei12 已提交
124 125

  // keys不存在时,为values生成随机值
126 127
  virtual int32_t Create(float** value, size_t num) = 0;
  virtual bool CreateValue(int type, const float* value) { return true; }
T
tangwei12 已提交
128
  // 从values中选取到select_values中
129
  virtual int32_t Select(float** select_values, const float** values,
T
tangwei12 已提交
130 131
                         size_t num) = 0;
  // 将update_values聚合到一起
132
  virtual int32_t Merge(float** update_values,
T
tangwei12 已提交
133 134
                        const float** other_update_values, size_t num) = 0;
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
135
  // virtual int32_t Merge(float** update_values, iterator it);
T
tangwei12 已提交
136
  // 将update_values更新应用到values中
137
  virtual int32_t Update(float** values, const float** update_values,
T
tangwei12 已提交
138 139 140
                         size_t num) = 0;

  // used to save model, will filter feature
141
  virtual std::string ParseToString(const float* value, int param) = 0;
T
tangwei12 已提交
142
  //  parse value from string, used to load model
143
  virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
T
tangwei12 已提交
144

145
  virtual FsDataConverter Converter(int param) {
T
tangwei12 已提交
146
    FsDataConverter data_convert;
147 148
    data_convert.converter = this->GetConverter(param);
    data_convert.deconverter = this->GetDeconverter(param);
T
tangwei12 已提交
149 150 151
    return data_convert;
  }

152 153
  virtual int SetWeight(float** values, const float** update_values,
                        size_t num) {
T
tangwei12 已提交
154 155 156
    return 0;
  }

157
  virtual float GetField(float* value, const std::string& name) { return 0.0; }
Y
yaoxuefeng 已提交
158 159
#define DEFINE_GET_INDEX(class, field) \
  virtual int get_##field##_index() override { return class ::field##_index(); }
T
tangwei12 已提交
160 161 162 163 164 165 166 167

 protected:
  size_t _value_size;
  size_t _select_value_size;
  size_t _update_value_size;
  TableAccessorParameter _config;
  std::unordered_map<int, std::shared_ptr<struct DataConverter>>
      _data_coverter_map;
Y
yaoxuefeng 已提交
168
  AccessorInfo _accessor_info;
T
tangwei12 已提交
169
};
T
tangwei12 已提交
170
REGISTER_PSCORE_REGISTERER(ValueAccessor);
T
tangwei12 已提交
171 172
}  // namespace distributed
}  // namespace paddle