From 861053eb8fd920d47f33aa81d67b606382b4d82d Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 16 Dec 2021 10:29:02 +0800 Subject: [PATCH] pylayer support tuple/list type args and fix check args bug (#38146) * Revert "Revert "pylayer support tuple/list type args (#37727)" (#37956)" This reverts commit d848ff04093673d07288f46721c48aded8a7e918. * move check args,kwargs before forward execute --- paddle/fluid/imperative/py_layer_fwd.h | 51 ++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h index e01fd6e001..2d7d319203 100644 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -81,9 +81,6 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, auto context = bk_function(); auto forward = cls.attr("forward"); - auto result_forward = forward(context, *args, **kwargs); - std::shared_ptr py_layer_ctx = - std::make_shared(context.ptr()); // make inputs to varbase std::vector> input_vars; // process args,`input_vars` only collect `imperative::VarBase` @@ -101,6 +98,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr) || + py::isinstance(*ptr)) { + try { + auto tuple_arg = ptr->cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->ptr()->ob_type->tp_name)); + } } } } @@ -119,9 +138,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, "`%s` type argument can not be cast into `Tensor`.", ptr->second.ptr()->ob_type->tp_name)); } + } else if (py::isinstance(*ptr->second) || + py::isinstance(*ptr->second)) { + try { + auto tuple_arg = ptr->second.cast(); + for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { + try { + auto t = iter->cast>(); + input_vars.push_back(t); + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, " + "the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } + } + } catch (py::cast_error& err) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The `PyLayer.forward` function contains invalid argument, the " + "`%s` type argument can not be cast into `Tensor`.", + ptr->second.ptr()->ob_type->tp_name)); + } } } } + + std::shared_ptr py_layer_ctx = + std::make_shared(context.ptr()); + auto result_forward = forward(context, *args, **kwargs); NameVarBaseMap ins = {{"X", input_vars}}; std::vector> output_vars; -- GitLab