未验证 提交 79c25979 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] fix cmake generate error, and fix circular import (#37871)

* refine a test case, test=develop

* rm python, test=develop

* refine, test=develop

* fix cmake generate error, and fix circular import, test=develop
上级 506e79d1
......@@ -42,7 +42,7 @@ static PyObject* eager_tensor_method_numpy(EagerTensorObject* self,
return Py_None;
}
auto tensor_dims = self->eagertensor.shape();
auto numpy_dtype = pten::TensorDtype2NumpyDtype(self->eagertensor.type());
auto numpy_dtype = TensorDtype2NumpyDtype(self->eagertensor.type());
auto sizeof_dtype = pten::DataTypeSize(self->eagertensor.type());
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
......
......@@ -17,9 +17,11 @@ limitations under the License. */
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
......@@ -37,6 +39,38 @@ extern PyTypeObject* g_xpuplace_pytype;
extern PyTypeObject* g_npuplace_pytype;
extern PyTypeObject* g_cudapinnedplace_pytype;
int TensorDtype2NumpyDtype(pten::DataType dtype) {
switch (dtype) {
case pten::DataType::BOOL:
return pybind11::detail::npy_api::NPY_BOOL_;
case pten::DataType::INT8:
return pybind11::detail::npy_api::NPY_INT8_;
case pten::DataType::UINT8:
return pybind11::detail::npy_api::NPY_UINT8_;
case pten::DataType::INT16:
return pybind11::detail::npy_api::NPY_INT16_;
case pten::DataType::INT32:
return pybind11::detail::npy_api::NPY_INT32_;
case pten::DataType::INT64:
return pybind11::detail::npy_api::NPY_INT64_;
case pten::DataType::FLOAT16:
return pybind11::detail::NPY_FLOAT16_;
case pten::DataType::FLOAT32:
return pybind11::detail::npy_api::NPY_FLOAT_;
case pten::DataType::FLOAT64:
return pybind11::detail::npy_api::NPY_DOUBLE_;
case pten::DataType::COMPLEX64:
return pybind11::detail::NPY_COMPLEX64;
case pten::DataType::COMPLEX128:
return pybind11::detail::NPY_COMPLEX128;
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow pten::DataType, the int value = %d.",
static_cast<int>(dtype)));
return 0;
}
}
bool PyObject_CheckLongOrConvertToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj))) {
return true;
......
......@@ -21,6 +21,8 @@ typedef struct {
PyObject_HEAD egr::EagerTensor eagertensor;
} EagerTensorObject;
int TensorDtype2NumpyDtype(pten::DataType dtype);
bool PyObject_CheckLongOrConvertToLong(PyObject** obj);
bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj);
bool PyObject_CheckStr(PyObject* obj);
......
if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info python)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info python)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place python)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place)
endif()
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce)
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/core/convert_utils.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/pybind/tensor_py.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
......@@ -272,36 +270,4 @@ std::string DataType2String(DataType dtype) {
}
}
int TensorDtype2NumpyDtype(pten::DataType dtype) {
switch (dtype) {
case pten::DataType::BOOL:
return pybind11::detail::npy_api::NPY_BOOL_;
case pten::DataType::INT8:
return pybind11::detail::npy_api::NPY_INT8_;
case pten::DataType::UINT8:
return pybind11::detail::npy_api::NPY_UINT8_;
case pten::DataType::INT16:
return pybind11::detail::npy_api::NPY_INT16_;
case pten::DataType::INT32:
return pybind11::detail::npy_api::NPY_INT32_;
case pten::DataType::INT64:
return pybind11::detail::npy_api::NPY_INT64_;
case pten::DataType::FLOAT16:
return pybind11::detail::NPY_FLOAT16_;
case pten::DataType::FLOAT32:
return pybind11::detail::npy_api::NPY_FLOAT_;
case pten::DataType::FLOAT64:
return pybind11::detail::npy_api::NPY_DOUBLE_;
case pten::DataType::COMPLEX64:
return pybind11::detail::NPY_COMPLEX64;
case pten::DataType::COMPLEX128:
return pybind11::detail::NPY_COMPLEX128;
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow pten::DataType, the int value = %d.",
static_cast<int>(dtype)));
return 0;
}
}
} // namespace pten
......@@ -48,6 +48,5 @@ pten::LoD TransToPtenLoD(const paddle::framework::LoD& lod);
size_t DataTypeSize(DataType dtype);
DataType String2DataType(const std::string& str);
std::string DataType2String(DataType dtype);
int TensorDtype2NumpyDtype(pten::DataType dtype);
} // namespace pten
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.core as core
from .. import core as core
def monkey_patch_eagertensor():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册