提交 7e8efb0f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Configure LLVMDialect and extract pointer size.

PiperOrigin-RevId: 258361644
上级 b542d679
......@@ -39,8 +39,13 @@ cc_library(
deps = [
":failover_compiler",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service/gpu:gpu_constants",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl",
"//tensorflow/compiler/xla/service/gpu:target_constants",
"//tensorflow/core:lib",
"@local_config_mlir//:IR",
"@local_config_mlir//:LLVMDialect",
],
alwayslink = True, # Contains compiler registration
)
......@@ -15,13 +15,33 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h"
#include "mlir/LLVMIR/LLVMDialect.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace mlir {
using ::mlir::MLIRContext;
using ::mlir::LLVM::LLVMDialect;
namespace {
int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) {
LLVMDialect* dialect = context->getRegisteredDialect<LLVMDialect>();
llvm::Module& module = dialect->getLLVMModule();
module.setTargetTriple(gpu::nvptx::kTargetTriple);
module.setDataLayout(gpu::nvptx::kDataLayout);
return module.getDataLayout().getPointerSize();
}
} // namespace
MlirCompiler::MlirCompiler()
: pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {}
se::Platform::Id MlirCompiler::PlatformId() const {
return stream_executor::cuda::kCudaPlatformId;
}
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/service/compiler.h"
namespace xla {
......@@ -26,7 +27,7 @@ namespace mlir {
// generation of a think suitable for XLAs runtime.
class MlirCompiler : public Compiler {
public:
MlirCompiler() {}
MlirCompiler();
se::Platform::Id PlatformId() const override;
......@@ -58,12 +59,15 @@ class MlirCompiler : public Compiler {
const AotCompilationOptions& options) override;
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
// TODO(herhut): Get this from the LLVMDialect in MLIR.
int64 pointer_size = 8;
int64 pointer_size = pointer_size_;
return [pointer_size](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, pointer_size);
};
}
private:
::mlir::MLIRContext context_;
int64 pointer_size_;
};
} // namespace mlir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册