...
 
Commits (3)
    https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2189bd508c13843a38719869a33e20457a7b0391 refine 2021-11-02T10:56:39+08:00 jackalcooper jackalcooper@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/f4e8a306c14c3f379e327053980573e134119301 refine 2021-11-02T10:59:40+08:00 jackalcooper jackalcooper@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/d02a218acb5afdeacdb6f4007a2a73367d5c75ab add functions for kernel 2021-11-02T17:29:52+08:00 jackalcooper jackalcooper@gmail.com
...@@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) { ...@@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) {
return kernel_reg_val->create_fn(&create_ctx); return kernel_reg_val->create_fn(&create_ctx);
} }
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf) {
auto ctx = new UserKernelComputeContext(device_ctx, stream_ctx, kernel_conf);
return static_cast<user_op::KernelComputeContext*>(ctx);
}
} // namespace ir } // namespace ir
} // namespace one } // namespace one
......
...@@ -68,6 +68,9 @@ namespace one { ...@@ -68,6 +68,9 @@ namespace one {
namespace ir { namespace ir {
const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf); const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf);
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf);
} // namespace ir } // namespace ir
......
...@@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel { ...@@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel {
private: private:
void Compute(user_op::KernelComputeContext* ctx) const override { void Compute(user_op::KernelComputeContext* ctx) const override {
ctx->getcallback()(ctx);
WithMlirContext( WithMlirContext(
ctx, {}, ctx, {},
[&ctx](mlir::MLIRContext* mlir_ctx) { [&ctx](mlir::MLIRContext* mlir_ctx) {
......
...@@ -26,6 +26,8 @@ limitations under the License. ...@@ -26,6 +26,8 @@ limitations under the License.
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "oneflow/core/kernel/user_kernel.h" #include "oneflow/core/kernel/user_kernel.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/device/device_context_adapter.h"
namespace { namespace {
...@@ -54,15 +56,14 @@ ParallelContext GetSingleDeviceParallelContext() { ...@@ -54,15 +56,14 @@ ParallelContext GetSingleDeviceParallelContext() {
void InsertLbnSegmentIntoMapping(const ::mlir::ArrayAttr& lbn_segment_keys, void InsertLbnSegmentIntoMapping(const ::mlir::ArrayAttr& lbn_segment_keys,
const ::mlir::ArrayAttr& lbn_segment_sizes, ValueRange values, const ::mlir::ArrayAttr& lbn_segment_sizes, ValueRange values,
std::unordered_map<std::string, mlir::Value>& operand_mapping_) { std::unordered_map<std::string, mlir::Value>& value_mapping_) {
auto operand_it = values.begin(); auto operand_it = values.begin();
for (const auto& bn_size_pair : llvm::zip(lbn_segment_keys, lbn_segment_sizes)) { for (const auto& bn_size_pair : llvm::zip(lbn_segment_keys, lbn_segment_sizes)) {
const auto& bn = std::get<0>(bn_size_pair).dyn_cast<StringAttr>().getValue().str(); const auto& bn = std::get<0>(bn_size_pair).dyn_cast<StringAttr>().getValue().str();
const auto& length = std::get<1>(bn_size_pair).dyn_cast<IntegerAttr>().getInt(); const auto& length = std::get<1>(bn_size_pair).dyn_cast<IntegerAttr>().getInt();
for (size_t i = 0; i < length; i++) { for (size_t i = 0; i < length; i++) {
const auto indexed_bn = bn + "_" + std::to_string(i); const auto indexed_bn = bn + "_" + std::to_string(i);
LOG(ERROR) << "indexed_bn: " << indexed_bn; CHECK(value_mapping_.emplace(indexed_bn, *operand_it).second) << "indexed_bn: " << indexed_bn;
assert(operand_mapping_.emplace(indexed_bn, *operand_it).second);
operand_it += 1; operand_it += 1;
} }
} }
...@@ -86,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLe ...@@ -86,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLe
struct JITKernelLaunchContext { struct JITKernelLaunchContext {
const OpKernel* kernel; const OpKernel* kernel;
KernelComputeContext* compute_ctx; KernelComputeContext* compute_ctx;
JITKernelLaunchContext(const OpKernel* kernel, KernelComputeContext* compute_ctx)
: kernel(kernel), compute_ctx(compute_ctx) {}
}; };
KernelComputeContext* GetKernelComputeContext(const ::oneflow::UserOpConf& user_op_conf) { StreamContext* GetStreamCxtFromStreamId(const StreamId& stream_id) {
static std::vector<std::shared_ptr<const OpKernel>> created_kernels; StreamContext* stream_ctx =
static std::vector<std::shared_ptr<KernelComputeContext>> created; NewObj<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(), stream_id);
return stream_ctx;
}
StreamContext* GetComputeStreamCxt() {
static int64_t GPU0 = 0;
static DeviceId device_id(GlobalProcessCtx::Rank(), DeviceType::kGPU, GPU0);
static StreamContext* stream_ctx = GetStreamCxtFromStreamId(StreamId(device_id, 0));
return stream_ctx;
}
DeviceCtx* GetComputeDeviceCxt() {
static auto device_ctx = CHECK_NOTNULL(NewDeviceCtxAdapter(GetComputeStreamCxt()));
return device_ctx;
}
JITKernelLaunchContext* GetKernelLaunchContext(const KernelConf& kernel_conf) {
static std::vector<std::shared_ptr<const OpKernel>> managed_kernels;
static std::vector<std::shared_ptr<KernelComputeContext>> managed_compute_contexts;
static std::vector<std::shared_ptr<JITKernelLaunchContext>> managed_jit_kernel_launch_contexts;
managed_kernels.emplace_back(one::ir::GetKernel(kernel_conf));
managed_compute_contexts.emplace_back(
one::ir::GetKernelComputeContext(GetComputeDeviceCxt(), GetComputeStreamCxt(), kernel_conf));
auto jit_kernel_launch_ctx = std::make_shared<JITKernelLaunchContext>(
managed_kernels.back().get(), managed_compute_contexts.back().get());
managed_jit_kernel_launch_contexts.emplace_back(jit_kernel_launch_ctx);
return jit_kernel_launch_ctx.get();
} }
extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) { extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) {
......
...@@ -22,7 +22,6 @@ def exec(f): ...@@ -22,7 +22,6 @@ def exec(f):
m = args[0] m = args[0]
assert isinstance(m, oneflow.nn.Module) assert isinstance(m, oneflow.nn.Module)
for arg in args[1::]: for arg in args[1::]:
print(id(arg))
isinstance(arg, oneflow._oneflow_internal.Tensor) isinstance(arg, oneflow._oneflow_internal.Tensor)
func_name = str(uuid.uuid4()).replace("-", "") func_name = str(uuid.uuid4()).replace("-", "")
func_name = f"jit{func_name}" func_name = f"jit{func_name}"
......