提交 0a088ab7 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Delay the requirement of the device/pyclient to the first call.

Because jax.jit is annotated on top-level functions, and that backends are initialised only after GoogleInit, we cannot have the client when calling the C++ jit.

PiperOrigin-RevId: 328177723
Change-Id: I462385e4a687461bf41c0d3e2875d4736285c42d
上级 c38ffdb5
......@@ -26,7 +26,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/jax_jit.h"
#include <exception>
#include <memory>
#include <optional>
#include <stdexcept>
#include "absl/container/flat_hash_map.h"
......@@ -230,12 +232,13 @@ struct CacheEntry {
// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
// bookkeeping of the different signatures used and the dispatch of calls to
// the correct underlying `PyExecutable`.
// TODO(jblespiau): This class is thread-unsafe. Note that using a mutex for the
// full `Call` will lead to a deadlock because it goes back to Python which will
// release the GIL.
class CompiledFunction {
public:
CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted,
bool jax_enable_x64, std::vector<int> static_argnums,
std::shared_ptr<xla::PyClient> pyclient,
xla::PjRtDevice* device);
bool jax_enable_x64, std::vector<int> static_argnums);
~CompiledFunction();
// This function will:
......@@ -246,9 +249,20 @@ class CompiledFunction {
// (e) reconstruct the `PyTree`.
py::object Call(py::args args, py::kwargs kwargs);
// This allows `inspect.signature(cpp_jitted_f)` from Python.
py::object __signature__() {
static const auto* inspect = new py::module(py::module::import("inspect"));
return inspect->attr("signature")(python_f_jitted_);
}
private:
CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature);
const CallSignature& signature,
absl::optional<py::tuple> cache_miss_return);
CacheEntry& SetAndReturnCacheEntry(
const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
absl::optional<py::tuple> cache_miss_return = absl::nullopt);
// The Python function in charge of returning a `xla::PyExecutable` from
// the arguments passed to `jitted_f`.
......@@ -267,22 +281,22 @@ class CompiledFunction {
// We need a `unique_ptr` here to ensure value pointer stability.
absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
const std::shared_ptr<xla::PyClient> pyclient_;
xla::PjRtDevice* const default_device_;
// As top-level functions are decorated with `jax.jit`, when
// `CompiledFunction` is being instantiated from Python, the clients are not
// yet available (done after GoogleInit). They will be during the first call
// to `Call`.
std::shared_ptr<xla::PyClient> pyclient_ = nullptr;
xla::PjRtDevice* default_device_ = nullptr;
};
CompiledFunction::CompiledFunction(py::function cache_miss_fun,
py::function python_f_jitted,
bool jax_enable_x64,
std::vector<int> static_argnums,
std::shared_ptr<xla::PyClient> pyclient,
xla::PjRtDevice* device)
std::vector<int> static_argnums)
: cache_miss_fun_(std::move(cache_miss_fun)),
python_f_jitted_(std::move(python_f_jitted)),
jax_enable_x64_(jax_enable_x64),
static_argnums_(std::move(static_argnums)),
pyclient_(std::move(pyclient)),
default_device_(device) {
static_argnums_(std::move(static_argnums)) {
std::sort(static_argnums_.begin(), static_argnums_.end());
}
......@@ -493,8 +507,21 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
xla::PjRtDevice* data_device = nullptr;
for (py::handle arg : arguments.flat_dynamic_args) {
if (py::isinstance(arg, device_array)) {
xla::PyBuffer* buffer =
py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
xla::PyBuffer* buffer;
try {
// This can fail, e.g. when device_buffer is a `DeviceConstant`.
buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
} catch (const py::cast_error& e) {
return InvalidArgument(
"%s",
absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
"`device_buffer` field is of type ",
py::cast<std::string>(
arg.attr("device_buffer").get_type().str()),
" while a `PyBuffer` was expected."
));
}
xla::PjRtDevice* device = buffer->buffer()->device();
if (data_device && (device != data_device)) {
return InvalidArgument(
......@@ -577,14 +604,20 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
} // namespace
CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args,
const py::kwargs& kwargs,
const CallSignature& signature) {
CacheEntry& CompiledFunction::GetCacheEntry(
const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
absl::optional<py::tuple> cache_miss_return) {
auto found_iterator = executables_.find(signature);
if (found_iterator != executables_.end()) { // Cache hit!
return *(found_iterator->second);
}
return SetAndReturnCacheEntry(args, kwargs, signature, cache_miss_return);
}
CacheEntry& CompiledFunction::SetAndReturnCacheEntry(
const py::args& args, const py::kwargs& kwargs,
const CallSignature& signature,
absl::optional<py::tuple> cache_miss_return) {
// We need to insert the element.
auto result = executables_.emplace(signature, std::make_unique<CacheEntry>());
auto it = result.first;
......@@ -593,7 +626,12 @@ CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args,
result.first->first.IncRef();
// Cache miss? Call the Python cache miss function.
py::tuple executable_and_pytree = cache_miss_fun_(*args, **kwargs);
py::tuple executable_and_pytree;
if (cache_miss_return) {
executable_and_pytree = cache_miss_return.value();
} else {
executable_and_pytree = cache_miss_fun_(*args, **kwargs);
}
if (executable_and_pytree.size() != 4) {
throw std::runtime_error(
"AssertionError: The cache miss function should return 4 "
......@@ -639,6 +677,16 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
ParsedArgumentsAsBuffers arguments;
FlattenArguments(args, kwargs, static_argnums_, arguments);
absl::optional<py::tuple> cache_miss_result = absl::nullopt;
if (!default_device_) {
cache_miss_result = cache_miss_fun_(*args, **kwargs);
auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
cache_miss_result.value()[0]);
pyclient_ = executable->client();
default_device_ = executable->LocalDevices()[0].contents;
}
// The C++ jit do not support Tracers arguments yet. The Python-based jit
// function will be called if any of the dynamic arguments is unsupported.
if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_,
......@@ -647,7 +695,8 @@ py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
return python_f_jitted_(*args, **kwargs);
}
CacheEntry& cache_entry = GetCacheEntry(args, kwargs, arguments.signature);
CacheEntry& cache_entry =
GetCacheEntry(args, kwargs, arguments.signature, cache_miss_result);
std::vector<std::unique_ptr<xla::PyBuffer>> outputs =
ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers));
......@@ -677,19 +726,18 @@ void BuildJaxjitSubmodule(pybind11::module& m) {
py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
jitlib, "CompiledFunction");
cfun.def("__call__", &CompiledFunction::Call);
jitlib.def("jit",
[](py::function cache_miss_fun,
py::function fallback_on_unsupported_argument,
bool jax_enable_x64, std::vector<int> static_argnums,
xla::ClientAndPtr<xla::PjRtDevice> client_and_device)
-> std::unique_ptr<CompiledFunction> {
return std::make_unique<CompiledFunction>(
std::move(cache_miss_fun),
std::move(fallback_on_unsupported_argument), jax_enable_x64,
std::move(static_argnums), client_and_device.client,
client_and_device.contents);
});
cfun.def_property_readonly("__signature__", &CompiledFunction::__signature__);
jitlib.def(
"jit",
[](py::function cache_miss_fun,
py::function fallback_on_unsupported_argument, bool jax_enable_x64,
std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
return std::make_unique<CompiledFunction>(
std::move(cache_miss_fun),
std::move(fallback_on_unsupported_argument), jax_enable_x64,
std::move(static_argnums));
});
// Only for testing purposes
jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册