ipu_executor.h 2.8 KB
Newer Older
J
jianghaicheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <popart/dataflow.hpp>
18
#include <popart/half.hpp>
J
jianghaicheng 已提交
19
#include <popart/names.hpp>
20
#include <popart/patterns/patterns.hpp>
J
jianghaicheng 已提交
21
#include <popart/session.hpp>
22
#include <popart/tensorinfo.hpp>
A
Allen Guo 已提交
23
#include <popdist/popdist_poplar.hpp>
J
jianghaicheng 已提交
24

25
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
J
jianghaicheng 已提交
26

27 28 29 30 31 32
namespace paddle {
namespace framework {
class ExecutionContext;
}  // namespace framework
}  // namespace paddle

J
jianghaicheng 已提交
33 34 35 36
namespace paddle {
namespace platform {
namespace ipu {

37 38 39
struct CompilerResources;
class IpuStrategy;

40 41 42 43
struct ExecutorResources {
  // map<tensor_id, paddle_var_ptr>
  popart::WeightsIO weights_io;
  // <popart_var, paddle_var> pairs, include weights and optimizer states
A
Allen Guo 已提交
44
  std::vector<std::pair<popart::TensorId, std::string>> weights_and_opt_state;
45 46
};

J
jianghaicheng 已提交
47 48
class Executor {
 public:
49
  Executor() = default;
J
jianghaicheng 已提交
50 51
  ~Executor();

52
  // Build popart session
53
  void Prepare(const std::string &proto);
J
jianghaicheng 已提交
54

55
  // Run popart session
56 57
  void Run(const std::vector<const Tensor *> &inputs,
           const std::vector<Tensor *> &outputs,
J
jianghaicheng 已提交
58 59
           const framework::ExecutionContext &ctx);

60
  // Sync weights from popart to paddle
A
Allen Guo 已提交
61 62
  void WeightsToHost();

63
  // Detach IPU
64
  void Detach();
J
jianghaicheng 已提交
65 66

  // Scope
67
  void SetScope(const Scope *scope) { scope_ = scope; }
J
jianghaicheng 已提交
68 69

  // Strategy
70 71 72
  void SetIpuStrategy(const IpuStrategy &strategy) {
    ipu_strategy_ = &strategy;
  }
J
jianghaicheng 已提交
73

74 75 76 77
  // CompilerResources
  void SetCompilerResources(CompilerResources *resources) {
    compiler_resources_ = resources;
  }
J
jianghaicheng 已提交
78

79 80
  // Save model to onnx
  void SaveModelToHost(const std::string &path);
J
jianghaicheng 已提交
81 82

 private:
83
  void AcquireDevice();
A
Allen Guo 已提交
84 85 86 87
  void SetWeightsIO();
  void ConvertWeights(bool);
  void WeightsFromPaddle();
  void WeightsToPaddle();
88 89

 private:
90
  // Not own
91
  const Scope *scope_ = nullptr;
J
jianghaicheng 已提交
92
  const IpuStrategy *ipu_strategy_ = nullptr;
93
  CompilerResources *compiler_resources_ = nullptr;
A
Allen Guo 已提交
94
  bool compile_only_ = false;
95

96
  // Deviceinfo for popart session
97
  std::shared_ptr<popart::DeviceInfo> device_;
98
  // Popart session, where graph running
99
  std::unique_ptr<popart::Session> session_;
100
  // A ExecutorResources corresponds to a graph
101
  std::unique_ptr<ExecutorResources> executor_resources_;
J
jianghaicheng 已提交
102 103 104 105 106
};

}  // namespace ipu
}  // namespace platform
}  // namespace paddle