ipu_executor.h 2.8 KB
Newer Older
J
jianghaicheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <popart/dataflow.hpp>
18
#include <popart/half.hpp>
J
jianghaicheng 已提交
19
#include <popart/names.hpp>
20
#include <popart/patterns/patterns.hpp>
J
jianghaicheng 已提交
21
#include <popart/session.hpp>
A
Allen Guo 已提交
22
#include <popart/stepio.hpp>
23
#include <popart/tensorinfo.hpp>
J
jianghaicheng 已提交
24

25
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
J
jianghaicheng 已提交
26

27 28 29 30 31 32
namespace paddle {
namespace framework {
class ExecutionContext;
}  // namespace framework
}  // namespace paddle

J
jianghaicheng 已提交
33 34 35 36
namespace paddle {
namespace platform {
namespace ipu {

37 38 39
struct CompilerResources;
class IpuStrategy;

40 41 42 43
struct ExecutorResources {
  // map<tensor_id, paddle_var_ptr>
  popart::WeightsIO weights_io;
  // <popart_var, paddle_var> pairs, include weights and optimizer states
A
Allen Guo 已提交
44
  std::vector<std::pair<popart::TensorId, std::string>> weights_and_opt_state;
45 46
};

J
jianghaicheng 已提交
47 48
class Executor {
 public:
49
  Executor() = default;
J
jianghaicheng 已提交
50 51
  ~Executor();

52
  // Build popart session
53
  void Prepare(const std::string &proto);
J
jianghaicheng 已提交
54

55
  // Run popart session
56 57
  void Run(const std::vector<const Tensor *> &inputs,
           const std::vector<Tensor *> &outputs,
J
jianghaicheng 已提交
58 59
           const framework::ExecutionContext &ctx);

60
  // Sync weights from popart to paddle
A
Allen Guo 已提交
61 62
  void WeightsToHost();

63
  // Detach IPU
64
  void Detach();
J
jianghaicheng 已提交
65

A
Allen Guo 已提交
66 67 68
  // Reset session
  void Reset();

J
jianghaicheng 已提交
69
  // Scope
70
  void SetScope(const Scope *scope) { scope_ = scope; }
J
jianghaicheng 已提交
71 72

  // Strategy
73 74 75
  void SetIpuStrategy(const IpuStrategy &strategy) {
    ipu_strategy_ = &strategy;
  }
J
jianghaicheng 已提交
76

77 78 79 80
  // CompilerResources
  void SetCompilerResources(CompilerResources *resources) {
    compiler_resources_ = resources;
  }
J
jianghaicheng 已提交
81

82 83
  // Save model to onnx
  void SaveModelToHost(const std::string &path);
J
jianghaicheng 已提交
84 85

 private:
86
  void AcquireDevice();
A
Allen Guo 已提交
87 88 89 90
  void SetWeightsIO();
  void ConvertWeights(bool);
  void WeightsFromPaddle();
  void WeightsToPaddle();
91 92

 private:
93
  // Not own
94
  const Scope *scope_ = nullptr;
J
jianghaicheng 已提交
95
  const IpuStrategy *ipu_strategy_ = nullptr;
96
  CompilerResources *compiler_resources_ = nullptr;
A
Allen Guo 已提交
97
  bool compile_only_ = false;
98

99
  // Deviceinfo for popart session
100
  std::shared_ptr<popart::DeviceInfo> device_;
101
  // Popart session, where graph running
102
  std::unique_ptr<popart::Session> session_;
103
  // A ExecutorResources corresponds to a graph
104
  std::unique_ptr<ExecutorResources> executor_resources_;
J
jianghaicheng 已提交
105 106 107 108 109
};

}  // namespace ipu
}  // namespace platform
}  // namespace paddle