parameter_prefetch.h 3.2 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

Q
Qiao Longfei 已提交
20
#include "paddle/fluid/framework/operator.h"
Q
Qiao Longfei 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace distributed {

void prefetch(const std::string& id_name, const std::string& out_name,
Q
Qiao Longfei 已提交
27
              const std::vector<std::string>& table_names,
Q
Qiao Longfei 已提交
28
              const std::vector<std::string>& epmap,
Q
Qiao Longfei 已提交
29
              const std::vector<int>& height_sections,
T
tangwei12 已提交
30 31
              const framework::ExecutionContext& context,
              const framework::Scope& scope);
Q
Qiao Longfei 已提交
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
                               const std::string& out_name,
                               const std::vector<std::string>& table_names,
                               const std::vector<std::string>& epmap,
                               const std::vector<int>& height_sections,
                               const framework::ExecutionContext& context,
                               const framework::Scope& scope,
                               framework::LoDTensor* original) {
  prefetch(id_name, out_name, table_names, epmap, height_sections, context,
           scope);
  auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
  auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
  auto* original_value = original->data<T>();
  auto* out_value = out.data<T>();
  size_t original_width = original->numel() / original->dims()[0];

50 51 52 53
  bool is_on_cpu_place = true;
  if (!platform::is_cpu_place(ids.place())) {
    is_on_cpu_place = false;
  }
J
JiabinYang 已提交
54 55 56 57 58
  if (is_on_cpu_place) {
    for (int64_t i = 0; i < ids.numel(); i++) {
      const T* out_rows = out_value + original_width * i;
      T* original_row =
          original_value + original_width * ids.data<int64_t>()[i];
59
      std::memcpy(original_row, out_rows, original_width * sizeof(T));
J
JiabinYang 已提交
60 61
    }
  } else {
62
#ifndef PADDLE_WITH_CUDA
J
JiabinYang 已提交
63
    PADDLE_THROW("paddle is not compiled with CUDA!");
64
#else
J
JiabinYang 已提交
65 66 67 68 69 70
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto& actual_ctx = *pool.Get(context.GetPlace());
    for (int64_t i = 0; i < ids.numel(); i++) {
      const T* out_rows = out_value + original_width * i;
      T* original_row =
          original_value + original_width * ids.data<int64_t>()[i];
71
      auto stream =
72 73 74 75
          static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
      memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
                   platform::CPUPlace(), out_rows, original_width * sizeof(T),
                   stream);
76
    }
J
JiabinYang 已提交
77
#endif
78 79 80
  }
}

Q
Qiao Longfei 已提交
81 82 83
};  // namespace distributed
};  // namespace operators
};  // namespace paddle