提交 8d47eef8 编写于 作者: A Anlun Xu 提交者: TensorFlower Gardener

[xla:gpu] Initialize cuBLAS support before running executable and before capturing gpu graphs

Add the custom call xla.gpu.init_cuBLAS to the beginning of the entry function of the module, when the module contains a gemm.

PiperOrigin-RevId: 564550744
上级 a8463067
......@@ -27,6 +27,8 @@ limitations under the License.
#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
......@@ -35,6 +37,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "xla/mlir/backends/gpu/transforms/uid_generator.h"
#include "xla/mlir/runtime/ir/rt_dialect.h"
#include "xla/mlir/runtime/utils/custom_calls.h"
#include "xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "xla/stream_executor/blas.h"
......@@ -85,6 +88,16 @@ class GemmOpLowering : public OpRewritePattern<GEMMOp> {
LogicalResult matchAndRewrite(GEMMOp op,
PatternRewriter& rewriter) const override {
{
// Set requires_blas attribute to true. The runtime pass will add cuBLAS
// initialization custom call to the entry function if the attribute is
// set to true.
auto module = op.getOperation()->getParentOfType<ModuleOp>();
ImplicitLocOpBuilder b(module.getLoc(), rewriter);
module->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName),
BoolAttr::get(b.getContext(), true));
}
// Get or create a custom call function declaration.
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
func::FuncOp callee = custom_calls_.GetOrCreate(b, kCustomCallTarget, op);
......
......@@ -386,7 +386,7 @@ static LogicalResult Outline(unsigned ordinal,
for (auto op : seq) {
mlir::Operation* captured_op = op.first;
if (isa<lmhlo_gpu::GEMMOp>(captured_op)) {
func->setAttr(b.getStringAttr("xla.requires_blas"),
func->setAttr(b.getStringAttr(runtime::kRequiresBlasAttrName),
BoolAttr::get(ctx, true));
break;
}
......
......@@ -35,6 +35,8 @@ namespace runtime {
// Attribute name for marking functions exported to runtime.
static constexpr char const* kExportedAttrName = "rt.exported";
static constexpr char const* kRequiresBlasAttrName = "rt.requires_blas";
} // namespace runtime
} // namespace xla
......
......@@ -29,6 +29,7 @@ gentbl_cc_library(
cc_library(
name = "passes",
srcs = [
"add_initializations.cc",
"convert_asserts.cc",
"convert_custom_calls.cc",
"export_functions.cc",
......@@ -42,6 +43,7 @@ cc_library(
":custom_call_encoding",
":passes_inc_gen",
"//xla/mlir/runtime/ir:rt",
"//xla/mlir/runtime/utils:custom_calls",
"//xla/runtime:custom_call",
"//xla/runtime:tracing",
"//xla/runtime:type_id",
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <utility>
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "xla/mlir/runtime/ir/rt_dialect.h"
#include "xla/mlir/runtime/transforms/passes.h"
#include "xla/mlir/runtime/utils/custom_calls.h"
namespace xla {
namespace runtime {
using namespace mlir; // NOLINT
#define GEN_PASS_DEF_ADDINITIALIZATIONS
#include "xla/mlir/runtime/transforms/passes.h.inc"
class AddInitializations
: public impl::AddInitializationsBase<AddInitializations> {
void runOnOperation() override;
};
//===----------------------------------------------------------------------====/
void AddInitializations::runOnOperation() {
ModuleOp module = getOperation();
bool requires_blas = false;
if (Attribute requires_blas_attr = module->getAttr(kRequiresBlasAttrName)) {
requires_blas = cast<BoolAttr>(requires_blas_attr).getValue();
}
if (!requires_blas) {
return;
}
SymbolTable sym_table(module);
CustomCallDeclarations custom_calls(std::move(sym_table));
ImplicitLocOpBuilder b(module->getLoc(), custom_calls.sym_table().getOp());
func::FuncOp initialize_cublas = custom_calls.GetOrCreate(
b, "xla.gpu.init_cublas", TypeRange(), TypeRange());
module.walk([&](func::FuncOp func) {
if (IntegerAttr exported = func.getOperation()->getAttrOfType<IntegerAttr>(
kExportedAttrName)) {
int64_t ordinal = exported.getInt();
if (ordinal == 0) {
b.setInsertionPointToStart(&*func.getBody().getBlocks().begin());
b.create<func::CallOp>(initialize_cublas.getName(), TypeRange());
}
}
});
}
std::unique_ptr<OperationPass<ModuleOp>> CreateAddInitializationsPass() {
return std::make_unique<AddInitializations>();
}
} // namespace runtime
} // namespace xla
......@@ -74,6 +74,7 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline(
// Export functions to the XLA runtime.
pm.addPass(CreateExportRuntimeFunctionsPass());
pm.addPass(CreateAddInitializationsPass());
pm.addPass(CreateConvertCustomCallsPass());
pm.addPass(CreateConvertAssertsPass());
......
......@@ -38,11 +38,13 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project
#include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
#include "xla/mlir/runtime/ir/rt_dialect.h"
#include "xla/mlir/runtime/ir/rt_ops.h"
#include "xla/mlir/runtime/transforms/compiler.h"
#include "xla/mlir/runtime/transforms/passes.h"
......@@ -349,7 +351,7 @@ MakeOptimizingTransformerForJit(llvm::TargetMachine* targetMachine) {
if (!results_memory_layout.ok()) return results_memory_layout.status();
bool requires_blas = false;
if (Attribute requires_blas_attr = func->getAttr("xla.requires_blas")) {
if (Attribute requires_blas_attr = func->getAttr(kRequiresBlasAttrName)) {
requires_blas = cast<BoolAttr>(requires_blas_attr).getValue();
}
......
......@@ -48,6 +48,9 @@ CreateExportRuntimeFunctionsPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateConvertCustomCallsPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateAddInitializationsPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateConvertAssertsPass();
//===-----------------------------------------------------------------------===/
......
......@@ -190,6 +190,18 @@ def ConvertAsserts : Pass<"xla-rt-convert-asserts", "ModuleOp"> {
let dependentDialects = ["xla::runtime::RuntimeDialect"];
}
def AddInitializations : Pass<"xla-rt-add-initializations", "ModuleOp"> {
let summary = "Add initialization custom calls";
let description = [{
Add custom calls that initialize library support to the beginning of the
entry funcion.
}];
let constructor = "xla::runtime::CreateAddInitializationsPass()";
let dependentDialects = ["xla::runtime::RuntimeDialect"];
}
//===-----------------------------------------------------------------------===/
// Conversions targeting `rt` dialect.
//===-----------------------------------------------------------------------===/
......
......@@ -376,6 +376,7 @@ cc_library(
"//xla/stream_executor:blas",
"//xla/stream_executor:device_memory",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/status",
] + if_cuda_is_configured([
"//xla/service/gpu:gemm_algorithm_picker",
"//xla/stream_executor/gpu:redzone_allocator",
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "xla/runtime/custom_call.h"
#include "xla/runtime/executable.h"
#include "xla/service/gpu/gpu_asm_opts_util.h"
......@@ -155,6 +156,17 @@ static absl::Status GemmImpl(const ServiceExecutableRunOptions* run_options,
deterministic_ops, stream);
}
static absl::Status InitCuBLASImpl(
const ServiceExecutableRunOptions* run_options) {
// Initialize (with memoization) BlasSupport here because cublasCreate fails
// during gpu graph capturing.
se::StreamExecutor* executor = run_options->stream()->parent();
if (!executor->AsBlas()) {
return absl::InternalError("Failed to initialize BLAS support");
}
return absl::OkStatus();
}
XLA_RUNTIME_DEFINE_CUSTOM_CALL(
Gemm, FunctionWrapper<GemmImpl>(), checks,
CustomCall::Bind("xla.gpu.gemm")
......@@ -172,8 +184,14 @@ XLA_RUNTIME_DEFINE_CUSTOM_CALL(
.Attr<DotDimensionNumbers>("dot_dims")
.Attr<absl::Span<const int32_t>>("precision"));
XLA_RUNTIME_DEFINE_CUSTOM_CALL(
InitCuBLAS, FunctionWrapper<InitCuBLASImpl>(), checks,
CustomCall::Bind("xla.gpu.init_cublas")
.UserData<const ServiceExecutableRunOptions*>());
void RegisterGemmCustomCalls(runtime::DirectCustomCallRegistry& registry) {
registry.Register("xla.gpu.gemm", Gemm);
registry.Register("xla.gpu.init_cublas", InitCuBLAS);
}
} // namespace gpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册