未验证 提交 3d09929b 编写于 作者: L Leo Chen 提交者: GitHub

Add check for non-dispensable input (#28666)

* Add check for non-dispensable input

* fix typo
上级 19226ba8
...@@ -36,9 +36,15 @@ namespace pybind { ...@@ -36,9 +36,15 @@ namespace pybind {
static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase( static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
const std::string& op_type, const std::string& arg_name, int arg_idx, const std::string& op_type, const std::string& arg_name, int arg_idx,
const py::handle& handle) { const py::handle& handle, bool dispensable = false) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) { if (!py_obj || py_obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
return nullptr; return nullptr;
} }
try { try {
...@@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase( ...@@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
static inline std::vector<std::shared_ptr<imperative::VarBase>> static inline std::vector<std::shared_ptr<imperative::VarBase>>
CastPyHandleToVarBaseList(const std::string& op_type, CastPyHandleToVarBaseList(const std::string& op_type,
const std::string& arg_name, int arg_idx, const std::string& arg_name, int arg_idx,
const py::handle& handle) { const py::handle& handle, bool dispensable = false) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) { if (!py_obj || py_obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
return {}; return {};
} }
std::vector<std::shared_ptr<imperative::VarBase>> result; std::vector<std::shared_ptr<imperative::VarBase>> result;
......
...@@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)"; ...@@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"( const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)"; auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"( const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)"; auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)"; const char* ARG_TEMPLATE = R"(const %s& %s)";
...@@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) {
input_args_num++; input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += ins_cast_str +=
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
arg_idx++, TempName(in_name)); arg_idx++, TempName(in_name), dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册