diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h index e01fd6e0016f147303f07da775007b5146c03575..2d7d319203833337abd52371730bc409ad8e0a1f 100644 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -81,9 +81,6 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, auto context = bk_function(); auto forward = cls.attr("forward"); - auto result_forward = forward(context, *args, **kwargs); - std::shared_ptr py_layer_ctx = - std::make_shared(context.ptr()); // make inputs to varbase std::vector> input_vars; // process args,`input_vars` only collect `imperative::VarBase` @@ -101,6 +98,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr) || + py::isinstance(*ptr)) { + try { + auto tuple_arg = ptr->cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } } } } @@ -119,9 +138,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->second.ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr->second) || + py::isinstance(*ptr->second)) { + try { + auto tuple_arg = ptr->second.cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } } } } + + std::shared_ptr py_layer_ctx = + std::make_shared(context.ptr()); + auto result_forward = forward(context, *args, **kwargs); NameVarBaseMap ins = {{"X", input_vars}}; std::vector> output_vars;