layer.h 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17
#include <memory>
18
#include <string>
19
#include <unordered_map>
20 21
#include <vector>

22
#include "paddle/phi/api/include/tensor.h"
23 24
#include "paddle/phi/common/place.h"

25
#include "function.h"  //NOLINT
26 27

namespace paddle {
28 29 30 31 32

namespace framework {
class Variable;
}  // namespace framework

33
namespace jit {
34
class CompilationUnit;
35
class FunctionInfo;
36 37 38

using DenseTensor = phi::DenseTensor;
using Tensor = paddle::experimental::Tensor;
39
using Variable = paddle::framework::Variable;
40 41
using Name2VariableMap =
    std::unordered_map<std::string, std::shared_ptr<Variable>>;
42 43 44 45
using Name2EngineMap =
    std::unordered_map<std::string, std::shared_ptr<BaseEngine>>;
using Name2FunctionInfoMap =
    std::unordered_map<std::string, std::shared_ptr<FunctionInfo>>;
46 47 48

class Layer {
 public:
49 50 51
  Layer(const Name2VariableMap& params_map,
        const Name2VariableMap& attrs_map_,
        const Name2FunctionInfoMap& info_map,
52
        const phi::Place& place);
53

54
  jit::Function Function(const std::string& name) const;
55

56 57
  template <typename T>
  T Attribute(const std::string& name) const;
58

59 60 61
  std::vector<Tensor> forward(const std::vector<Tensor>& inputs);

  std::vector<DenseTensor> forward(const std::vector<DenseTensor>& inputs);
62

63 64
  void to(const phi::Place& place);

65 66
  void SetEngine(const std::string& name,
                 const std::shared_ptr<BaseEngine>& engine);
67

68
  const Name2EngineMap& EngineMap() const;
69

70 71
  const std::shared_ptr<jit::FunctionInfo>& FunctionInfo(
      const std::string& name) const;
72

73
 private:
74 75 76
  Name2VariableMap params_map_;
  Name2VariableMap attrs_map_;
  Name2FunctionInfoMap info_map_;
77
  std::shared_ptr<CompilationUnit> unit_;
78 79 80 81
};

}  // namespace jit
}  // namespace paddle