parameter_prefetch.h 3.0 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

Q
Qiao Longfei 已提交
20
#include "paddle/fluid/framework/operator.h"
Q
Qiao Longfei 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace distributed {

void prefetch(const std::string& id_name, const std::string& out_name,
Q
Qiao Longfei 已提交
27
              const std::vector<std::string>& table_names,
Q
Qiao Longfei 已提交
28
              const std::vector<std::string>& epmap,
Q
Qiao Longfei 已提交
29
              const std::vector<int>& height_sections,
T
tangwei12 已提交
30 31
              const framework::ExecutionContext& context,
              const framework::Scope& scope);
Q
Qiao Longfei 已提交
32

33 34 35 36 37 38 39 40 41
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
                               const std::string& out_name,
                               const std::vector<std::string>& table_names,
                               const std::vector<std::string>& epmap,
                               const std::vector<int>& height_sections,
                               const framework::ExecutionContext& context,
                               const framework::Scope& scope,
                               framework::LoDTensor* original) {
42 43 44
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto& actual_ctx = *pool.Get(context.GetPlace());

45 46 47 48 49 50 51 52
  prefetch(id_name, out_name, table_names, epmap, height_sections, context,
           scope);
  auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
  auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
  auto* original_value = original->data<T>();
  auto* out_value = out.data<T>();
  size_t original_width = original->numel() / original->dims()[0];

53 54 55 56 57
  bool is_on_cpu_place = true;
  if (!platform::is_cpu_place(ids.place())) {
    is_on_cpu_place = false;
  }

58 59 60
  for (int64_t i = 0; i < ids.numel(); i++) {
    const T* out_rows = out_value + original_width * i;
    T* original_row = original_value + original_width * ids.data<int64_t>()[i];
61 62 63 64 65 66 67
    if (is_on_cpu_place) {
      std::memcpy(original_row, out_rows, original_width * sizeof(T));
    } else {
#ifndef PADDLE_WITH_CUDA
      PADDLE_THROW("paddle is not compiled with CUDA!");
#else
      auto stream =
68 69 70 71
          static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
      memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
                   platform::CPUPlace(), out_rows, original_width * sizeof(T),
                   stream);
72 73
#endif
    }
74 75 76
  }
}

Q
Qiao Longfei 已提交
77 78 79
};  // namespace distributed
};  // namespace operators
};  // namespace paddle