From 3689c21345785dbf05f7afa38751c23f7b3ff26d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 12 Jan 2017 13:19:09 -0800 Subject: [PATCH] [XLA] Add support for multiple computations to CompileAheadOfTime. Change: 144362931 --- tensorflow/compiler/aot/compile.cc | 27 ++-- .../compiler/xla/client/local_client.cc | 23 ++- tensorflow/compiler/xla/client/local_client.h | 27 ++-- tensorflow/compiler/xla/service/compiler.h | 9 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 139 ++++++++++-------- .../compiler/xla/service/cpu/cpu_compiler.h | 9 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 9 +- .../compiler/xla/service/gpu/gpu_compiler.h | 9 +- .../compiler/xla/service/local_service.cc | 72 +++++---- .../compiler/xla/service/local_service.h | 22 ++- .../xla/tests/local_client_aot_test_helper.cc | 14 +- 11 files changed, 207 insertions(+), 153 deletions(-) diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 50e596786ab..00c07932aca 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -204,23 +204,23 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config, string feed_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id)); if (missing_feeds.erase(feed_id) == 0) { - return errors::Aborted(kArgOp, " node found with unknown feed id: ", - feed_id); + return errors::Aborted(kArgOp, + " node found with unknown feed id: ", feed_id); } } else if (n->type_string() == kRetvalOp) { string fetch_id; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id)); if (missing_fetches.erase(fetch_id) == 0) { - return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ", - fetch_id); + return errors::Aborted(kRetvalOp, + " node found with unknown fetch id: ", fetch_id); } } } if (!missing_feeds.empty() || !missing_fetches.empty()) { - return errors::Aborted("Post graph-pruning", ", missing feeds: ", - str_util::Join(missing_feeds, ", "), - ", missing fetches: ", - str_util::Join(missing_fetches, ", ")); + return errors::Aborted( + "Post graph-pruning", + ", missing feeds: ", str_util::Join(missing_feeds, ", "), + ", missing fetches: ", str_util::Join(missing_fetches, ", ")); } return Status::OK(); } @@ -351,16 +351,19 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::StatusOr> aot_or = - client->CompileAheadOfTime(computation, arg_layouts, pshape->result(), - aot_opts); + xla::LocalClient::AheadOfTimeComputationInstance instance; + instance.computation = &computation; + instance.argument_layouts = std::move(arg_layouts); + instance.result_layout = &pshape->result(); + xla::StatusOr>> + aot_or = client->CompileAheadOfTime({instance}, aot_opts); if (!aot_or.ok()) { return errors::Unknown("XLA compilation failed: ", aot_or.status().error_message()); } compile_result->aot = xla::unique_ptr_static_cast( - aot_or.ConsumeValueOrDie()); + std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = xla::LocalClient::PointerSizeForTriple(aot_opts.triple()); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 148c033eaa3..384aae867b1 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -314,12 +314,23 @@ tensorflow::Status LocalClient::ExecuteLocally( options, result); } -StatusOr> LocalClient::CompileAheadOfTime( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const Shape& result_layout, const AotCompilationOptions& options) { - return local_service_->CompileAheadOfTime( - computation.handle(), argument_layouts, result_layout, options); +StatusOr>> +LocalClient::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice + computations, + const AotCompilationOptions& options) { + std::vector service_instances; + service_instances.reserve(computations.size()); + for (const AheadOfTimeComputationInstance& instance : computations) { + service_instances.push_back({}); + LocalService::AheadOfTimeComputationInstance& service_instance = + service_instances.back(); + TF_RET_CHECK(instance.computation != nullptr); + service_instance.computation = instance.computation->handle(); + service_instance.argument_layouts = instance.argument_layouts; + service_instance.result_layout = instance.result_layout; + } + return local_service_->CompileAheadOfTime(service_instances, options); } int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 1d6243a3b68..33366b97fd5 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -219,19 +219,26 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); - // Compiles the computation for ahead-of-time execution. This is intended for - // use in static compilation. The |argument_layouts| parameter is used to - // inform the compiler of the expected layout for arguments while - // |result_layout| is used to signal the layout of the result. The |options| - // parameter is used to request which target the compiler should emit code - // for. + // A description of a computation to compile using CompileAheadOfTime. + struct AheadOfTimeComputationInstance { + const Computation* computation; + // Inform the compiler of the expected layout for arguments. + std::vector argument_layouts; + // Specifies the expected result layout. + const Shape* result_layout; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. The |options| parameter describes + // the target for which the compiler should emit code. // // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its // own library. - StatusOr> CompileAheadOfTime( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const Shape& result_layout, const AotCompilationOptions& options); + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice + computations, + const AotCompilationOptions& options); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 632081a747e..85c2d03e1bc 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -128,10 +128,11 @@ class Compiler { // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. - virtual StatusOr> CompileAheadOfTime( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, - const AotCompilationOptions& options) = 0; + virtual StatusOr>> + CompileAheadOfTime( + std::vector> module, + std::vector> module_config, + HloDumper dump_hlo, const AotCompilationOptions& options) = 0; ///// // The Compiler class also serves as a point to register compiler objects diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d566cfd8c8f..b9f4537b809 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -478,10 +478,13 @@ StatusOr>> CpuCompiler::Compile( "Compilation of multiple HLO modules is not yet supported on CPU."); } -StatusOr> CpuCompiler::CompileAheadOfTime( - std::unique_ptr hlo_module, - std::unique_ptr module_config, HloDumper dump_hlo, - const AotCompilationOptions& aot_options) { +StatusOr>> +CpuCompiler::CompileAheadOfTime( + std::vector> hlo_modules, + std::vector> module_configs, + HloDumper dump_hlo, const AotCompilationOptions& aot_options) { + TF_RET_CHECK(hlo_modules.size() == module_configs.size()); + if (aot_options.PlatformId() != se::host::kHostPlatformId) { return InvalidArgument("Incompatible AOT compilation platform"); } @@ -549,72 +552,78 @@ StatusOr> CpuCompiler::CompileAheadOfTime( const llvm::DataLayout& data_layout = llvm_module.getDataLayout(); int64 pointer_size = data_layout.getPointerSize(); - TF_RETURN_IF_ERROR( - RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo)); + std::vector> results; + for (int i = 0; i < hlo_modules.size(); ++i) { + HloModule* hlo_module = hlo_modules[i].get(); + HloModuleConfig* module_config = module_configs[i].get(); - SequentialHloOrdering::HloModuleSequence module_sequence = - CreateModuleSequence(hlo_module.get()); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run( - hlo_module.get(), - MakeUnique(hlo_module.get(), module_sequence), - pointer_size)); - - IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr); - HloComputation* computation = hlo_module->entry_computation(); - for (auto embedded_computation : - computation->MakeEmbeddedComputationsList()) { - TF_RETURN_IF_ERROR( - ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, - &module_sequence.at(embedded_computation)) - .status()); - } - const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true)); - - entry_function->setName(llvm_ir::AsStringRef(entry_point_name)); - - Disassembler disassembler(*target_machine); - CompilerFunctor compiler_functor(target_machine.get(), &disassembler, - opt_level, CompilerFunctor::AllIntrinsics()); - llvm::object::OwningBinary object_file = - compiler_functor(llvm_module); - llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); - ObjectFileData object_file_data(object_file_data_ref.begin(), - object_file_data_ref.end()); - - BufferSizes buffer_sizes; - for (const BufferAllocation& allocation : assignment->Allocations()) { - // Callers don't need to allocate temporary buffers for parameters. - if (allocation.is_entry_computation_parameter()) { - buffer_sizes.push_back(-1); - continue; + TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo)); + + SequentialHloOrdering::HloModuleSequence module_sequence = + CreateModuleSequence(hlo_module); + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run(hlo_module, MakeUnique( + hlo_module, module_sequence), + pointer_size)); + + IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module, + /*hlo_to_profile_idx=*/nullptr); + HloComputation* computation = hlo_module->entry_computation(); + for (auto embedded_computation : + computation->MakeEmbeddedComputationsList()) { + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(embedded_computation, + embedded_computation->name(), + /*is_entry_computation=*/false, + &module_sequence.at(embedded_computation)) + .status()); } - // Callers don't need to allocate anything for thread-local temporary - // buffers. They are lowered to allocas. - if (allocation.is_thread_local()) { - buffer_sizes.push_back(-1); - continue; + const string& entry_point_name = options.entry_point_name(); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(computation, entry_point_name, + /*is_entry_computation=*/true)); + + entry_function->setName(llvm_ir::AsStringRef(entry_point_name)); + + Disassembler disassembler(*target_machine); + CompilerFunctor compiler_functor(target_machine.get(), &disassembler, + opt_level, + CompilerFunctor::AllIntrinsics()); + llvm::object::OwningBinary object_file = + compiler_functor(llvm_module); + llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); + ObjectFileData object_file_data(object_file_data_ref.begin(), + object_file_data_ref.end()); + + BufferSizes buffer_sizes; + for (const BufferAllocation& allocation : assignment->Allocations()) { + // Callers don't need to allocate temporary buffers for parameters. + if (allocation.is_entry_computation_parameter()) { + buffer_sizes.push_back(-1); + continue; + } + // Callers don't need to allocate anything for thread-local temporary + // buffers. They are lowered to allocas. + if (allocation.is_thread_local()) { + buffer_sizes.push_back(-1); + continue; + } + buffer_sizes.push_back(allocation.size()); } - buffer_sizes.push_back(allocation.size()); - } - TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, - assignment->GetUniqueTopLevelOutputAllocation()); + TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation, + assignment->GetUniqueTopLevelOutputAllocation()); - return std::unique_ptr( - MakeUnique(std::move(object_file_data), - std::move(buffer_sizes), - result_allocation->index())); + results.emplace_back(MakeUnique( + std::move(object_file_data), std::move(buffer_sizes), + result_allocation->index())); + } + return std::move(results); } se::Platform::Id CpuCompiler::PlatformId() const { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 349724d8406..d7d77ce58a6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -123,10 +123,11 @@ class CpuCompiler : public Compiler { HloDumper dump_hlo, std::vector stream_exec) override; - StatusOr> CompileAheadOfTime( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, - const AotCompilationOptions& options) override; + StatusOr>> + CompileAheadOfTime( + std::vector> module, + std::vector> module_config, + HloDumper dump_hlo, const AotCompilationOptions& options) override; perftools::gputools::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index a13279c6ff6..2f95446e6c4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -312,10 +312,11 @@ StatusOr>> GpuCompiler::Compile( "Compilation of multiple HLO modules is not yet supported on GPU."); } -StatusOr> GpuCompiler::CompileAheadOfTime( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, - const AotCompilationOptions& options) { +StatusOr>> +GpuCompiler::CompileAheadOfTime( + std::vector> module, + std::vector> module_config, + HloDumper dump_hlo, const AotCompilationOptions& options) { return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index fefa4031041..a074607760f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -52,10 +52,11 @@ class GpuCompiler : public Compiler { HloDumper dump_hlo, std::vector stream_exec) override; - StatusOr> CompileAheadOfTime( - std::unique_ptr module, - std::unique_ptr module_config, HloDumper dump_hlo, - AotCompilationOptions const& options) override; + StatusOr>> + CompileAheadOfTime( + std::vector> module, + std::vector> module_config, + HloDumper dump_hlo, AotCompilationOptions const& options) override; perftools::gputools::Platform::Id PlatformId() const override; diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 38465e37e7b..7f86a3cbb57 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -206,42 +206,49 @@ tensorflow::Status LocalService::ExecuteLocally( return tensorflow::Status::OK(); } -StatusOr> +StatusOr>> LocalService::CompileAheadOfTime( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const Shape& result_layout, const AotCompilationOptions& options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule(versioned_handle, - /*include_unused_parameters=*/true)); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - auto module_config = MakeUnique(*program_shape); - auto* computation_layout = module_config->mutable_entry_computation_layout(); - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_layout = *argument_layouts[i]; - if (ShapeUtil::IsTuple(argument_layout)) { - return Unimplemented("tuple arguments not supported yet"); + const tensorflow::gtl::ArraySlice + computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + std::vector> module_configs; + for (const AheadOfTimeComputationInstance& instance : computations) { + TF_ASSIGN_OR_RETURN(UserComputation * user_computation, + computation_tracker_.Resolve(instance.computation)); + VersionedComputationHandle versioned_handle = + user_computation->GetVersionedHandle(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, + computation_tracker_.BuildHloModule( + versioned_handle, + /*include_unused_parameters=*/true)); + hlo_modules.push_back(std::move(hlo_module)); + + TF_ASSIGN_OR_RETURN( + std::shared_ptr program_shape, + user_computation->ComputeProgramShape(versioned_handle.version)); + + module_configs.push_back(MakeUnique(*program_shape)); + HloModuleConfig* module_config = module_configs.back().get(); + auto* computation_layout = + module_config->mutable_entry_computation_layout(); + for (int i = 0; i < instance.argument_layouts.size(); ++i) { + const Shape& argument_layout = *instance.argument_layouts[i]; + if (ShapeUtil::IsTuple(argument_layout)) { + return Unimplemented("tuple arguments not supported yet"); + } + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + argument_layout)); } TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - argument_layout)); + computation_layout->mutable_result_layout()->CopyLayoutFromShape( + *instance.result_layout)); } - TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( - result_layout)); return execute_backend_->compiler() - ->CompileAheadOfTime(std::move(hlo_module), std::move(module_config), + ->CompileAheadOfTime(std::move(hlo_modules), std::move(module_configs), MakeHloDumper(), options) .ConsumeValueOrDie(); } @@ -426,8 +433,9 @@ StatusOr> LocalService::ExecuteLocallyInternal( } else { se::StreamExecutor* stream_executor; if (options.device_ordinal() >= 0) { - TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( - options.device_ordinal())); + TF_ASSIGN_OR_RETURN( + stream_executor, + execute_backend_->stream_executor(options.device_ordinal())); } else { stream_executor = execute_backend_->default_stream_executor(); } diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 3e160a0201e..9fe0d5993b3 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -139,13 +139,21 @@ class LocalService : public Service { tensorflow::gtl::ArraySlice arguments, const LocalExecuteOptions& options, ShapedBuffer* result_buffer); - // Compiles the computation for ahead-of-time execution. This is intended for - // use in static compilation. See |LocalClient::CompileAheadOfTime| for - // additional details. - StatusOr> CompileAheadOfTime( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const Shape& result_layout, const AotCompilationOptions& Options); + // A description of a computation to compile using CompileAheadOfTime. + struct AheadOfTimeComputationInstance { + ComputationHandle computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |LocalClient::CompileAheadOfTime| for additional details. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice + computations, + const AotCompilationOptions& Options); // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 50e5dec0f62..50d9ee50835 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -72,16 +73,19 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); + xla::Computation computation = builder.Build().ConsumeValueOrDie(); + xla::LocalClient::AheadOfTimeComputationInstance instance{ + &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; + xla::cpu::CpuAotCompilationOptions options( triple_string, /*cpu_name=*/"", /*features=*/"", "SumAndDouble", xla::cpu::CpuAotCompilationOptions::RelocationModel::Static); + + auto results = + client->CompileAheadOfTime({instance}, options).ConsumeValueOrDie(); auto result = xla::unique_ptr_static_cast( - client - ->CompileAheadOfTime(builder.Build().ValueOrDie(), - /*argument_layouts=*/{&opaque_shape}, r0f32, - options) - .ConsumeValueOrDie()); + std::move(results.front())); // We should have two buffers, one for the result and one temporary buffer, // and both should be float-sized. It's lame to hard-code this, but we need // local_client_aot_test.cc to be able to easily invoke the function. -- GitLab