diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h index 1baf73ab3b95da869922e5d4745c91356025799e..79251d7bf7ad6b2ffe0168649c207878d16336da 100644 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -101,6 +101,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr) || + py::isinstance(*ptr)) { + try { + auto tuple_arg = ptr->cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } } } } @@ -119,6 +141,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->second.ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr->second) || + py::isinstance(*ptr->second)) { + try { + auto tuple_arg = ptr->second.cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } } } }