layer.h 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace imperative {

class OpBase;

class VarBase {
 public:
32
  explicit VarBase(bool stop_gradient = false)
33 34 35 36
      : pre_op_(nullptr),
        pre_op_out_idx_(-1),
        var_desc_(nullptr),
        var_(nullptr),
37 38
        grads_(nullptr),
        stop_gradient_(stop_gradient) {}
39 40 41 42 43 44 45 46 47

  virtual ~VarBase() {}

  void ApplyGrad(framework::Scope* scope, framework::Variable* grad);

  void RunBackward(framework::Scope* scope);

  framework::LoDTensor& Grad();

M
minqiyang 已提交
48 49 50 51 52 53 54 55 56
  inline framework::Variable* GradVar() { return grads_; }

  inline std::string GradName() const {
    PADDLE_ENFORCE(
        var_desc_,
        "Couldn't get gradient variable's name, please call backward() first");
    return string::Sprintf("%s@IGrad", var_desc_->Name());
  }

57 58 59 60 61 62
  OpBase* pre_op_;
  int pre_op_out_idx_;

  framework::VarDesc* var_desc_;
  framework::Variable* var_;
  framework::Variable* grads_;
63 64

  bool stop_gradient_;
65 66 67 68 69 70 71 72 73 74
};

class OpBase {
 public:
  OpBase()
      : input_vars_(new std::vector<VarBase*>()),
        output_vars_(new std::vector<VarBase*>()),
        pre_ops_(new std::vector<OpBase*>()),
        pre_ops_out_idx_(new std::vector<int>()),
        op_desc_(nullptr),
M
minqiyang 已提交
75 76
        grad_op_desc_(nullptr),
        grad_to_var_(nullptr) {}
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115

  virtual ~OpBase() {
    delete input_vars_;
    delete output_vars_;

    delete pre_ops_;
    delete pre_ops_out_idx_;

    if (grad_op_desc_) delete grad_op_desc_;
    if (grad_to_var_) delete grad_to_var_;
  }

  std::vector<framework::Variable*> ApplyGrad(framework::Scope* scope);

  std::vector<VarBase*>* input_vars_;
  std::vector<VarBase*>* output_vars_;
  std::vector<OpBase*>* pre_ops_;
  std::vector<int>* pre_ops_out_idx_;
  framework::OpDesc* op_desc_;

  framework::OpDesc* grad_op_desc_;
  std::unordered_map<std::string, std::string>* grad_to_var_;
  framework::BlockDesc* block_;
};

class Layer {
 public:
  virtual ~Layer() {}

  virtual std::vector<VarBase> Forward(const std::vector<VarBase>& inputs) {
    std::vector<VarBase> vars;
    return vars;
  }

  virtual void Backward() { LOG(ERROR) << "To support customize"; }
};

}  // namespace imperative
}  // namespace paddle