parameter_prefetch.h 2.4 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

Q
Qiao Longfei 已提交
20
#include "paddle/fluid/framework/operator.h"
Q
Qiao Longfei 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {
namespace distributed {

void prefetch(const std::string& id_name, const std::string& out_name,
Q
Qiao Longfei 已提交
27
              const std::vector<std::string>& table_names,
Q
Qiao Longfei 已提交
28
              const std::vector<std::string>& epmap,
Q
Qiao Longfei 已提交
29
              const std::vector<int>& height_sections,
T
tangwei12 已提交
30 31
              const framework::ExecutionContext& context,
              const framework::Scope& scope);
Q
Qiao Longfei 已提交
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
                               const std::string& out_name,
                               const std::vector<std::string>& table_names,
                               const std::vector<std::string>& epmap,
                               const std::vector<int>& height_sections,
                               const framework::ExecutionContext& context,
                               const framework::Scope& scope,
                               framework::LoDTensor* original) {
  prefetch(id_name, out_name, table_names, epmap, height_sections, context,
           scope);
  auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
  auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
  auto* original_value = original->data<T>();
  auto* out_value = out.data<T>();
  size_t original_width = original->numel() / original->dims()[0];

  for (int64_t i = 0; i < ids.numel(); i++) {
    const T* out_rows = out_value + original_width * i;
    T* original_row = original_value + original_width * ids.data<int64_t>()[i];
    std::memcpy(original_row, out_rows, original_width * sizeof(T));
  }
}

Q
Qiao Longfei 已提交
57 58 59
};  // namespace distributed
};  // namespace operators
};  // namespace paddle