layer.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17
#include <memory>
18
#include <string>
19
#include <unordered_map>
20 21
#include <vector>

22
#include "paddle/phi/api/include/tensor.h"
23 24
#include "paddle/phi/common/place.h"

25
#include "function.h"  //NOLINT
26 27

namespace paddle {
28 29 30 31 32

namespace framework {
class Variable;
}  // namespace framework

33
namespace jit {
34
class CompilationUnit;
35
class FunctionInfo;
36 37 38

using DenseTensor = phi::DenseTensor;
using Tensor = paddle::experimental::Tensor;
39
using Variable = paddle::framework::Variable;
40 41
using VariableMap = std::unordered_map<std::string, std::shared_ptr<Variable>>;
using FunctionInfoMap =
42
    std::unordered_map<std::string, std::shared_ptr<FunctionInfo>>;
43 44 45

class Layer {
 public:
46 47 48
  Layer(const VariableMap& params_map,
        const VariableMap& attrs_map_,
        const FunctionInfoMap& info_map,
49
        const phi::Place& place);
50

51
  jit::Function Function(const std::string& name) const;
52

53 54
  template <typename T>
  T Attribute(const std::string& name) const;
55

56 57 58
  std::vector<Tensor> forward(const std::vector<Tensor>& inputs);

  std::vector<DenseTensor> forward(const std::vector<DenseTensor>& inputs);
59

60 61
  void to(const phi::Place& place);

62 63
  void SetEngine(const std::string& name,
                 const std::shared_ptr<BaseEngine>& engine);
64

65 66
  const std::shared_ptr<jit::FunctionInfo>& FunctionInfo(
      const std::string& name) const;
67

68 69
  std::vector<std::string> FunctionNames() const;

70
 private:
71 72 73
  VariableMap params_map_;
  VariableMap attrs_map_;
  FunctionInfoMap info_map_;
74
  std::shared_ptr<CompilationUnit> unit_;
75 76 77 78
};

}  // namespace jit
}  // namespace paddle