From 0a088ab76427221029273fc97736bb346f79e0c9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 24 Aug 2020 11:29:48 -0700 Subject: [PATCH] Delay the requirement of the device/pyclient to the first call. Because jax.jit is annotated on top-level functions, and that backends are initialised only after GoogleInit, we cannot have the client when calling the C++ jit. PiperOrigin-RevId: 328177723 Change-Id: I462385e4a687461bf41c0d3e2875d4736285c42d --- tensorflow/compiler/xla/python/jax_jit.cc | 114 +++++++++++++++------- 1 file changed, 81 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index 96cf1e64b85..239cbcaee8d 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -26,7 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/python/jax_jit.h" +#include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -230,12 +232,13 @@ struct CacheEntry { // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the // bookkeeping of the different signatures used and the dispatch of calls to // the correct underlying `PyExecutable`. +// TODO(jblespiau): This class is thread-unsafe. Note that using a mutex for the +// full `Call` will lead to a deadlock because it goes back to Python which will +// release the GIL. class CompiledFunction { public: CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted, - bool jax_enable_x64, std::vector static_argnums, - std::shared_ptr pyclient, - xla::PjRtDevice* device); + bool jax_enable_x64, std::vector static_argnums); ~CompiledFunction(); // This function will: @@ -246,9 +249,20 @@ class CompiledFunction { // (e) reconstruct the `PyTree`. py::object Call(py::args args, py::kwargs kwargs); + // This allows `inspect.signature(cpp_jitted_f)` from Python. + py::object __signature__() { + static const auto* inspect = new py::module(py::module::import("inspect")); + return inspect->attr("signature")(python_f_jitted_); + } + private: CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs, - const CallSignature& signature); + const CallSignature& signature, + absl::optional cache_miss_return); + CacheEntry& SetAndReturnCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return = absl::nullopt); // The Python function in charge of returning a `xla::PyExecutable` from // the arguments passed to `jitted_f`. @@ -267,22 +281,22 @@ class CompiledFunction { // We need a `unique_ptr` here to ensure value pointer stability. absl::flat_hash_map> executables_; - const std::shared_ptr pyclient_; - xla::PjRtDevice* const default_device_; + // As top-level functions are decorated with `jax.jit`, when + // `CompiledFunction` is being instantiated from Python, the clients are not + // yet available (done after GoogleInit). They will be during the first call + // to `Call`. + std::shared_ptr pyclient_ = nullptr; + xla::PjRtDevice* default_device_ = nullptr; }; CompiledFunction::CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted, bool jax_enable_x64, - std::vector static_argnums, - std::shared_ptr pyclient, - xla::PjRtDevice* device) + std::vector static_argnums) : cache_miss_fun_(std::move(cache_miss_fun)), python_f_jitted_(std::move(python_f_jitted)), jax_enable_x64_(jax_enable_x64), - static_argnums_(std::move(static_argnums)), - pyclient_(std::move(pyclient)), - default_device_(device) { + static_argnums_(std::move(static_argnums)) { std::sort(static_argnums_.begin(), static_argnums_.end()); } @@ -493,8 +507,21 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, xla::PjRtDevice* data_device = nullptr; for (py::handle arg : arguments.flat_dynamic_args) { if (py::isinstance(arg, device_array)) { - xla::PyBuffer* buffer = - py::cast(arg.attr("device_buffer")); + xla::PyBuffer* buffer; + try { + // This can fail, e.g. when device_buffer is a `DeviceConstant`. + buffer = py::cast(arg.attr("device_buffer")); + } catch (const py::cast_error& e) { + return InvalidArgument( + "%s", + absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: " + "`device_buffer` field is of type ", + py::cast( + arg.attr("device_buffer").get_type().str()), + " while a `PyBuffer` was expected." + + )); + } xla::PjRtDevice* device = buffer->buffer()->device(); if (data_device && (device != data_device)) { return InvalidArgument( @@ -577,14 +604,20 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, } // namespace -CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args, - const py::kwargs& kwargs, - const CallSignature& signature) { +CacheEntry& CompiledFunction::GetCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return) { auto found_iterator = executables_.find(signature); if (found_iterator != executables_.end()) { // Cache hit! return *(found_iterator->second); } - + return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return); +} +CacheEntry& CompiledFunction::SetAndReturnCacheEntry( + const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature, + absl::optional cache_miss_return) { // We need to insert the element. auto result = executables_.emplace(signature, std::make_unique()); auto it = result.first; @@ -593,7 +626,12 @@ CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args, result.first->first.IncRef(); // Cache miss? Call the Python cache miss function. - py::tuple executable_and_pytree = cache_miss_fun_(*args, **kwargs); + py::tuple executable_and_pytree; + if (cache_miss_return) { + executable_and_pytree = cache_miss_return.value(); + } else { + executable_and_pytree = cache_miss_fun_(*args, **kwargs); + } if (executable_and_pytree.size() != 4) { throw std::runtime_error( "AssertionError: The cache miss function should return 4 " @@ -639,6 +677,16 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { ParsedArgumentsAsBuffers arguments; FlattenArguments(args, kwargs, static_argnums_, arguments); + absl::optional cache_miss_result = absl::nullopt; + if (!default_device_) { + cache_miss_result = cache_miss_fun_(*args, **kwargs); + auto executable = py::cast>( + cache_miss_result.value()[0]); + + pyclient_ = executable->client(); + default_device_ = executable->LocalDevices()[0].contents; + } + // The C++ jit do not support Tracers arguments yet. The Python-based jit // function will be called if any of the dynamic arguments is unsupported. if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_, @@ -647,7 +695,8 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { return python_f_jitted_(*args, **kwargs); } - CacheEntry& cache_entry = GetCacheEntry(args, kwargs, arguments.signature); + CacheEntry& cache_entry = + GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result); std::vector> outputs = ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers)); @@ -677,19 +726,18 @@ void BuildJaxjitSubmodule(pybind11::module& m) { py::class_> cfun( jitlib, "CompiledFunction"); cfun.def("__call__", &CompiledFunction::Call); - - jitlib.def("jit", - [](py::function cache_miss_fun, - py::function fallback_on_unsupported_argument, - bool jax_enable_x64, std::vector static_argnums, - xla::ClientAndPtr client_and_device) - -> std::unique_ptr { - return std::make_unique( - std::move(cache_miss_fun), - std::move(fallback_on_unsupported_argument), jax_enable_x64, - std::move(static_argnums), client_and_device.client, - client_and_device.contents); - }); + cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__); + + jitlib.def( + "jit", + [](py::function cache_miss_fun, + py::function fallback_on_unsupported_argument, bool jax_enable_x64, + std::vector static_argnums) -> std::unique_ptr { + return std::make_unique( + std::move(cache_miss_fun), + std::move(fallback_on_unsupported_argument), jax_enable_x64, + std::move(static_argnums)); + }); // Only for testing purposes jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64, -- GitLab