提交 3689c213 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[XLA] Add support for multiple computations to CompileAheadOfTime.

Change: 144362931
上级 725da748
......@@ -204,23 +204,23 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
string feed_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
if (missing_feeds.erase(feed_id) == 0) {
return errors::Aborted(kArgOp, " node found with unknown feed id: ",
feed_id);
return errors::Aborted(kArgOp,
" node found with unknown feed id: ", feed_id);
}
} else if (n->type_string() == kRetvalOp) {
string fetch_id;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
if (missing_fetches.erase(fetch_id) == 0) {
return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ",
fetch_id);
return errors::Aborted(kRetvalOp,
" node found with unknown fetch id: ", fetch_id);
}
}
}
if (!missing_feeds.empty() || !missing_fetches.empty()) {
return errors::Aborted("Post graph-pruning", ", missing feeds: ",
str_util::Join(missing_feeds, ", "),
", missing fetches: ",
str_util::Join(missing_fetches, ", "));
return errors::Aborted(
"Post graph-pruning",
", missing feeds: ", str_util::Join(missing_feeds, ", "),
", missing fetches: ", str_util::Join(missing_fetches, ", "));
}
return Status::OK();
}
......@@ -351,16 +351,19 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
}
xla::StatusOr<std::unique_ptr<xla::AotCompilationResult>> aot_or =
client->CompileAheadOfTime(computation, arg_layouts, pshape->result(),
aot_opts);
xla::LocalClient::AheadOfTimeComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {
return errors::Unknown("XLA compilation failed: ",
aot_or.status().error_message());
}
compile_result->aot =
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
aot_or.ConsumeValueOrDie());
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->pointer_size =
xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
......
......@@ -314,12 +314,23 @@ tensorflow::Status LocalClient::ExecuteLocally(
options, result);
}
StatusOr<std::unique_ptr<AotCompilationResult>> LocalClient::CompileAheadOfTime(
const Computation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const Shape& result_layout, const AotCompilationOptions& options) {
return local_service_->CompileAheadOfTime(
computation.handle(), argument_layouts, result_layout, options);
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options) {
std::vector<LocalService::AheadOfTimeComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AheadOfTimeComputationInstance& instance : computations) {
service_instances.push_back({});
LocalService::AheadOfTimeComputationInstance& service_instance =
service_instances.back();
TF_RET_CHECK(instance.computation != nullptr);
service_instance.computation = instance.computation->handle();
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
return local_service_->CompileAheadOfTime(service_instances, options);
}
int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
......
......@@ -219,19 +219,26 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options);
// Compiles the computation for ahead-of-time execution. This is intended for
// use in static compilation. The |argument_layouts| parameter is used to
// inform the compiler of the expected layout for arguments while
// |result_layout| is used to signal the layout of the result. The |options|
// parameter is used to request which target the compiler should emit code
// for.
// A description of a computation to compile using CompileAheadOfTime.
struct AheadOfTimeComputationInstance {
const Computation* computation;
// Inform the compiler of the expected layout for arguments.
std::vector<const Shape*> argument_layouts;
// Specifies the expected result layout.
const Shape* result_layout;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. The |options| parameter describes
// the target for which the compiler should emit code.
//
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
// own library.
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
const Computation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const Shape& result_layout, const AotCompilationOptions& options);
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
......
......@@ -128,10 +128,11 @@ class Compiler {
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.
virtual StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
std::unique_ptr<HloModule> module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
const AotCompilationOptions& options) = 0;
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> module,
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
HloDumper dump_hlo, const AotCompilationOptions& options) = 0;
/////
// The Compiler class also serves as a point to register compiler objects
......
......@@ -478,10 +478,13 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
"Compilation of multiple HLO modules is not yet supported on CPU.");
}
StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
const AotCompilationOptions& aot_options) {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CpuCompiler::CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> hlo_modules,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
HloDumper dump_hlo, const AotCompilationOptions& aot_options) {
TF_RET_CHECK(hlo_modules.size() == module_configs.size());
if (aot_options.PlatformId() != se::host::kHostPlatformId) {
return InvalidArgument("Incompatible AOT compilation platform");
}
......@@ -549,72 +552,78 @@ StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
int64 pointer_size = data_layout.getPointerSize();
TF_RETURN_IF_ERROR(
RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo));
std::vector<std::unique_ptr<AotCompilationResult>> results;
for (int i = 0; i < hlo_modules.size(); ++i) {
HloModule* hlo_module = hlo_modules[i].get();
HloModuleConfig* module_config = module_configs[i].get();
SequentialHloOrdering::HloModuleSequence module_sequence =
CreateModuleSequence(hlo_module.get());
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
hlo_module.get(),
MakeUnique<SequentialHloOrdering>(hlo_module.get(), module_sequence),
pointer_size));
IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
/*hlo_to_profile_idx=*/nullptr);
HloComputation* computation = hlo_module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation, embedded_computation->name(),
/*is_entry_computation=*/false,
&module_sequence.at(embedded_computation))
.status());
}
const string& entry_point_name = options.entry_point_name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
ir_emitter.EmitComputation(computation, entry_point_name,
/*is_entry_computation=*/true));
entry_function->setName(llvm_ir::AsStringRef(entry_point_name));
Disassembler disassembler(*target_machine);
CompilerFunctor compiler_functor(target_machine.get(), &disassembler,
opt_level, CompilerFunctor::AllIntrinsics());
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
compiler_functor(llvm_module);
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
ObjectFileData object_file_data(object_file_data_ref.begin(),
object_file_data_ref.end());
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate temporary buffers for parameters.
if (allocation.is_entry_computation_parameter()) {
buffer_sizes.push_back(-1);
continue;
TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo));
SequentialHloOrdering::HloModuleSequence module_sequence =
CreateModuleSequence(hlo_module);
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(hlo_module, MakeUnique<SequentialHloOrdering>(
hlo_module, module_sequence),
pointer_size));
IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
/*hlo_to_profile_idx=*/nullptr);
HloComputation* computation = hlo_module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
TF_RETURN_IF_ERROR(
ir_emitter
.EmitComputation(embedded_computation,
embedded_computation->name(),
/*is_entry_computation=*/false,
&module_sequence.at(embedded_computation))
.status());
}
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
const string& entry_point_name = options.entry_point_name();
TF_ASSIGN_OR_RETURN(
llvm::Function * entry_function,
ir_emitter.EmitComputation(computation, entry_point_name,
/*is_entry_computation=*/true));
entry_function->setName(llvm_ir::AsStringRef(entry_point_name));
Disassembler disassembler(*target_machine);
CompilerFunctor compiler_functor(target_machine.get(), &disassembler,
opt_level,
CompilerFunctor::AllIntrinsics());
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
compiler_functor(llvm_module);
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
ObjectFileData object_file_data(object_file_data_ref.begin(),
object_file_data_ref.end());
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
// Callers don't need to allocate temporary buffers for parameters.
if (allocation.is_entry_computation_parameter()) {
buffer_sizes.push_back(-1);
continue;
}
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
buffer_sizes.push_back(allocation.size());
}
buffer_sizes.push_back(allocation.size());
}
TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
assignment->GetUniqueTopLevelOutputAllocation());
TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
assignment->GetUniqueTopLevelOutputAllocation());
return std::unique_ptr<AotCompilationResult>(
MakeUnique<CpuAotCompilationResult>(std::move(object_file_data),
std::move(buffer_sizes),
result_allocation->index()));
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
std::move(object_file_data), std::move(buffer_sizes),
result_allocation->index()));
}
return std::move(results);
}
se::Platform::Id CpuCompiler::PlatformId() const {
......
......@@ -123,10 +123,11 @@ class CpuCompiler : public Compiler {
HloDumper dump_hlo,
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
std::unique_ptr<HloModule> module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
const AotCompilationOptions& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> module,
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
HloDumper dump_hlo, const AotCompilationOptions& options) override;
perftools::gputools::Platform::Id PlatformId() const override;
......
......@@ -312,10 +312,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> GpuCompiler::Compile(
"Compilation of multiple HLO modules is not yet supported on GPU.");
}
StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::CompileAheadOfTime(
std::unique_ptr<HloModule> module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
const AotCompilationOptions& options) {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
GpuCompiler::CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> module,
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
HloDumper dump_hlo, const AotCompilationOptions& options) {
return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
}
......
......@@ -52,10 +52,11 @@ class GpuCompiler : public Compiler {
HloDumper dump_hlo,
std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
std::unique_ptr<HloModule> module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
AotCompilationOptions const& options) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
std::vector<std::unique_ptr<HloModule>> module,
std::vector<std::unique_ptr<HloModuleConfig>> module_config,
HloDumper dump_hlo, AotCompilationOptions const& options) override;
perftools::gputools::Platform::Id PlatformId() const override;
......
......@@ -206,42 +206,49 @@ tensorflow::Status LocalService::ExecuteLocally(
return tensorflow::Status::OK();
}
StatusOr<std::unique_ptr<AotCompilationResult>>
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalService::CompileAheadOfTime(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const Shape& result_layout, const AotCompilationOptions& options) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(computation));
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(versioned_handle,
/*include_unused_parameters=*/true));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
auto* computation_layout = module_config->mutable_entry_computation_layout();
for (int i = 0; i < argument_layouts.size(); ++i) {
const Shape& argument_layout = *argument_layouts[i];
if (ShapeUtil::IsTuple(argument_layout)) {
return Unimplemented("tuple arguments not supported yet");
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
for (const AheadOfTimeComputationInstance& instance : computations) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(instance.computation));
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle,
/*include_unused_parameters=*/true));
hlo_modules.push_back(std::move(hlo_module));
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
module_configs.push_back(MakeUnique<HloModuleConfig>(*program_shape));
HloModuleConfig* module_config = module_configs.back().get();
auto* computation_layout =
module_config->mutable_entry_computation_layout();
for (int i = 0; i < instance.argument_layouts.size(); ++i) {
const Shape& argument_layout = *instance.argument_layouts[i];
if (ShapeUtil::IsTuple(argument_layout)) {
return Unimplemented("tuple arguments not supported yet");
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_layout));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
argument_layout));
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*instance.result_layout));
}
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
result_layout));
return execute_backend_->compiler()
->CompileAheadOfTime(std::move(hlo_module), std::move(module_config),
->CompileAheadOfTime(std::move(hlo_modules), std::move(module_configs),
MakeHloDumper(), options)
.ConsumeValueOrDie();
}
......@@ -426,8 +433,9 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
} else {
se::StreamExecutor* stream_executor;
if (options.device_ordinal() >= 0) {
TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
options.device_ordinal()));
TF_ASSIGN_OR_RETURN(
stream_executor,
execute_backend_->stream_executor(options.device_ordinal()));
} else {
stream_executor = execute_backend_->default_stream_executor();
}
......
......@@ -139,13 +139,21 @@ class LocalService : public Service {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const LocalExecuteOptions& options, ShapedBuffer* result_buffer);
// Compiles the computation for ahead-of-time execution. This is intended for
// use in static compilation. See |LocalClient::CompileAheadOfTime| for
// additional details.
StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const Shape& result_layout, const AotCompilationOptions& Options);
// A description of a computation to compile using CompileAheadOfTime.
struct AheadOfTimeComputationInstance {
ComputationHandle computation;
std::vector<const Shape*> argument_layouts;
const Shape* result_layout = nullptr;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |LocalClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
computations,
const AotCompilationOptions& Options);
// Builds an Executable with the given argument layouts and options. If
// result_layout is non-null, then the executable is compiled to produce a
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
......@@ -72,16 +73,19 @@ int main(int argc, char** argv) {
llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
xla::Computation computation = builder.Build().ConsumeValueOrDie();
xla::LocalClient::AheadOfTimeComputationInstance instance{
&computation, /*argument_layouts=*/{&opaque_shape}, &r0f32};
xla::cpu::CpuAotCompilationOptions options(
triple_string,
/*cpu_name=*/"", /*features=*/"", "SumAndDouble",
xla::cpu::CpuAotCompilationOptions::RelocationModel::Static);
auto results =
client->CompileAheadOfTime({instance}, options).ConsumeValueOrDie();
auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
client
->CompileAheadOfTime(builder.Build().ValueOrDie(),
/*argument_layouts=*/{&opaque_shape}, r0f32,
options)
.ConsumeValueOrDie());
std::move(results.front()));
// We should have two buffers, one for the result and one temporary buffer,
// and both should be float-sized. It's lame to hard-code this, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册