accessor.h 5.7 KB
Newer Older
T
tangwei12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <stdint.h>
#include <stdio.h>
#include <unordered_map>
#include <vector>
20
#include "paddle/fluid/distributed/common/afs_warpper.h"
T
tangwei12 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
#include "paddle/fluid/distributed/common/registerer.h"
#include "paddle/fluid/distributed/ps.pb.h"

namespace paddle {
namespace distributed {

struct Region {
  Region() : data(NULL), size(0) {}
  Region(char* data, size_t data_num) : data(data), size(data_num) {}
  Region(float* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int16_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 1) {}
  Region(int32_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 2) {}
  Region(int64_t* data, size_t data_num)
      : data(reinterpret_cast<char*>(data)), size(data_num << 3) {}
  char* data;
  size_t size;
};

struct DataConverter {
  int param;
  std::string converter;
  std::string deconverter;
};

Y
yaoxuefeng 已提交
48
struct AccessorInfo {
49
  // value维度
Y
yaoxuefeng 已提交
50
  size_t dim;
51
  // value各个维度的size
Y
yaoxuefeng 已提交
52
  size_t size;
53
  // pull value维度
Y
yaoxuefeng 已提交
54
  size_t select_dim;
55 56 57
  // pull value各维度相加总size
  size_t select_size;
  // push value维度
Y
yaoxuefeng 已提交
58
  size_t update_dim;
59 60 61
  // push value各个维度的size
  size_t update_size;
  // value中mf动态长度部分总size大小, sparse下生效
Y
yaoxuefeng 已提交
62
  size_t mf_size;
63
  // value总维度,dense下生效
Y
yaoxuefeng 已提交
64 65 66
  size_t fea_dim;
};

T
tangwei12 已提交
67 68
class ValueAccessor {
 public:
69 70
  ValueAccessor() {}
  virtual ~ValueAccessor() {}
T
tangwei12 已提交
71

72
  virtual int Configure(const TableAccessorParameter& parameter) {
T
tangwei12 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    _config = parameter;
    // data_convert结构体初始化
    if (_config.table_accessor_save_param_size() != 0) {
      for (int i = 0; i < _config.table_accessor_save_param_size(); ++i) {
        int param = _config.table_accessor_save_param(i).param();
        std::string converter =
            _config.table_accessor_save_param(i).converter();
        std::string deconverter =
            _config.table_accessor_save_param(i).deconverter();
        _data_coverter_map[param] = std::make_shared<DataConverter>();
        *(_data_coverter_map[param]) = {param, converter, deconverter};
      }
    }
    return 0;
  }
88
  virtual int Initialize() = 0;
T
tangwei12 已提交
89

90
  virtual AccessorInfo GetAccessorInfo() { return _accessor_info; }
Y
yaoxuefeng 已提交
91

92 93
  virtual bool NeedExtendMF(float* value) { return false; }
  virtual bool HasMF(size_t size) { return false; }
T
tangwei12 已提交
94
  // converter for save
95
  virtual std::string GetConverter(int param) {
T
tangwei12 已提交
96 97 98 99 100 101 102 103
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->converter;
    }
  }
  // deconverter for load
104
  virtual std::string GetDeconverter(int param) {
T
tangwei12 已提交
105 106 107 108 109 110 111 112
    auto itr = _data_coverter_map.find(param);
    if (itr == _data_coverter_map.end()) {
      return "";
    } else {
      return (*itr).second->deconverter;
    }
  }
  // 判断该value是否进行shrink
113
  virtual bool Shrink(float* value) = 0;
T
tangwei12 已提交
114 115 116

  // 判断该value是否在save阶段dump,
  // param作为参数用于标识save阶段,如downpour的xbox与batch_model
117
  virtual bool Save(float* value, int param) = 0;
T
tangwei12 已提交
118
  // update delta_score and unseen_days after save
119
  virtual void UpdateStatAfterSave(float* value, int param) {}
T
tangwei12 已提交
120 121

  // keys不存在时,为values生成随机值
122 123
  virtual int32_t Create(float** value, size_t num) = 0;
  virtual bool CreateValue(int type, const float* value) { return true; }
T
tangwei12 已提交
124
  // 从values中选取到select_values中
125
  virtual int32_t Select(float** select_values, const float** values,
T
tangwei12 已提交
126 127
                         size_t num) = 0;
  // 将update_values聚合到一起
128
  virtual int32_t Merge(float** update_values,
T
tangwei12 已提交
129 130
                        const float** other_update_values, size_t num) = 0;
  // 将update_values聚合到一起,通过it.next判定是否进入下一个key
131
  // virtual int32_t Merge(float** update_values, iterator it);
T
tangwei12 已提交
132
  // 将update_values更新应用到values中
133
  virtual int32_t Update(float** values, const float** update_values,
T
tangwei12 已提交
134 135 136
                         size_t num) = 0;

  // used to save model, will filter feature
137
  virtual std::string ParseToString(const float* value, int param) = 0;
T
tangwei12 已提交
138
  //  parse value from string, used to load model
139
  virtual int32_t ParseFromString(const std::string& data, float* value) = 0;
T
tangwei12 已提交
140

141
  virtual FsDataConverter Converter(int param) {
T
tangwei12 已提交
142
    FsDataConverter data_convert;
143 144
    data_convert.converter = this->GetConverter(param);
    data_convert.deconverter = this->GetDeconverter(param);
T
tangwei12 已提交
145 146 147
    return data_convert;
  }

148 149
  virtual int SetWeight(float** values, const float** update_values,
                        size_t num) {
T
tangwei12 已提交
150 151 152
    return 0;
  }

153
  virtual float GetField(float* value, const std::string& name) { return 0.0; }
Y
yaoxuefeng 已提交
154 155
#define DEFINE_GET_INDEX(class, field) \
  virtual int get_##field##_index() override { return class ::field##_index(); }
T
tangwei12 已提交
156 157 158 159 160 161 162 163

 protected:
  size_t _value_size;
  size_t _select_value_size;
  size_t _update_value_size;
  TableAccessorParameter _config;
  std::unordered_map<int, std::shared_ptr<struct DataConverter>>
      _data_coverter_map;
Y
yaoxuefeng 已提交
164
  AccessorInfo _accessor_info;
T
tangwei12 已提交
165
};
T
tangwei12 已提交
166
REGISTER_PSCORE_REGISTERER(ValueAccessor);
T
tangwei12 已提交
167 168
}  // namespace distributed
}  // namespace paddle