未验证 提交 19a9f885 编写于 作者: L liufengwei0103 提交者: GitHub

Fix import oneflow error (#6401)

* export initNumpyCAPI

* refine

* refine

* refine

* refine

* refine

* refine
上级 bc72be16
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/extension/python/numpy.h"
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitNumpyCAPI", []() { return oneflow::numpy::InitNumpyCAPI().GetOrThrow(); });
}
......@@ -20,8 +20,9 @@ limitations under the License.
namespace oneflow {
// Note: there is a time interval between catching error and reporting an error,
// any error occur in this interval can't be displayed.
Maybe<void> CheckAndClearRegistryFlag();
void CatchRegistryError(const std::function<Maybe<void>()>&);
} // namespace oneflow
......
......@@ -60,15 +60,12 @@ Maybe<DataType> GetOFDataTypeFromNpArray(PyArrayObject* array) {
// defined in numpy_internal.h
// Reference:
// https://numpy.org/doc/stable/reference/c-api/array.html#importing-the-api
void InitNumpyCAPI() {
CatchRegistryError([]() -> Maybe<void> {
CHECK_ISNULL_OR_RETURN(PyArray_API);
CHECK_EQ_OR_RETURN(_import_array(), 0);
return Maybe<void>::Ok();
});
Maybe<void> InitNumpyCAPI() {
CHECK_ISNULL_OR_RETURN(PyArray_API);
CHECK_EQ_OR_RETURN(_import_array(), 0)
<< ". Unable to import Numpy array, try to upgrade Numpy version!";
return Maybe<void>::Ok();
}
COMMAND(InitNumpyCAPI());
} // namespace numpy
} // namespace oneflow
......@@ -40,5 +40,7 @@ Maybe<DataType> NumpyTypeToOFDataType(int np_array_type);
Maybe<DataType> GetOFDataTypeFromNpArray(PyArrayObject* array);
Maybe<void> InitNumpyCAPI();
} // namespace numpy
} // namespace oneflow
......@@ -19,6 +19,7 @@ import collections
import oneflow._oneflow_internal
oneflow._oneflow_internal.InitNumpyCAPI()
oneflow._oneflow_internal.CheckAndClearRegistryFlag()
Size = oneflow._oneflow_internal.Size
device = oneflow._oneflow_internal.device
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册