未验证 提交 97ccaa79 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[Eager][Yaml]Supported Scalar and ScalarArray for AutoCodeGen (#40080)

上级 b4665d23
...@@ -31,7 +31,9 @@ yaml_types_mapping = { ...@@ -31,7 +31,9 @@ yaml_types_mapping = {
'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64_t[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
'Tensor[]' : 'std::vector<Tensor>', 'Tensor[]' : 'std::vector<Tensor>',
'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>' 'Tensor[Tensor[]]' : 'std::vector<std::vector<Tensor>>',
'Scalar' : 'Scalar',
'ScalarArray' : 'ScalarArray'
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import argparse import argparse
from eager_gen import ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = { atype_to_parsing_function = {
"bool": "CastPyArg2Boolean", "bool": "CastPyArg2Boolean",
...@@ -27,21 +27,9 @@ atype_to_parsing_function = { ...@@ -27,21 +27,9 @@ atype_to_parsing_function = {
"long[]": "CastPyArg2Longs", "long[]": "CastPyArg2Longs",
"float[]": "CastPyArg2Floats", "float[]": "CastPyArg2Floats",
"double[]": "CastPyArg2Float64s", "double[]": "CastPyArg2Float64s",
"string[]": "CastPyArg2Strings" "string[]": "CastPyArg2Strings",
} "Scalar": "CastPyArg2Scalar",
"ScalarArray": "CastPyArg2ScalarArray"
atype_to_cxx_type = {
"bool": "bool",
"int": "int",
"long": "long",
"float": "float",
"string": "std::string",
"bool[]": "std::vector<bool>",
"int[]": "std::vector<int>",
"long[]": "std::vector<long>",
"float[]": "std::vector<float>",
"double[]": "std::vector<double>",
"string[]": "std::vector<std::string>"
} }
...@@ -56,10 +44,10 @@ def ParseArguments(): ...@@ -56,10 +44,10 @@ def ParseArguments():
def GetCxxType(atype): def GetCxxType(atype):
if atype not in atype_to_cxx_type.keys(): if atype not in yaml_types_mapping.keys():
assert False assert False
return atype_to_cxx_type[atype] return yaml_types_mapping[atype]
def FindParsingFunctionFromAttributeType(atype): def FindParsingFunctionFromAttributeType(atype):
......
...@@ -587,14 +587,9 @@ paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs( ...@@ -587,14 +587,9 @@ paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
reinterpret_cast<TensorObject*>(obj)->tensor); reinterpret_cast<TensorObject*>(obj)->tensor);
} }
// For Intermediate State Dygraph, static paddle::experimental::Tensor& GetTensorFromPyObject(
// we use an uninitialized Tensor to represent dispensable Tensor const std::string& op_type, const std::string& arg_name, PyObject* obj,
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type, ssize_t arg_idx, bool dispensable) {
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable) {
PyObject* obj = PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check(obj)) { if (PyTuple_Check(obj)) {
obj = PyTuple_GET_ITEM(obj, 0); obj = PyTuple_GET_ITEM(obj, 0);
} }
...@@ -612,6 +607,16 @@ paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type, ...@@ -612,6 +607,16 @@ paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
return reinterpret_cast<TensorObject*>(obj)->tensor; return reinterpret_cast<TensorObject*>(obj)->tensor;
} }
// For Intermediate State Dygraph,
// we use an uninitialized Tensor to represent dispensable Tensor
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable) {
PyObject* obj = PyTuple_GET_ITEM(args, arg_idx);
return GetTensorFromPyObject(op_type, arg_name, obj, arg_idx, dispensable);
}
std::vector<paddle::experimental::Tensor> GetTensorListFromArgs( std::vector<paddle::experimental::Tensor> GetTensorListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args, const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable) { ssize_t arg_idx, bool dispensable) {
...@@ -746,5 +751,84 @@ std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs( ...@@ -746,5 +751,84 @@ std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
return result; return result;
} }
paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
// obj could be: int, float, bool, paddle.Tensor
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (type_name == "int") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "float") {
float value = CastPyArg2Float(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "bool") {
bool value = CastPyArg2Boolean(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "paddle.Tensor") {
paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::Scalar(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
// Fake a Scalar
return paddle::experimental::Scalar(1.0);
}
paddle::experimental::ScalarArray CastPyArg2ScalarArray(
PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
// In case of ScalarArray, only two possible PyObjects:
// 1. list of int
// 2. Tensor
if (obj == Py_None) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
// obj could be: int, float, bool, paddle.Tensor
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (type_name == "list") {
std::vector<int> value = CastPyArg2Ints(obj, op_type, arg_pos);
return paddle::experimental::ScalarArray(value);
} else if (type_name == "paddle.Tensor") {
paddle::experimental::Tensor& value = GetTensorFromPyObject(
op_type, "" /*arg_name*/, obj, arg_pos, false /*dispensable*/);
return paddle::experimental::ScalarArray(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
// Fake a ScalarArray
return paddle::experimental::ScalarArray({1});
}
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -11,7 +11,10 @@ limitations under the License. */ ...@@ -11,7 +11,10 @@ limitations under the License. */
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
namespace paddle { namespace paddle {
...@@ -90,6 +93,13 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) { ...@@ -90,6 +93,13 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
return result; return result;
} }
paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
paddle::experimental::ScalarArray CastPyArg2ScalarArray(
PyObject* obj, const std::string& op_type, ssize_t arg_pos);
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs( paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args, const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false); ssize_t arg_idx, bool dispensable = false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册