diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h index e01fd6e0016f147303f07da775007b5146c03575..159371970dcacf0701eb056fdf5caf5c55c473ab 100644 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -101,6 +101,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr) || + py::isinstance(*ptr)) { + try { + auto tuple_arg = ptr->cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } } } } @@ -119,6 +141,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->second.ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr->second) || + py::isinstance(*ptr->second)) { + try { + auto tuple_arg = ptr->second.cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } } } }