未验证 提交 861053eb 编写于 作者: C chentianyu03 提交者: GitHub

pylayer support tuple/list type args and fix check args bug (#38146)

* Revert "Revert "pylayer support tuple/list type args (#37727)" (#37956)"

This reverts commit d848ff04.

* move check args,kwargs before forward execute
上级 9bac4a76
...@@ -81,9 +81,6 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -81,9 +81,6 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
auto context = bk_function(); auto context = bk_function();
auto forward = cls.attr("forward"); auto forward = cls.attr("forward");
auto result_forward = forward(context, *args, **kwargs);
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.ptr());
// make inputs to varbase // make inputs to varbase
std::vector<std::shared_ptr<imperative::VarBase>> input_vars; std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
// process args,`input_vars` only collect `imperative::VarBase` // process args,`input_vars` only collect `imperative::VarBase`
...@@ -101,6 +98,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -101,6 +98,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
"`%s` type argument can not be cast into `Tensor`.", "`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name)); ptr->ptr()->ob_type->tp_name));
} }
} else if (py::isinstance<py::tuple>(*ptr) ||
py::isinstance<py::list>(*ptr)) {
try {
auto tuple_arg = ptr->cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
} }
} }
} }
...@@ -119,9 +138,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -119,9 +138,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
"`%s` type argument can not be cast into `Tensor`.", "`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name)); ptr->second.ptr()->ob_type->tp_name));
} }
} else if (py::isinstance<py::tuple>(*ptr->second) ||
py::isinstance<py::list>(*ptr->second)) {
try {
auto tuple_arg = ptr->second.cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
} }
} }
} }
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.ptr());
auto result_forward = forward(context, *args, **kwargs);
NameVarBaseMap ins = {{"X", input_vars}}; NameVarBaseMap ins = {{"X", input_vars}};
std::vector<std::shared_ptr<imperative::VarBase>> output_vars; std::vector<std::shared_ptr<imperative::VarBase>> output_vars;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册