hashtable_inl.h 8.9 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
Thunderbrook 已提交
15
#ifdef PADDLE_WITH_HETERPS
T
Thunderbrook 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

namespace paddle {
namespace framework {

template <typename value_type>
struct ReplaceOp {
  __host__ __device__ value_type operator()(value_type new_value,
                                            value_type old_value) {
    return new_value;
  }
};

template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const typename Table::mapped_type* const vals,
                              size_t len) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    kv.first = keys[i];
    kv.second = vals[i];
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
template <typename Table>
__global__ void insert_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              size_t len, char* pool, int start_index) {
  ReplaceOp<typename Table::mapped_type> op;
  thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

  if (i < len) {
    kv.first = keys[i];
    kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
    auto it = table->insert(kv, op);
    assert(it != table->end() && "error: insert fails: table is full");
  }
}

T
Thunderbrook 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75
template <typename Table>
__global__ void search_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              typename Table::mapped_type* const vals,
                              size_t len) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      vals[i] = it->second;
    }
  }
}

76 77 78 79 80 81 82 83 84 85 86 87 88 89
template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
                                    const typename Table::key_type* const keys,
                                    char* const vals, size_t len,
                                    size_t pull_feature_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);

    if (it != table->end()) {
      *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
    }
  }
}
T
Thunderbrook 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
                              const typename Table::key_type* const keys,
                              const GradType* const grads, size_t len,
                              Sgd sgd) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      sgd.update_value((it.getter())->second, grads[i]);
    }
  }
}

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
                                    const typename Table::key_type* const keys,
                                    const char* const grads, size_t len,
                                    Sgd sgd, size_t grad_value_size) {
  const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
  if (i < len) {
    auto it = table->find(keys[i]);
    if (it != table->end()) {
      FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
      sgd.dy_mf_update_value((it.getter())->second, *cur);
    } else {
      printf("yxf::push miss key: %d", keys[i]);
    }
  }
}

T
Thunderbrook 已提交
121 122 123
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
  container_ = new TableContainer<KeyType, ValType>(capacity);
124
  rwlock_.reset(new RWLock);
T
Thunderbrook 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
  delete container_;
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
  container_->print();
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
139
                                      size_t len, gpuStream_t stream) {
T
Thunderbrook 已提交
140 141 142 143 144 145 146 147
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

148 149 150 151 152 153 154 155 156 157 158
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
                                      size_t len, gpuStream_t stream) {
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_vals, len, pull_feature_value_size_);
}

T
Thunderbrook 已提交
159 160 161
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
                                         const ValType* d_vals, size_t len,
162
                                         gpuStream_t stream) {
T
Thunderbrook 已提交
163 164 165 166 167 168 169 170
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_vals, len);
}

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
                                         char* pool, size_t start_index,
                                         gpuStream_t stream) {
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  if (pool == NULL) {
    return;
  }
  insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
                                                       pool, start_index);
}

T
Thunderbrook 已提交
186 187 188 189 190 191 192 193 194 195 196
template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
  container_->prefetch(cudaCpuDeviceId, stream);
  size_t num = container_->size();
  KeyType unuse_key = std::numeric_limits<KeyType>::max();
  thrust::pair<KeyType, ValType>* kv = container_->data();
  for (size_t i = 0; i < num; ++i) {
    if (kv[i].first == unuse_key) {
      continue;
    }
    ValType& gpu_val = kv[i].second;
T
Thunderbrook 已提交
197
#ifdef PADDLE_WITH_PSLIB
T
Thunderbrook 已提交
198 199 200 201 202 203 204
    auto* downpour_value =
        (paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
    int downpour_value_size = downpour_value->size();
    if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
      downpour_value->resize(gpu_val.mf_size + downpour_value_size);
    }
    float* cpu_val = downpour_value->data();
205
    // cpu_val[0] = 0;
T
Thunderbrook 已提交
206 207 208 209 210 211 212 213 214 215 216
    cpu_val[1] = gpu_val.delta_score;
    cpu_val[2] = gpu_val.show;
    cpu_val[3] = gpu_val.clk;
    cpu_val[4] = gpu_val.lr;
    cpu_val[5] = gpu_val.lr_g2sum;
    cpu_val[6] = gpu_val.slot;
    if (gpu_val.mf_size > 0) {
      for (int x = 0; x < gpu_val.mf_size; x++) {
        cpu_val[x + 7] = gpu_val.mf[x];
      }
    }
T
Thunderbrook 已提交
217 218 219 220 221 222 223 224
#endif
#ifdef PADDLE_WITH_PSCORE
    auto* downpour_value = (paddle::distributed::VALUE*)(gpu_val.cpu_ptr);
    downpour_value->count_ = gpu_val.show;
    for (int x = 0; x < gpu_val.mf_size; x++) {
      downpour_value->data_[x] = gpu_val.mf[x];
    }
#endif
T
Thunderbrook 已提交
225 226 227 228 229
  }

  container_->prefetch(devid, stream);
}

T
Thunderbrook 已提交
230 231 232 233
template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const GradType* d_grads, size_t len,
234
                                         Sgd sgd, gpuStream_t stream) {
T
Thunderbrook 已提交
235 236 237 238 239 240 241 242
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
  update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys,
                                                       d_grads, len, sgd);
}

243 244 245 246 247 248 249 250 251 252 253 254 255 256
template <typename KeyType, typename ValType>
template <typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
                                         const char* d_grads, size_t len,
                                         Sgd sgd, gpuStream_t stream) {
  if (len == 0) {
    return;
  }
  const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;

  dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
      container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}

T
Thunderbrook 已提交
257 258 259
}  // end namespace framework
}  // end namespace paddle
#endif