py_layer_fwd.h 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
20
#include "paddle/fluid/imperative/prepared_operator.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#include "paddle/fluid/imperative/tracer.h"

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/operators/py_layer_op.h"

namespace paddle {
namespace imperative {

namespace py = ::pybind11;

bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
  for (const auto& name_pair : ins) {
    for (const auto& var_base : name_pair.second) {
      if (!var_base->OverridedStopGradient()) {
36 37 38 39 40 41 42 43 44 45 46
        for (const auto& pair : outs) {
          for (const auto& var : pair.second) {
            if (var) {
              var->SetOverridedStopGradient(false);
              SetForwardDataTypeOfGradVar(var);
              VLOG(3) << "Set output: " << var->Name()
                      << "'s OverridedStopGradient as "
                      << var->OverridedStopGradient();
            }
          }
        }
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        return true;
      }
    }
  }
  return false;
}

std::shared_ptr<GradOpNode> CreateGradOpNode(
    const std::string& type, const NameVarBaseMap& ins,
    const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
    const platform::Place& place,
    const std::map<std::string, std::string>& inplace_map,
    const std::shared_ptr<operators::PyLayerContext>& py_context) {
  operators::PyLayerGradOpMaker<paddle::imperative::OpBase> maker(
      type, ins, outs, attrs, inplace_map);

  maker.SetPyLayerContext(py_context);
  auto grad_node = maker();
  if (grad_node && !grad_node->empty()) {
    for (auto& grad_op : *grad_node) {
      grad_op.SetId(OpBase::GenerateUniqueId());
      grad_op.SetPlace(place);
      ClearNoNeedBufferInputs(&grad_op);
    }
    return grad_node;
  } else {
    return nullptr;
  }
}

77
py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
78
                        const py::args args, const py::kwargs kwargs) {
79
  py::gil_scoped_acquire guard;
80 81 82 83 84 85
  auto bk_function = cls.attr("_backward_function");
  auto context = bk_function();
  auto forward = cls.attr("forward");

  auto result_forward = forward(context, *args, **kwargs);
  std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
86
      std::make_shared<operators::PyLayerContext>(context.ptr());
87 88 89 90 91
  // make inputs to varbase
  std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
  // process args,`input_vars` only collect `imperative::VarBase`
  if (!args.empty()) {
    for (auto ptr = args.begin(); ptr != args.end(); ptr++) {
92 93 94 95
      // Only collect Tensor type in 'args' and pass them to backward. Ignore
      // other types of input temporarily.
      if (py::isinstance<imperative::VarBase>(*ptr)) {
        try {
96 97
          auto a = ptr->cast<std::shared_ptr<VarBase>>();
          input_vars.push_back(a);
98 99 100 101 102
        } catch (py::cast_error& err) {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "The `PyLayer.forward` function contains invalid argument, the "
              "`%s` type argument can not be cast into `Tensor`.",
              ptr->ptr()->ob_type->tp_name));
103
        }
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
      } else if (py::isinstance<py::tuple>(*ptr) ||
                 py::isinstance<py::list>(*ptr)) {
        try {
          auto tuple_arg = ptr->cast<py::tuple>();
          for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
            try {
              auto t = iter->cast<std::shared_ptr<VarBase>>();
              input_vars.push_back(t);
            } catch (py::cast_error& err) {
              PADDLE_THROW(platform::errors::InvalidArgument(
                  "The `PyLayer.forward` function contains invalid argument, "
                  "the "
                  "`%s` type argument can not be cast into `Tensor`.",
                  ptr->ptr()->ob_type->tp_name));
            }
          }
        } catch (py::cast_error& err) {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "The `PyLayer.forward` function contains invalid argument, the "
              "`%s` type argument can not be cast into `Tensor`.",
              ptr->ptr()->ob_type->tp_name));
        }
126 127 128 129 130 131
      }
    }
  }
  // process kwargs, only collect `imperative::VarBase`
  if (!kwargs.empty()) {
    for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) {
132 133 134 135
      // Only collect Tensor type in 'kwargs' and pass them to backward.
      // Ignore other types of input temporarily.
      if (py::isinstance<imperative::VarBase>(*ptr->second)) {
        try {
136 137
          auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
          input_vars.push_back(a);
138 139 140 141
        } catch (py::cast_error&) {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "The `PyLayer.forward` function contains invalid argument, the "
              "`%s` type argument can not be cast into `Tensor`.",
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
              ptr->second.ptr()->ob_type->tp_name));
        }
      } else if (py::isinstance<py::tuple>(*ptr->second) ||
                 py::isinstance<py::list>(*ptr->second)) {
        try {
          auto tuple_arg = ptr->second.cast<py::tuple>();
          for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
            try {
              auto t = iter->cast<std::shared_ptr<VarBase>>();
              input_vars.push_back(t);
            } catch (py::cast_error& err) {
              PADDLE_THROW(platform::errors::InvalidArgument(
                  "The `PyLayer.forward` function contains invalid argument, "
                  "the "
                  "`%s` type argument can not be cast into `Tensor`.",
                  ptr->second.ptr()->ob_type->tp_name));
            }
          }
        } catch (py::cast_error& err) {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "The `PyLayer.forward` function contains invalid argument, the "
              "`%s` type argument can not be cast into `Tensor`.",
164 165
              ptr->second.ptr()->ob_type->tp_name));
        }
166 167 168 169 170 171 172 173 174 175
      }
    }
  }
  NameVarBaseMap ins = {{"X", input_vars}};

  std::vector<std::shared_ptr<imperative::VarBase>> output_vars;
  if (PyTuple_Check(result_forward.ptr()) ||
      PyList_Check(result_forward.ptr())) {
    auto tuple_result = result_forward.cast<py::tuple>();
    for (size_t i = 0; i < tuple_result.size(); i++) {
176 177 178
      // Only collect Tensor type of output and pass them to backward.
      // Ignore other types of input temporarily.
      if (py::isinstance<imperative::VarBase>(tuple_result[i])) {
179 180 181 182 183
        try {
          auto temp_out =
              tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
          output_vars.push_back(temp_out);
        } catch (py::cast_error&) {
184 185 186 187
          PADDLE_THROW(platform::errors::InvalidArgument(
              "The `PyLayer.forward` function returns invalid argument, the "
              "`%s` type argument can not be cast into `Tensor`.",
              tuple_result[i].ptr()->ob_type->tp_name));
188 189 190 191
        }
      }
    }
  } else {
192 193 194
    // Only collect Tensor type of output and pass them to backward.
    // Ignore other types of input temporarily.
    if (py::isinstance<imperative::VarBase>(result_forward)) {
195 196 197 198 199
      try {
        auto temp_out =
            result_forward.cast<std::shared_ptr<imperative::VarBase>>();
        output_vars.push_back(temp_out);
      } catch (py::cast_error&) {
200 201 202 203
        PADDLE_THROW(platform::errors::InvalidArgument(
            "The `PyLayer.forward` function returns invalid argument, the `%s` "
            "type argument can not be cast into `Tensor`.",
            result_forward.ptr()->ob_type->tp_name));
204 205 206
      }
    }
  }
W
WeiXin 已提交
207 208 209 210
  if (output_vars.size() == 0) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "At least one output of `PyLayer.forward` is a `Tensor`."));
  }
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228

  NameVarBaseMap outs = {{"Out", output_vars}};

  if (RequiredGrad(ins, outs)) {
    std::map<std::string, std::string> inplace_map{};
    bool if_inplace = false;
    for (auto temp_ins : input_vars) {
      if (if_inplace) {
        break;
      }
      for (auto temp_outs : output_vars) {
        if (temp_ins->Name() == temp_outs->Name()) {
          if_inplace = true;
          break;
        }
      }
    }
    if (if_inplace) {
229 230 231 232 233 234 235 236 237
      // when pylayer forward is inplace strategy, check whether tensor is leaf
      for (auto& t : input_vars) {
        PADDLE_ENFORCE_EQ(t->IsLeaf() && !t->OverridedStopGradient(), false,
                          platform::errors::InvalidArgument(
                              "Leaf Var (%s) that doesn't stop gradient can't "
                              "use inplace strategy.",
                              t->Name()));
      }

238 239 240 241 242 243 244 245 246 247 248 249 250 251
      inplace_map["X"] = "Out";
    }

    CreateGradOpNode("py_layer", ins, outs, {{}}, place, inplace_map,
                     py_layer_ctx);
  } else {
    VLOG(3) << "No Grad to track for Op: py_layer_op";
  }

  return result_forward;
}

}  // namespace imperative
}  // namespace paddle