未验证 提交 c8d36c17 编写于 作者: J Juncheng 提交者: GitHub

SystemOpFillJobNamePass (#6138)

* SystemOpFillJobNamePass

* fix
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 eda2a011
......@@ -24,7 +24,7 @@ namespace user_op {
OpKernelInferCache::OpKernelInferCache(const KernelConf& kernel_conf, const JobDesc& job_desc) {
const OperatorConf& op_conf = kernel_conf.op_attribute().op_conf();
std::shared_ptr<Operator> op = CHECK_JUST(ConstructOp(op_conf));
cache_key_.job_desc = &job_desc;
cache_key_.scope = &job_desc;
cache_key_.op_conf_sym = op->GetOpConfWithoutOpNameAndLbn();
cache_key_.ibn_idx2shape_sym.resize(op->input_bns().size());
cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());
......
......@@ -135,6 +135,7 @@ Maybe<void> JobCompleter::Complete(Job* job) const {
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections));
JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections));
JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick));
JUST(JobPass4Name("SystemOpFillJobNamePass")(job, &job_pass_ctx));
JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx));
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
namespace {
class SystemOpFillJobNamePass final : public JobPass {
public:
OF_DISALLOW_COPY_AND_MOVE(SystemOpFillJobNamePass);
SystemOpFillJobNamePass() = default;
~SystemOpFillJobNamePass() override = default;
bool IsEnabled(const JobPassCtx& ctx) const { return true; }
Maybe<void> Apply(Job* job, JobPassCtx* ctx) const override {
const std::string& job_name = job->job_conf().job_name();
for (OperatorConf& op_conf : *job->mutable_net()->mutable_op()) {
if (op_conf.has_input_conf()) {
op_conf.mutable_input_conf()->set_job_name(job_name);
} else if (op_conf.has_wait_and_send_ids_conf()) {
op_conf.mutable_wait_and_send_ids_conf()->set_job_name(job_name);
} else if (op_conf.has_output_conf()) {
op_conf.mutable_output_conf()->set_job_name(job_name);
} else if (op_conf.has_return_conf()) {
op_conf.mutable_return_conf()->set_job_name(job_name);
} else if (op_conf.has_callback_notify_conf()) {
op_conf.mutable_callback_notify_conf()->set_job_name(job_name);
} else {
// do nothing
}
}
return Maybe<void>::Ok();
}
};
REGISTER_JOB_PASS("SystemOpFillJobNamePass", SystemOpFillJobNamePass);
} // namespace
} // namespace oneflow
......@@ -18,7 +18,6 @@ limitations under the License.
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -39,7 +38,8 @@ void CallbackNotifyKernel<T>::ForwardDataContent(const KernelContext* ctx) const
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
std::string buffer_name;
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
buffer_name = GetCallbackNotifierBufferName(ctx->job_desc()->job_name());
CHECK(this->op_conf().callback_notify_conf().has_job_name());
buffer_name = GetCallbackNotifierBufferName(this->op_conf().callback_notify_conf().job_name());
} else {
T job_id = *ctx->BnInOp2Blob("in")->dptr<T>();
buffer_name = this->op_conf().callback_notify_conf().callback_buffer_name(job_id);
......
......@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/case_kernel.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
......
......@@ -18,7 +18,6 @@ limitations under the License.
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -34,7 +33,8 @@ class InputKernel final : public Kernel {
private:
void ForwardDataContent(const KernelContext* ctx) const override {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = ctx->job_desc()->job_name();
CHECK(this->op_conf().input_conf().has_job_name());
const auto& job_name = this->op_conf().input_conf().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name));
......
......@@ -16,7 +16,6 @@ limitations under the License.
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/runtime_blob_shape_infer_helper.h"
#include "oneflow/core/kernel/kernel_observer.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -40,7 +39,7 @@ void Kernel::InitBase(const JobDesc* job_desc, const KernelConf& kernel_conf) {
if (shape_infer_helper_) { return; }
kernel_conf_ = kernel_conf;
shape_infer_helper_.reset(
new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), job_desc));
new RuntimeBlobShapeInferHelper(this->op_conf(), this->kernel_conf(), this));
}
void Kernel::Init(const KernelConf& kernel_conf, KernelContext* ctx) {
......
......@@ -19,7 +19,6 @@ limitations under the License.
#include "oneflow/core/device/cpu_device_context.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/job/nd_sbp_util.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/persistence/snapshot.h"
......
......@@ -17,7 +17,6 @@ limitations under the License.
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -36,7 +35,8 @@ class OutputKernel final : public Kernel {
template<DeviceType device_type>
void OutputKernel<device_type>::ForwardDataContent(const KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = ctx->job_desc()->job_name();
CHECK(this->op_conf().output_conf().has_job_name());
const auto& job_name = this->op_conf().output_conf().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));
......
......@@ -17,7 +17,6 @@ limitations under the License.
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -36,7 +35,8 @@ class ReturnKernel final : public Kernel {
template<DeviceType device_type>
void ReturnKernel<device_type>::ForwardDataContent(const KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = ctx->job_desc()->job_name();
CHECK(this->op_conf().return_conf().has_job_name());
const auto& job_name = this->op_conf().return_conf().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));
......
......@@ -23,7 +23,7 @@ namespace oneflow {
RuntimeBlobShapeInferHelper::RuntimeBlobShapeInferHelper(const OperatorConf& op_conf,
const KernelConf& kernel_conf,
const JobDesc* job_desc) {
const void* scope) {
op_ = CHECK_JUST(ConstructOp(op_conf));
const OpAttribute& op_attribute = kernel_conf.op_attribute();
if (op_attribute.has_parallel_conf_signature()
......@@ -55,7 +55,7 @@ RuntimeBlobShapeInferHelper::RuntimeBlobShapeInferHelper(const OperatorConf& op_
if (kernel_conf.has_parallel_ctx()) {
parallel_ctx_.reset(new ParallelContext(kernel_conf.parallel_ctx()));
}
op_infer_cache_key_.job_desc = job_desc;
op_infer_cache_key_.scope = scope;
op_infer_cache_key_.op_conf_sym = op_->GetOpConfWithoutOpNameAndLbn();
op_infer_cache_key_.ibn_idx2shape_sym.resize(op_->input_bns().size());
op_infer_cache_key_.dtype_signature_sym = SymbolOf(kernel_conf.dtype_signature());
......
......@@ -27,7 +27,7 @@ class BlobDesc;
class RuntimeBlobShapeInferHelper final {
public:
RuntimeBlobShapeInferHelper(const OperatorConf& op_conf, const KernelConf& kernel_conf,
const JobDesc* job_desc);
const void* scope);
~RuntimeBlobShapeInferHelper() = default;
void InferShape(const std::function<Blob*(const std::string&)>& BnInOp2Blob);
......
......@@ -18,7 +18,6 @@ limitations under the License.
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -39,7 +38,8 @@ void WaitAndSendIdsKernel<T>::ForwardDataContent(const KernelContext* ctx) const
const auto& conf = this->op_conf().wait_and_send_ids_conf();
if (status->out_idx_ >= status->out_num_) {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = ctx->job_desc()->job_name();
CHECK(this->op_conf().wait_and_send_ids_conf().has_job_name());
const auto& job_name = this->op_conf().wait_and_send_ids_conf().job_name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetSourceTickBufferName(job_name));
status->in_id_ = 0;
......
......@@ -124,6 +124,7 @@ message InputOpConf {
optional string tick = 1;
required string out = 2;
required InterfaceBlobConf blob_conf = 3;
optional string job_name = 4;
}
message ForeignInputOpConf {
......@@ -136,12 +137,14 @@ message ForeignInputOpConf {
message ReturnOpConf {
required string in = 1;
required string out = 2;
optional string job_name = 3;
}
message OutputOpConf {
required string in = 1;
required string out = 2;
required InterfaceBlobConf blob_conf = 3;
optional string job_name = 4;
}
message ForeignOutputOpConf {
......@@ -195,11 +198,13 @@ message WaitAndSendIdsOpConf {
required string wait_buffer_name = 2;
repeated Int64List id_list = 3;
required DataType data_type = 4 [default = kInt32];
optional string job_name = 5;
}
message CallbackNotifyOpConf {
required string in = 1;
repeated string callback_buffer_name = 2;
optional string job_name = 3;
}
message ReentrantLockOpConf {
......
......@@ -25,7 +25,7 @@ limitations under the License.
namespace oneflow {
struct OpInferCacheKey final {
const JobDesc* job_desc;
const void* scope;
Symbol<OperatorConf> op_conf_sym;
Symbol<DTypeSignature> dtype_signature_sym;
std::vector<Symbol<Shape>> ibn_idx2shape_sym;
......@@ -36,7 +36,7 @@ struct OpInferCacheValue final {
};
inline bool operator==(const OpInferCacheKey& lhs, const OpInferCacheKey& rhs) {
return lhs.job_desc == rhs.job_desc && lhs.op_conf_sym == rhs.op_conf_sym
return lhs.scope == rhs.scope && lhs.op_conf_sym == rhs.op_conf_sym
&& lhs.dtype_signature_sym == rhs.dtype_signature_sym
&& lhs.ibn_idx2shape_sym == rhs.ibn_idx2shape_sym;
}
......@@ -57,7 +57,7 @@ struct hash<oneflow::OpInferCacheKey> final {
for (const auto& shape_sym : op_infer_cache_key.ibn_idx2shape_sym) {
ibn_idx2shape_sym_hash_value ^= std::hash<Symbol<Shape>>()(shape_sym);
}
return std::hash<const JobDesc*>()(op_infer_cache_key.job_desc)
return std::hash<const void*>()(op_infer_cache_key.scope)
^ std::hash<Symbol<OperatorConf>>()(op_infer_cache_key.op_conf_sym)
^ ibn_idx2shape_sym_hash_value
^ std::hash<Symbol<DTypeSignature>>()(op_infer_cache_key.dtype_signature_sym);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册