提交 3f1d45e9 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add support for "disable_jit".

PiperOrigin-RevId: 328181751
Change-Id: Icc9c27218a9024e64b5a95a1cd35d83c46ab876b
上级 89a8b192
......@@ -54,11 +54,14 @@ namespace xla {
namespace py = pybind11;
// TODO(phawkins): Add support for Tracers.
// TODO(jblespiau): Add support for donate_argnums.
// TODO(jblespiau): Use absl Status.
namespace {
thread_local bool disable_jit;
void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
bool GetDisableJit() { return disable_jit; }
// Describes the abstract shape and dtype of an argument.
struct ArgSignature {
// This is the XLA dtype of the object.
......@@ -237,8 +240,9 @@ struct CacheEntry {
// release the GIL.
class CompiledFunction {
public:
CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted,
bool jax_enable_x64, std::vector<int> static_argnums);
CompiledFunction(py::function fun, py::function cache_miss_fun,
py::function python_f_jitted, bool jax_enable_x64,
bool jax_disable_jit, std::vector<int> static_argnums);
~CompiledFunction();
// This function will:
......@@ -252,7 +256,7 @@ class CompiledFunction {
// This allows `inspect.signature(cpp_jitted_f)` from Python.
py::object __signature__() {
static const auto* inspect = new py::module(py::module::import("inspect"));
return inspect->attr("signature")(python_f_jitted_);
return inspect->attr("signature")(fun_);
}
private:
......@@ -263,7 +267,9 @@ class CompiledFunction {
const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
absl::optional<py::tuple> cache_miss_return = absl::nullopt);
bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_; }
const py::function fun_; // The Python function to jit.
// The Python function in charge of returning a `xla::PyExecutable` from
// the arguments passed to `jitted_f`.
const py::function cache_miss_fun_;
......@@ -274,6 +280,7 @@ class CompiledFunction {
// The value of the Python flag when the object was created.
const bool jax_enable_x64_;
const bool jax_disable_jit_;
// We need to know the static arguments to remove them from the arguments
// passed to the underlying PyExecutable. In sorted order.
......@@ -289,13 +296,16 @@ class CompiledFunction {
xla::PjRtDevice* default_device_ = nullptr;
};
CompiledFunction::CompiledFunction(py::function cache_miss_fun,
CompiledFunction::CompiledFunction(py::function fun,
py::function cache_miss_fun,
py::function python_f_jitted,
bool jax_enable_x64,
bool jax_enable_x64, bool jax_disable_jit,
std::vector<int> static_argnums)
: cache_miss_fun_(std::move(cache_miss_fun)),
: fun_(std::move(fun)),
cache_miss_fun_(std::move(cache_miss_fun)),
python_f_jitted_(std::move(python_f_jitted)),
jax_enable_x64_(jax_enable_x64),
jax_disable_jit_(jax_disable_jit),
static_argnums_(std::move(static_argnums)) {
std::sort(static_argnums_.begin(), static_argnums_.end());
}
......@@ -562,8 +572,6 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
py::array numpy_array = py::cast<py::array>(arg);
// If jax_enable_x64 is not set, we need to coerce 32 bits types.
// Note that this is calling back to Python!
// TODO(jblespiau): We can remove this complexity when we delete
// jax_enable_x64 mode.
if (!jax_enable_x64) {
const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
if (to_dtype) {
......@@ -674,6 +682,9 @@ CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
}
py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
if (JitIsDisabled()) {
return fun_(*args, **kwargs);
}
ParsedArgumentsAsBuffers arguments;
FlattenArguments(args, kwargs, static_argnums_, arguments);
......@@ -728,15 +739,18 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
cfun.def("__call__", &CompiledFunction::Call);
cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
jitlib.def("set_disable_jit", &SetDisableJit);
jitlib.def("get_disable_jit", &GetDisableJit);
jitlib.def(
"jit",
[](py::function cache_miss_fun,
[](py::function fun, py::function cache_miss_fun,
py::function fallback_on_unsupported_argument, bool jax_enable_x64,
bool jax_disable_jit,
std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
return std::make_unique<CompiledFunction>(
std::move(cache_miss_fun),
std::move(fun), std::move(cache_miss_fun),
std::move(fallback_on_unsupported_argument), jax_enable_x64,
std::move(static_argnums));
jax_disable_jit, std::move(static_argnums));
});
// Only for testing purposes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册