提交 e20f8814 编写于 作者: A Anlun Xu 提交者: TensorFlower Gardener

[xla:runtime] Add CpuCompiler::Export

PiperOrigin-RevId: 481007346
上级 53fa1d27
......@@ -1770,6 +1770,25 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
return CpuExecutable::ShapeSizeBytes;
}
StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::Export(
Executable* executable) const {
auto* cpu_executable = tensorflow::down_cast<CpuExecutable*>(executable);
if (!cpu_executable)
return Internal("Could not downcast Executable to CpuExecutable");
HloModuleProto module_proto = cpu_executable->module().ToProto();
TF_ASSIGN_OR_RETURN(std::string obj_file, cpu_executable->GetObjFile());
TF_ASSIGN_OR_RETURN(std::string mlir_module, cpu_executable->GetMlirModule());
TF_ASSIGN_OR_RETURN(XlaFrameworkMapping xla_framework_mapping,
cpu_executable->GetXlaFrameworkMapping());
std::unique_ptr<AotCompilationResult> result =
std::make_unique<CpuXlaRuntimeAotCompilationResult>(
module_proto, obj_file, mlir_module,
cpu_executable->buffer_assignment(), xla_framework_mapping);
return result;
}
} // namespace cpu
} // namespace xla
......
......@@ -191,6 +191,9 @@ class CpuCompiler : public LLVMCompiler {
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
StatusOr<std::unique_ptr<AotCompilationResult>> Export(
Executable* executable) const override;
private:
// Initialize the LLVM target.
static void InitializeLLVMTarget();
......
......@@ -68,6 +68,24 @@ class XlaRuntimeCpuExecutable {
return *default_executable_;
}
StatusOr<std::string> GetObjFile() const {
std::unique_ptr<llvm::MemoryBuffer> obj_file =
jit_executable_->DefaultExecutable()->obj_file();
if (!obj_file)
return InternalError("XlaRuntimeCpuExecutable didn't save the obj file");
std::string data(obj_file->getBuffer().data(),
obj_file->getBuffer().size());
return data;
}
StatusOr<std::string> GetMlirModule() const {
return jit_executable_->mlir_module();
}
XlaFrameworkMapping xla_framework_mapping() { return xla_framework_mapping_; }
private:
std::unique_ptr<xla::runtime::JitExecutable> jit_executable_;
xla::runtime::Executable* default_executable_; // owned by jit_executable_.
......@@ -137,6 +155,21 @@ class CpuExecutable : public Executable {
int64_t SizeOfGeneratedCodeInBytes() const override;
StatusOr<std::string> GetObjFile() const {
if (!IsXlaRuntime()) return InternalError("Not an XLA Runtime executable");
return xla_runtime_executable_->GetObjFile();
}
StatusOr<std::string> GetMlirModule() const {
if (!IsXlaRuntime()) return InternalError("Not an XLA Runtime executable");
return xla_runtime_executable_->GetMlirModule();
}
StatusOr<XlaFrameworkMapping> GetXlaFrameworkMapping() const {
if (!IsXlaRuntime()) return InternalError("Not an XLA Runtime executable");
return xla_runtime_executable_->xla_framework_mapping();
}
private:
// Creates an array suitable for passing as the "buffer_table" argument to the
// JIT compiled function pointer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册