提交 8b177845 编写于 作者: J jackalcooper

naive kernel_conf creation

上级 6d411000
......@@ -254,6 +254,13 @@ llvm::Optional<TensorType> JitImporter::GetMlirTensorTypeFromBlobDesc(const Blob
}
}
ParallelContext GetSingleDeviceParallelContext() {
ParallelContext parallel_ctx;
parallel_ctx.set_parallel_id(0);
parallel_ctx.set_parallel_num(1);
return parallel_ctx;
}
void JitImporter::CreateOperandMapping(const ::oneflow::OperatorConf& op_conf,
const std::shared_ptr<const ParallelDesc> parallel_desc,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
......@@ -293,11 +300,16 @@ void JitImporter::CreateOperandMapping(const ::oneflow::OperatorConf& op_conf,
return lbi2logical_blob_desc_.at(bn).get();
};
CHECK_JUST(op->InferLogicalOutBlobDescs(GetLogicalBlobDesc4BnInOp, *parallel_desc));
KernelConf kernel_conf;
static ParallelContext parallel_ctx = GetSingleDeviceParallelContext();
for (auto& kv : lbi2logical_blob_desc_) {
CHECK(
result_type_mapping_.emplace(kv.first, GetMlirTensorTypeFromBlobDesc(*kv.second).getValue())
.second);
}
op->GenKernelConf(GetLogicalBlobDesc4BnInOp, &parallel_ctx, &kernel_conf);
llvm::errs() << "kernel_conf: \n";
llvm::errs() << kernel_conf.DebugString() << "\n";
}
llvm::Optional<mlir::Value> JitImporter::GetResultByBnAndIndex(const std::string& bn,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册