py_func_op.cc 12.4 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/py_func_op.h"
X
Xin Pan 已提交
16

M
minqiyang 已提交
17
#include <memory>
S
sneaxiy 已提交
18 19
#include <set>
#include <string>
M
minqiyang 已提交
20 21
#include <unordered_set>
#include <utility>
S
sneaxiy 已提交
22 23 24 25 26 27
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

S
sneaxiy 已提交
28
namespace py = ::pybind11;
S
sneaxiy 已提交
29 30 31

static std::vector<py::object> g_py_callables;

S
sneaxiy 已提交
32 33 34 35
const char kForwardPythonCallableId[] = "forward_callable_id";
const char kBackwardPythonCallableId[] = "backward_callable_id";
const char kPyFuncBackwardSkipVars[] = "backward_skip_vars";

S
sneaxiy 已提交
36
size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
S
sneaxiy 已提交
37 38 39 40
  g_py_callables.emplace_back(py_obj);
  return g_py_callables.size() - 1;
}

S
sneaxiy 已提交
41 42 43
// Return py::object* instead of py::object
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
S
sneaxiy 已提交
44
static py::object *GetPythonCallableObject(size_t i) {
45 46 47 48 49
  PADDLE_ENFORCE_LT(
      i, g_py_callables.size(),
      platform::errors::InvalidArgument(
          "Invalid python callable id %d, which should be less than %d.", i,
          g_py_callables.size()));
S
sneaxiy 已提交
50 51 52
  return &g_py_callables[i];
}

S
sneaxiy 已提交
53
static std::string PythonFuncDebugString(const py::object &py_callable) {
S
sneaxiy 已提交
54
  py::gil_scoped_acquire guard;
S
sneaxiy 已提交
55 56 57 58
  std::string wrapper_func_str = py::str(py_callable);
  auto inner_func = py_callable.attr("_func");
  std::string inner_func_str = py::str(inner_func);
  return inner_func_str + " wrapped by " + wrapper_func_str;
S
sneaxiy 已提交
59 60
}

S
sneaxiy 已提交
61 62
static void CallPythonFunc(py::object *callable,
                           const std::vector<framework::LoDTensor> &ins,
S
sneaxiy 已提交
63
                           std::vector<framework::LoDTensor *> *outs) {
S
sneaxiy 已提交
64
  py::gil_scoped_acquire guard;
S
sneaxiy 已提交
65 66
  py::tuple in_args(ins.size());
  for (size_t i = 0; i < ins.size(); ++i) {
S
sneaxiy 已提交
67
    in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
S
sneaxiy 已提交
68 69
  }

S
sneaxiy 已提交
70
  auto ret = (*callable)(*in_args);
S
sneaxiy 已提交
71
  auto ret_tuple = py::cast<py::tuple>(ret);
S
sneaxiy 已提交
72
  size_t ret_num = py::len(ret_tuple);
S
sneaxiy 已提交
73 74
  size_t out_num = outs->size();
  if (UNLIKELY(ret_num != out_num)) {
S
sneaxiy 已提交
75 76 77
    // Python function has no return values or returns None
    // In this case, ret_num = 1 && ret[0] == None && out_num should be 0
    // Otherwise, ret_num must be equal to out_num
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    PADDLE_ENFORCE_EQ(ret_num == 1, true,
                      platform::errors::InvalidArgument(
                          "Python function has no return values or returns "
                          "None. In this case, ret_num = 1 && ret[0] == None "
                          "&& out_num should be 0. But ret_num is %d",
                          ret_num));

    PADDLE_ENFORCE_EQ(
        out_num == 0, true,
        platform::errors::InvalidArgument(
            "Python function has no return values or returns None. In "
            "this case, ret_num = 1 && ret[0] == None && out_num should "
            "be 0. But out_num is %d",
            out_num));

    PADDLE_ENFORCE_EQ(
        py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr, true,
        platform::errors::InvalidArgument(
            "Python function has no return values or returns None. In "
            "this case, ret_num = 1 && ret[0] == None && out_num should "
            "be 0. But ret[0] is not None"));
S
sneaxiy 已提交
99 100 101
  }

  for (size_t i = 0; i < out_num; ++i) {
S
sneaxiy 已提交
102 103
    auto *out = (*outs)[i];
    if (out == nullptr) {
S
sneaxiy 已提交
104 105
      continue;
    }
S
sneaxiy 已提交
106
    try {
S
sneaxiy 已提交
107 108
      auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
      PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
109 110
                              platform::errors::InvalidArgument(
                                  "Output tensor %d should not be nullptr", i));
S
sneaxiy 已提交
111 112
      out->set_lod(py_out_tensor->lod());
      out->ShareDataWith(*py_out_tensor);
S
sneaxiy 已提交
113
    } catch (py::cast_error &) {
114
      PADDLE_THROW(platform::errors::InvalidArgument(
W
Wilber 已提交
115 116 117
          "py::cast to LoDTensor error. The %d-th output expection is "
          "LoDTensor",
          i));
S
sneaxiy 已提交
118 119 120 121
    }
  }
}

122
class PyFuncOpVarTypeInference : public framework::StaticGraphVarTypeInference {
S
sneaxiy 已提交
123
 public:
M
minqiyang 已提交
124
  void operator()(framework::InferVarTypeContext *ctx) const override {
125 126
    bool has_out = ctx->HasOutput("Out");
    bool has_in = ctx->HasInput("X");
S
sneaxiy 已提交
127 128 129 130 131

    /**
     * X or Out can be empty, so that py_func can be more flexible
     * to support Python functions with no input or no output
     */
132 133 134 135 136 137 138
    PADDLE_ENFORCE_EQ(
        has_in || has_out, true,
        platform::errors::InvalidArgument("Input(X) or Output(Out) must exist, "
                                          "but has_in is %d, has_out is %d.",
                                          has_in, has_out));

    PADDLE_ENFORCE_GE(
139
        BOOST_GET_CONST(int, ctx->GetAttr(kForwardPythonCallableId)), 0,
140 141
        platform::errors::InvalidArgument(
            "Function id cannot be less than 0, but received value is %d.",
142
            BOOST_GET_CONST(int, ctx->GetAttr(kForwardPythonCallableId))));
S
sneaxiy 已提交
143

S
sneaxiy 已提交
144 145
    if (!has_out) return;

S
sneaxiy 已提交
146 147 148 149 150
    /**
     * Traverse all outputs, check if name of any output ends with @GRAD.
     * If found, set its shape, dtype, lod_level, type to be the same as
     * the corresponding forward variable
     */
S
sneaxiy 已提交
151
    const std::string kGradVarSuffix = framework::kGradVarSuffix;
152
    auto &out_var_names = Output(ctx, "Out");
S
sneaxiy 已提交
153 154 155
    for (auto &out_var_name : out_var_names) {
      if (out_var_name == framework::kEmptyVarName ||
          out_var_name.size() < kGradVarSuffix.size()) {
S
sneaxiy 已提交
156 157 158
        continue;
      }

S
sneaxiy 已提交
159 160 161
      size_t len = out_var_name.size() - kGradVarSuffix.size();
      if (out_var_name.substr(len) == kGradVarSuffix) {
        auto fwd_var_name = out_var_name.substr(0, len);
162 163 164 165
        OP_INOUT_CHECK(HasVar(ctx, out_var_name), "Var", out_var_name,
                       "py_func");
        OP_INOUT_CHECK(HasVar(ctx, fwd_var_name), "Var", fwd_var_name,
                       "py_func");
S
sneaxiy 已提交
166 167
        VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
                 << fwd_var_name << ")";
M
minqiyang 已提交
168

169 170 171 172
        SetShape(ctx, out_var_name, GetShape(ctx, fwd_var_name));
        SetDataType(ctx, out_var_name, GetDataType(ctx, fwd_var_name));
        SetLoDLevel(ctx, out_var_name, GetLoDLevel(ctx, fwd_var_name));
        SetType(ctx, out_var_name, GetType(ctx, fwd_var_name));
S
sneaxiy 已提交
173 174
      }
    }
S
sneaxiy 已提交
175 176 177
  }
};

S
sneaxiy 已提交
178 179 180
class PyFuncOpShapeInference : public framework::InferShapeBase {
 public:
  void operator()(framework::InferShapeContext *ctx) const override {
181 182 183
    PADDLE_ENFORCE_EQ(!ctx->IsRuntime(), true,
                      platform::errors::InvalidArgument(
                          "Infer shape cannot be called in runtime."));
S
sneaxiy 已提交
184 185 186
  }
};

S
sneaxiy 已提交
187 188 189 190 191
class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "Inputs of py_func op.").AsDuplicable();
    AddOutput("Out", "Outputs of py_func op").AsDuplicable();
S
sneaxiy 已提交
192 193
    AddAttr<int>(kForwardPythonCallableId,
                 "Index of registered forward Python function.")
S
sneaxiy 已提交
194
        .SetDefault(0);
S
sneaxiy 已提交
195
    AddAttr<int>(kBackwardPythonCallableId,
S
sneaxiy 已提交
196
                 "Index of registered backward Python function.")
S
sneaxiy 已提交
197 198 199 200
        .SetDefault(-1);
    AddAttr<std::vector<std::string>>(kPyFuncBackwardSkipVars,
                                      "Unused forward in/out in backward op")
        .SetDefault(std::vector<std::string>());
S
sneaxiy 已提交
201 202 203 204
    AddComment(R"DOC("PyFunc Op")DOC");
  }
};

S
sneaxiy 已提交
205 206 207 208 209 210 211 212 213 214
/**
 * There are several benefits when backward op of py_func op is
 * still py_func op.
 *
 *  - Less codes are needed, since codes of backward is almost
 *    the same as forward.
 *
 *  - To support high order derivative, so that py_func is
 *    infinite-order differentiable
 */
S
sneaxiy 已提交
215
class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
S
sneaxiy 已提交
216 217 218 219 220 221 222 223 224 225 226
 private:
  static std::string DebugString(const std::vector<std::string> &strs) {
    if (strs.empty()) return "";
    std::string ret = strs[0];
    for (size_t i = 1; i < strs.size(); ++i) {
      ret += " ";
      ret += strs[i];
    }
    return ret;
  }

S
sneaxiy 已提交
227 228 229 230 231
 public:
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;

  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
    auto &fwd_attrs = Attrs();
S
sneaxiy 已提交
232
    // no backward op when backward_id is less than 0
233
    if (BOOST_GET_CONST(int, fwd_attrs.at(kBackwardPythonCallableId)) < 0) {
S
sneaxiy 已提交
234 235 236 237 238 239 240
      return {};
    }

    std::unique_ptr<framework::OpDesc> grad_op(new framework::OpDesc());
    grad_op->SetType("py_func");

    framework::AttributeMap bwd_attrs;
S
sneaxiy 已提交
241 242 243
    bwd_attrs[kForwardPythonCallableId] =
        fwd_attrs.at(kBackwardPythonCallableId);
    bwd_attrs[kBackwardPythonCallableId] = -1;
S
sneaxiy 已提交
244 245
    grad_op->SetAttrMap(bwd_attrs);

S
sneaxiy 已提交
246 247 248 249 250 251
    // All forward inputs
    auto fwd_ins = Input("X");
    // All forward outputs
    auto fwd_outs = Output("Out");

    // For memory reused, some inputs/output in forward part may be not needed
S
sneaxiy 已提交
252
    // in backward part. Skipping these vars helps to save memory
253 254
    auto &backward_skip_var_list = BOOST_GET_CONST(
        std::vector<std::string>, fwd_attrs.at(kPyFuncBackwardSkipVars));
S
sneaxiy 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    std::unordered_set<std::string> backward_skip_var_set(
        backward_skip_var_list.begin(), backward_skip_var_list.end());
    std::vector<std::string> bwd_ins;
    bwd_ins.reserve(fwd_ins.size() + fwd_outs.size());
    for (auto &fwd_in : fwd_ins) {
      if (backward_skip_var_set.count(fwd_in) == 0) {
        bwd_ins.emplace_back(fwd_in);
      }
    }

    for (auto &fwd_out : fwd_outs) {
      if (backward_skip_var_set.count(fwd_out) == 0) {
        bwd_ins.emplace_back(fwd_out);
      }
    }

    // Backward OG cannot be skipped
    // But in Python side, if OG is kEmptyVarName, input tensor would be None
    auto fwd_out_grads = OutputGrad("Out");
    bwd_ins.reserve(bwd_ins.size() + fwd_out_grads.size());
    bwd_ins.insert(bwd_ins.end(), fwd_out_grads.begin(), fwd_out_grads.end());
S
sneaxiy 已提交
276

S
sneaxiy 已提交
277 278 279
    // Backward IG cannot be skipped
    // But in Python side, if IG is not needed, users can just return None
    auto bwd_outs = InputGrad("X", false);
S
sneaxiy 已提交
280

S
sneaxiy 已提交
281 282
    VLOG(10) << "PyFunc Grad Input: " << DebugString(bwd_ins);
    VLOG(10) << "PyFunc Grad Output: " << DebugString(bwd_outs);
S
sneaxiy 已提交
283

S
sneaxiy 已提交
284 285
    grad_op->SetInput("X", bwd_ins);
    grad_op->SetOutput("Out", bwd_outs);
S
sneaxiy 已提交
286 287 288 289 290 291 292

    std::vector<std::unique_ptr<framework::OpDesc>> ret(1);
    ret[0] = std::move(grad_op);
    return ret;
  }
};

S
sneaxiy 已提交
293 294 295 296 297 298 299 300 301 302 303 304
class PyFuncOp : public framework::OperatorBase {
 public:
  using framework::OperatorBase::OperatorBase;

 protected:
  void RunImpl(const framework::Scope &scope,
               const platform::Place &place) const override {
    auto &in_arg_names = Inputs("X");
    auto &out_arg_names = Outputs("Out");

    std::vector<framework::LoDTensor> inputs(in_arg_names.size());
    for (size_t i = 0; i < in_arg_names.size(); ++i) {
S
sneaxiy 已提交
305
      auto in_var = scope.FindVar(in_arg_names[i]);
S
sneaxiy 已提交
306
      // When py_func op is called in backward, in_var may be null
S
sneaxiy 已提交
307 308 309 310 311 312 313
      if (in_var == nullptr) {
        continue;
      }
      auto &in_tensor = in_var->Get<framework::LoDTensor>();
      if (!in_tensor.IsInitialized()) {
        continue;
      }
S
sneaxiy 已提交
314 315 316 317 318 319 320 321 322 323
      if (platform::is_gpu_place(in_tensor.place())) {
        framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]);
      } else {
        inputs[i].ShareDataWith(in_tensor);
      }
      inputs[i].set_lod(in_tensor.lod());
    }

    std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
    for (size_t i = 0; i < out_arg_names.size(); ++i) {
S
sneaxiy 已提交
324
      auto *out_var = scope.FindVar(out_arg_names[i]);
S
sneaxiy 已提交
325
      outputs[i] =
S
sneaxiy 已提交
326
          out_var ? out_var->GetMutable<framework::LoDTensor>() : nullptr;
S
sneaxiy 已提交
327 328
    }

S
sneaxiy 已提交
329 330
    auto callable_id = static_cast<size_t>(Attr<int>(kForwardPythonCallableId));
    auto *py_callable = GetPythonCallableObject(callable_id);
S
sneaxiy 已提交
331 332
    VLOG(10) << "Call Python function with id " << callable_id << ": "
             << PythonFuncDebugString(*py_callable);
S
sneaxiy 已提交
333
    CallPythonFunc(py_callable, inputs, &outputs);
S
sneaxiy 已提交
334 335 336 337 338 339 340 341 342
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
M
minqiyang 已提交
343
                  ops::PyFuncOpVarTypeInference, ops::PyFuncOpShapeInference,
S
sneaxiy 已提交
344
                  ops::PyFuncOpGradDescMaker);