inplace_op_inference.h 4.9 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <functional>
#include <numeric>
#include <string>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"

namespace paddle {
namespace framework {

/*
  Inplace Inference for create In->Out pairs for inplaced operator.
  If we specify a pair of corresponding names. For example, X->Out.
  then Out will inplaced use X's memory. The base class will do
  legality validation for both variables.
*/
class InplaceOpInference {
 public:
  virtual ~InplaceOpInference() {}
  virtual std::unordered_map<std::string, std::string> operator()(
      const OpDesc& op_desc, BlockDesc* block) const = 0;
};

class InplaceInToOut : public InplaceOpInference {
 public:
  std::unordered_map<std::string, std::string> operator()(
      const OpDesc& op_desc, BlockDesc* block) const {
    std::unordered_map<std::string, std::string> ret;
    auto in_out_var_names_pair = this->Apply(op_desc, block);
    for (auto& pair : in_out_var_names_pair) {
      PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(),
                     string::Sprintf("op %s do not have input of %s!",
                                     op_desc.Type(), pair.first));
      PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(),
                     string::Sprintf("op %s do not have output of %s!",
                                     op_desc.Type(), pair.second));
      auto& in_name = op_desc.Input(pair.first).at(0);
      auto& out_name = op_desc.Output(pair.second).at(0);

      auto in = block->FindRecursiveOrCreateVar(in_name);
      auto out = block->FindRecursiveOrCreateVar(out_name);
      if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name});
    }
    return ret;
  }

 protected:
  virtual std::unordered_map<std::string, std::string> Apply(
      const OpDesc& op_desc, BlockDesc* block) const = 0;

  bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
    auto var_can_reused = [&](const VarDesc& node) -> bool {
      auto type = node.GetType();
      if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
          node.GetShape().empty()) {
        return false;
      }
      // vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
      std::string name = node.Name();
      if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
        return false;
      return true;
    };

    auto var_size_in_bytes = [&](const VarDesc& node) -> size_t {
      auto shape = node.GetShape();
      int size = std::accumulate(shape.begin(), shape.end(), 1,
                                 std::multiplies<int>());
      size_t type_size = SizeOfType(node.GetDataType());
      return type_size * std::abs(size);
    };

    return in.Name() != out.Name() && var_can_reused(in) &&
           var_can_reused(out) &&
           var_size_in_bytes(out) <= var_size_in_bytes(in);
  }
};

/*
  Inplace In and Out for operator only have an Input and an Output.
  For example, activation op.
 */
class SingleOpInplaceInToOut : public InplaceInToOut {
 protected:
  std::unordered_map<std::string, std::string> Apply(
      const OpDesc& op_desc, BlockDesc* block) const override {
    PADDLE_ENFORCE(!op_desc.InputNames().empty(),
                   "Op inputs must not be empty");
    PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
                   "Op outputs must not be empty");
    auto x_name = op_desc.InputNames().at(0);
    auto out_name = op_desc.OutputNames().at(0);
    return std::unordered_map<std::string, std::string>{{x_name, out_name}};
  }
};

/*
  Gradient op. Inplace output use it's Input.
  For example, Input@Grad->Input reuse strategy.
 */
class GradOpInplaceInToOut : public InplaceInToOut {
 protected:
  std::unordered_map<std::string, std::string> Apply(
      const OpDesc& op_desc, BlockDesc* block) const override {
    std::unordered_map<std::string, std::string> ret;
    std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
                                                 op_desc.OutputNames().end());
    for (auto& input_name : op_desc.InputNames()) {
      if (output_names.count(GradVarName(input_name))) {
        ret.insert({input_name, GradVarName(input_name)});
      }
    }
    return ret;
  }
};

}  // namespace framework
}  // namespace paddle