提交 63e6a63d 编写于 作者: J Johannes Reifferscheid 提交者: TensorFlower Gardener

Make remaining utilities for calling computations free functions.

- Also give them clearer names (IMO, let me know if you disagree)
- Also get rid of the indirection through `GetNestedComputer()`.

PiperOrigin-RevId: 549568613
上级 3e240b5b
......@@ -304,11 +304,13 @@ cc_library(
cc_library(
name = "ir_emitter",
srcs = [
"elemental_ir_emitter.cc",
"ir_emitter.cc",
"ir_emitter_nested.cc",
"ir_emitter_unnested.cc",
],
hdrs = [
"elemental_ir_emitter.h",
"ir_emitter.h",
"ir_emitter_nested.h",
"ir_emitter_unnested.h",
......@@ -318,7 +320,6 @@ cc_library(
deps = [
":backend_configs_cc",
":buffer_allocations",
":elemental_ir_emitter",
":fft_thunk",
":gpu_asm_opts_util",
":gpu_constants",
......@@ -379,6 +380,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/compiler/xla/service/llvm_ir:sort_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_module_importer",
......@@ -669,37 +671,6 @@ cc_library(
],
)
cc_library(
name = "elemental_ir_emitter",
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
deps = [
":backend_configs_cc",
":target_util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/hlo/ir:hlo",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:math_ops",
"//tensorflow/tsl/platform:logging",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
],
)
cc_library(
name = "buffer_allocations",
srcs = ["buffer_allocations.cc"],
......
......@@ -15,17 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
#include <stddef.h>
#include <string>
#include <vector>
#include "tensorflow/tsl/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/APInt.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Instructions.h"
......@@ -35,19 +31,15 @@ limitations under the License.
#include "llvm/Support/ModRef.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
......@@ -71,17 +63,17 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
} // namespace
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
llvm::IRBuilder<>* b, NestedComputer compute_nested)
: ElementalIrEmitter(module, b),
const HloModuleConfig& hlo_module_config,
IrEmitterContext& ir_emitter_context, llvm::IRBuilder<>* b)
: ElementalIrEmitter(ir_emitter_context.llvm_module(), b),
hlo_module_config_(hlo_module_config),
compute_nested_(std::move(compute_nested)) {}
ir_emitter_context_(ir_emitter_context) {}
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name) {
// Device functions dont have f16 math functions, so we convert the operands
// Device functions don't have f16 math functions, so we convert the operands
// to f32 before calling the function and then convert the result back to f16.
bool cast_result_to_fp16 = false;
std::vector<llvm::Value*> converted_operands(operands.begin(),
......@@ -343,5 +335,12 @@ llvm::Value* GpuElementalIrEmitter::EmitThreadId() {
return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block);
}
StatusOr<std::vector<llvm::Value*>> GpuElementalIrEmitter::EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view, bool /*is_reducer*/) {
return CallNestedComputationWithScalars(b(), hlo_module_config_, callee,
ir_emitter_context_, parameters);
}
} // namespace gpu
} // namespace xla
......@@ -16,9 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_
#include <functional>
#include <string>
#include <utility>
#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
......@@ -26,11 +24,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h"
#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
......@@ -38,14 +35,9 @@ namespace gpu {
class GpuElementalIrEmitter : public ElementalIrEmitter {
public:
// A NestedComputer computes an element of the output of the given computation
// given a Span of its input elements.
using NestedComputer = std::function<StatusOr<std::vector<llvm::Value*>>(
const HloComputation&, absl::Span<llvm::Value* const>)>;
GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
llvm::Module* module, llvm::IRBuilder<>* b,
NestedComputer compute_nested);
IrEmitterContext& ir_emitter_context,
llvm::IRBuilder<>* b);
protected:
llvm_ir::IrArray::Index GetSourceIndexOfBitcast(
......@@ -98,9 +90,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
StatusOr<std::vector<llvm::Value*>> EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view, bool /*is_reducer*/) override {
return compute_nested_(callee, parameters);
}
absl::string_view, bool /*is_reducer*/) override;
llvm::Value* EmitThreadId() override;
......@@ -137,8 +127,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
absl::string_view name = "");
const HloModuleConfig& hlo_module_config_;
NestedComputer compute_nested_;
IrEmitterContext& ir_emitter_context_;
};
} // namespace gpu
......
......@@ -60,8 +60,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
};
}
return EmitTargetElementLoop(
*hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &b_,
GetNestedComputer())
*hlo, GpuElementalIrEmitter(hlo_module_config_, *ir_emitter_context_, &b_)
.MakeElementGenerator(hlo, operand_to_generator));
}
......@@ -505,8 +504,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
// kFusion for library calls should be handled by
// IrEmitterUnnested::HandleFusion.
CHECK_EQ(HloInstruction::FusionKind::kLoop, fusion->fusion_kind());
GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
*ir_emitter_context_, &b_);
FusedIrEmitter fused_emitter(elemental_emitter);
BindFusionArguments(fusion, &fused_emitter);
TF_ASSIGN_OR_RETURN(auto generator, fused_emitter.GetGenerator(
......@@ -558,53 +557,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"be lowered before IR emission to HLO-soup using BatchNormRewriter.");
}
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements) {
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", &b_));
Store(parameter_element, parameter_buffers.back());
}
return ComputeNestedElementFromAddrs(computation, parameter_buffers);
}
StatusOr<std::vector<llvm::Value*>> IrEmitter::ComputeNestedElementFromAddrs(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements_addrs) {
const Shape& return_shape = computation.root_instruction()->shape();
llvm::Type* return_buffer_type =
llvm_ir::ShapeToIrType(return_shape, module_);
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
return_buffer_type, "return_buffer", &b_);
std::vector<llvm::Value*> allocas_for_returned_scalars;
if (!return_shape.IsTuple()) {
allocas_for_returned_scalars.push_back(return_buffer);
} else {
allocas_for_returned_scalars =
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, &b_);
llvm_ir::IrArray tuple_array(return_buffer, return_buffer_type,
return_shape);
EmitTuple(tuple_array, allocas_for_returned_scalars, &b_);
}
TF_RETURN_IF_ERROR(CallNestedComputation(
&b_, hlo_module_config_, computation, *ir_emitter_context_,
parameter_elements_addrs, return_buffer));
std::vector<llvm::Value*> returned_scalars;
returned_scalars.reserve(allocas_for_returned_scalars.size());
for (llvm::Value* addr : allocas_for_returned_scalars) {
auto alloca = llvm::cast<llvm::AllocaInst>(addr);
returned_scalars.push_back(Load(alloca->getAllocatedType(), alloca));
}
return returned_scalars;
}
std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
const HloInstruction& hlo) {
std::vector<llvm_ir::IrArray> output_arrays;
......
......@@ -133,21 +133,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
const HloComputation& nested_computation, llvm::Value* output_address,
llvm::Value* source_address, llvm::Type* element_type);
GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
return [&](const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements) {
return ComputeNestedElement(computation, parameter_elements);
};
}
StatusOr<std::vector<llvm::Value*>> ComputeNestedElement(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements);
StatusOr<std::vector<llvm::Value*>> ComputeNestedElementFromAddrs(
const HloComputation& computation,
absl::Span<llvm::Value* const> parameter_elements_addrs);
IrEmitterContext* ir_emitter_context_;
llvm::Module* module_;
......
......@@ -257,7 +257,7 @@ Status IrEmitterNested::EmitConstants(const HloComputation& computation) {
// Casts the provided llvm::Value* to the default address space. This is useful
// in particular for generating IR for AMDGPU target, as its kernel variables
// are in address space 5 instead of the default address space.
static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
llvm::Type* arg_type = arg->getType();
CHECK(arg_type->isPointerTy());
if (arg_type->getPointerAddressSpace() != 0) {
......@@ -274,16 +274,16 @@ static llvm::Value* AddrCastToDefault(llvm::Value* arg, llvm::IRBuilder<>& b) {
Status CallNestedComputation(llvm::IRBuilder<>* builder,
const HloModuleConfig& hlo_module_config,
const HloComputation& nested_computation,
const HloComputation& computation,
IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> operands,
llvm::Value* output) {
TF_RET_CHECK(nested_computation.num_parameters() > 0);
TF_RET_CHECK(computation.num_parameters() > 0);
TF_ASSIGN_OR_RETURN(llvm::Function * emitted_function,
IrEmitterNested(hlo_module_config, nested_computation,
&ir_emitter_context)
.CodegenNestedComputation());
TF_ASSIGN_OR_RETURN(
llvm::Function * emitted_function,
IrEmitterNested(hlo_module_config, computation, &ir_emitter_context)
.CodegenNestedComputation());
// Operands are in default address space for non-AMDGPU target.
// However for AMDGPU target, addrspacecast alloca variables from
......@@ -301,5 +301,57 @@ Status CallNestedComputation(llvm::IRBuilder<>* builder,
return OkStatus();
}
StatusOr<std::vector<llvm::Value*>> CallNestedComputationWithScalars(
llvm::IRBuilder<>* builder, const HloModuleConfig& hlo_module_config,
const HloComputation& computation, IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> parameter_elements) {
std::vector<llvm::Value*> parameter_buffers;
for (llvm::Value* parameter_element : parameter_elements) {
parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
parameter_element->getType(), "parameter_buffer", builder));
builder->CreateStore(parameter_element, parameter_buffers.back());
}
return CallNestedComputationWithScalarAddrs(builder, hlo_module_config,
computation, ir_emitter_context,
parameter_buffers);
}
StatusOr<std::vector<llvm::Value*>> CallNestedComputationWithScalarAddrs(
llvm::IRBuilder<>* builder, const HloModuleConfig& hlo_module_config,
const HloComputation& computation, IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> parameter_elements_addrs) {
const Shape& return_shape = computation.root_instruction()->shape();
llvm::Type* return_buffer_type = llvm_ir::ShapeToIrType(
return_shape, builder->GetInsertBlock()->getModule());
llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
return_buffer_type, "return_buffer", builder);
std::vector<llvm::Value*> allocas_for_returned_scalars;
if (!return_shape.IsTuple()) {
allocas_for_returned_scalars.push_back(return_buffer);
} else {
allocas_for_returned_scalars =
llvm_ir::EmitTupleAllocasAtFunctionEntry(return_shape, builder);
llvm_ir::IrArray tuple_array(return_buffer, return_buffer_type,
return_shape);
llvm_ir::EmitTuple(tuple_array, allocas_for_returned_scalars, builder);
}
TF_RETURN_IF_ERROR(CallNestedComputation(
builder, hlo_module_config, computation, ir_emitter_context,
parameter_elements_addrs, return_buffer));
std::vector<llvm::Value*> returned_scalars;
returned_scalars.reserve(allocas_for_returned_scalars.size());
for (llvm::Value* addr : allocas_for_returned_scalars) {
auto alloca = llvm::cast<llvm::AllocaInst>(addr);
returned_scalars.push_back(
builder->CreateLoad(alloca->getAllocatedType(), alloca));
}
return returned_scalars;
}
} // namespace gpu
} // namespace xla
......@@ -42,11 +42,23 @@ namespace gpu {
// - a pointer to the top-level temp buffer.
Status CallNestedComputation(llvm::IRBuilder<>* builder,
const HloModuleConfig& hlo_module_config,
const HloComputation& nested_computation,
const HloComputation& computation,
IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> operands,
llvm::Value* output);
// Like CallNestedComputation, but parameters and results are scalars.
StatusOr<std::vector<llvm::Value*>> CallNestedComputationWithScalars(
llvm::IRBuilder<>* builder, const HloModuleConfig& hlo_module_config,
const HloComputation& computation, IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> parameter_scalars);
// Like CallNestedComputationWithScalars, but parameters are scalar addresses.
StatusOr<std::vector<llvm::Value*>> CallNestedComputationWithScalarAddrs(
llvm::IRBuilder<>* builder, const HloModuleConfig& hlo_module_config,
const HloComputation& computation, IrEmitterContext& ir_emitter_context,
absl::Span<llvm::Value* const> parameter_elements_addrs);
} // namespace gpu
} // namespace xla
......
......@@ -351,8 +351,7 @@ StatusOr<xla::gpu::CudnnfMHAKind> AsCudnnBackwardfMHAKind(
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
IrEmitterContext* ir_emitter_context)
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
elemental_emitter_(hlo_module_config_, module_, &b_,
GetNestedComputer()) {}
elemental_emitter_(hlo_module_config_, *ir_emitter_context, &b_) {}
StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
const HloModuleConfig& hlo_module_config,
......@@ -3670,7 +3669,9 @@ void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
}
StatusOr<std::vector<llvm::Value*>> returned_scalars =
ComputeNestedElementFromAddrs(*reducer, reduction_params);
CallNestedComputationWithScalarAddrs(&b_, hlo_module_config_, *reducer,
*ir_emitter_context_,
reduction_params);
TF_CHECK_OK(returned_scalars.status());
for (int i = 0; i < returned_scalars->size(); i++) {
......@@ -4198,7 +4199,9 @@ void IrEmitterUnnested::GenerateElementForReducer(
// those pointers, and we have returned values on the stack (as well
// as pointers to them).
StatusOr<std::vector<llvm::Value*>> returned_scalars =
ComputeNestedElementFromAddrs(*reducer, reduction_params);
CallNestedComputationWithScalarAddrs(&b_, hlo_module_config_, *reducer,
*ir_emitter_context_,
reduction_params);
TF_CHECK_OK(returned_scalars.status());
for (int i = 0; i < returned_scalars->size(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册