...
 
Commits (3)
    https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2189bd508c13843a38719869a33e20457a7b0391 refine 2021-11-02T10:56:39+08:00 jackalcooper jackalcooper@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/f4e8a306c14c3f379e327053980573e134119301 refine 2021-11-02T10:59:40+08:00 jackalcooper jackalcooper@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/d02a218acb5afdeacdb6f4007a2a73367d5c75ab add functions for kernel 2021-11-02T17:29:52+08:00 jackalcooper jackalcooper@gmail.com
......@@ -787,6 +787,13 @@ const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf) {
return kernel_reg_val->create_fn(&create_ctx);
}
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf) {
auto ctx = new UserKernelComputeContext(device_ctx, stream_ctx, kernel_conf);
return static_cast<user_op::KernelComputeContext*>(ctx);
}
} // namespace ir
} // namespace one
......
......@@ -68,6 +68,9 @@ namespace one {
namespace ir {
const user_op::OpKernel* GetKernel(const KernelConf& kernel_conf);
user_op::KernelComputeContext* GetKernelComputeContext(DeviceCtx* device_ctx,
StreamContext* stream_ctx,
const KernelConf& kernel_conf);
} // namespace ir
......
......@@ -170,6 +170,7 @@ class MlirJitCpuKernel final : public user_op::OpKernel {
private:
void Compute(user_op::KernelComputeContext* ctx) const override {
ctx->getcallback()(ctx);
WithMlirContext(
ctx, {},
[&ctx](mlir::MLIRContext* mlir_ctx) {
......
......@@ -26,6 +26,8 @@ limitations under the License.
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "oneflow/core/kernel/user_kernel.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/device/device_context_adapter.h"
namespace {
......@@ -54,15 +56,14 @@ ParallelContext GetSingleDeviceParallelContext() {
void InsertLbnSegmentIntoMapping(const ::mlir::ArrayAttr& lbn_segment_keys,
const ::mlir::ArrayAttr& lbn_segment_sizes, ValueRange values,
std::unordered_map<std::string, mlir::Value>& operand_mapping_) {
std::unordered_map<std::string, mlir::Value>& value_mapping_) {
auto operand_it = values.begin();
for (const auto& bn_size_pair : llvm::zip(lbn_segment_keys, lbn_segment_sizes)) {
const auto& bn = std::get<0>(bn_size_pair).dyn_cast<StringAttr>().getValue().str();
const auto& length = std::get<1>(bn_size_pair).dyn_cast<IntegerAttr>().getInt();
for (size_t i = 0; i < length; i++) {
const auto indexed_bn = bn + "_" + std::to_string(i);
LOG(ERROR) << "indexed_bn: " << indexed_bn;
assert(operand_mapping_.emplace(indexed_bn, *operand_it).second);
CHECK(value_mapping_.emplace(indexed_bn, *operand_it).second) << "indexed_bn: " << indexed_bn;
operand_it += 1;
}
}
......@@ -86,11 +87,39 @@ class ReturnAllLeaveResultPass : public ReturnAllLeaveResultPassBase<ReturnAllLe
struct JITKernelLaunchContext {
const OpKernel* kernel;
KernelComputeContext* compute_ctx;
JITKernelLaunchContext(const OpKernel* kernel, KernelComputeContext* compute_ctx)
: kernel(kernel), compute_ctx(compute_ctx) {}
};
KernelComputeContext* GetKernelComputeContext(const ::oneflow::UserOpConf& user_op_conf) {
static std::vector<std::shared_ptr<const OpKernel>> created_kernels;
static std::vector<std::shared_ptr<KernelComputeContext>> created;
StreamContext* GetStreamCxtFromStreamId(const StreamId& stream_id) {
StreamContext* stream_ctx =
NewObj<int, StreamContext, const StreamId&>(stream_id.device_id().device_type(), stream_id);
return stream_ctx;
}
StreamContext* GetComputeStreamCxt() {
static int64_t GPU0 = 0;
static DeviceId device_id(GlobalProcessCtx::Rank(), DeviceType::kGPU, GPU0);
static StreamContext* stream_ctx = GetStreamCxtFromStreamId(StreamId(device_id, 0));
return stream_ctx;
}
DeviceCtx* GetComputeDeviceCxt() {
static auto device_ctx = CHECK_NOTNULL(NewDeviceCtxAdapter(GetComputeStreamCxt()));
return device_ctx;
}
JITKernelLaunchContext* GetKernelLaunchContext(const KernelConf& kernel_conf) {
static std::vector<std::shared_ptr<const OpKernel>> managed_kernels;
static std::vector<std::shared_ptr<KernelComputeContext>> managed_compute_contexts;
static std::vector<std::shared_ptr<JITKernelLaunchContext>> managed_jit_kernel_launch_contexts;
managed_kernels.emplace_back(one::ir::GetKernel(kernel_conf));
managed_compute_contexts.emplace_back(
one::ir::GetKernelComputeContext(GetComputeDeviceCxt(), GetComputeStreamCxt(), kernel_conf));
auto jit_kernel_launch_ctx = std::make_shared<JITKernelLaunchContext>(
managed_kernels.back().get(), managed_compute_contexts.back().get());
managed_jit_kernel_launch_contexts.emplace_back(jit_kernel_launch_ctx);
return jit_kernel_launch_ctx.get();
}
extern "C" void _mlir_ciface_LaunchOneFlowKernel(JITKernelLaunchContext* ctx) {
......
......@@ -22,7 +22,6 @@ def exec(f):
m = args[0]
assert isinstance(m, oneflow.nn.Module)
for arg in args[1::]:
print(id(arg))
isinstance(arg, oneflow._oneflow_internal.Tensor)
func_name = str(uuid.uuid4()).replace("-", "")
func_name = f"jit{func_name}"
......