未验证 提交 83cc313c 编写于 作者: J Juncheng 提交者: GitHub

Refactor GenKernelConf (#4262)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 bc81dff0
......@@ -23,6 +23,13 @@ Maybe<void> OpKernelObject::ResetOpAndKernel(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc* parallel_desc) {
auto op = ConstructOp(op_conf_, device_type_, job_desc_.get());
JUST(op->FillOpParallelDesc(*parallel_desc));
const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn));
};
JUST(op->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp));
JUST(op->FillLogicalOutBlobDesc(LogicalBlobDesc4BnInOp));
JUST(op->FillSbpSignature(op_node_signature.sbp_signature()));
JUST(InferBlobDescs(*op, BlobDesc4BnInOp, &op_node_signature.sbp_signature(), parallel_ctx));
NewPartialInitializedKernel(*op, BlobDesc4BnInOp, op_node_signature, parallel_ctx, parallel_desc);
return Maybe<void>::Ok();
......@@ -40,11 +47,7 @@ void OpKernelObject::NewPartialInitializedKernel(
const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx,
const ParallelDesc* parallel_desc) {
KernelConf kernel_conf;
auto LogicalBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn_in_op));
};
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf, LogicalBlobDesc4BnInOp,
parallel_desc, &op_node_signature.sbp_signature());
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf);
kernel_.reset(new EagerKernel(job_desc_.get(), kernel_conf));
}
......@@ -53,6 +56,13 @@ Maybe<void> SystemOpKernelObject::ResetKernel(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc* parallel_desc) {
auto op = ConstructOp(op_conf_, device_type_, job_desc_.get());
JUST(op->FillOpParallelDesc(*parallel_desc));
const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn));
};
JUST(op->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp));
JUST(op->FillLogicalOutBlobDesc(LogicalBlobDesc4BnInOp));
JUST(op->FillSbpSignature(op_node_signature.sbp_signature()));
JUST(InferBlobDescs(*op, BlobDesc4BnInOp, &op_node_signature.sbp_signature(), parallel_ctx));
ResetKernel(*op, BlobDesc4BnInOp, op_node_signature, parallel_ctx, parallel_desc);
return Maybe<void>::Ok();
......@@ -70,11 +80,7 @@ void SystemOpKernelObject::ResetKernel(
const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx,
const ParallelDesc* parallel_desc) {
KernelConf kernel_conf;
auto LogicalBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> const BlobDesc& {
return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn_in_op));
};
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf, LogicalBlobDesc4BnInOp,
parallel_desc, &op_node_signature.sbp_signature());
op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf);
kernel_ = ConstructKernel(job_desc_.get(), kernel_conf, nullptr);
}
......
......@@ -59,11 +59,7 @@ void ExecNode::UnbindBnWithEmptyRegst() {
}
void ExecNode::ToProto(const ParallelContext* parallel_ctx, ExecNodeProto* ret) const {
const OpNode* op_node = Global<OpGraph>::Get()->OpNode4OpName(op_->op_name());
const ParallelDesc* parallel_desc = op_node == nullptr ? nullptr : &op_node->parallel_desc();
const SbpSignature* sbp_signature = op_node == nullptr ? nullptr : &op_node->sbp_signature();
op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf(),
GetLogicalBlobDesc4BnInOpFunc(), parallel_desc, sbp_signature);
op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf());
for (const auto& bn_regst : bn_in_op2regst_) {
const std::string& bn_in_op = bn_regst.first;
auto regst = bn_regst.second;
......@@ -83,7 +79,7 @@ void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) {
CHECK_JUST(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, sbp_signature));
Global<OpGraph>::Get()->CheckBlobDescs(op_->op_name(), GetBlobDesc4BnInOp, parallel_ctx);
CHECK_JUST(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_,
GetBlobDesc4BnInOp, parallel_ctx, sbp_signature));
GetBlobDesc4BnInOp, parallel_ctx));
}
std::function<const BlobDesc&(const std::string&)> ExecNode::GetLogicalBlobDesc4BnInOpFunc() const {
......
......@@ -64,7 +64,6 @@ class OpNode final : public Node<OpNode, OpEdge> {
// Setters
Operator* mut_op() { return op_.get(); }
ParallelDesc* mut_parallel_desc() { return &parallel_desc_; }
SbpSignature* mut_sbp_signature() { return mut_op()->mut_sbp_signature(); }
Shape* mut_out_blob_time_shape();
HashMap<std::string, std::vector<std::shared_ptr<BlobDesc>>>* mut_bn2parallel_id2blob_desc() {
return &bn2parallel_id2blob_desc_;
......
......@@ -308,17 +308,17 @@ Maybe<void> Operator::InferInternalBlobDescs(
Maybe<void> Operator::InferInplaceObn2IbnIf(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
return InferInplaceObn2Ibn(mut_inplace_obn2ibn, con_inplace_obn2ibn, GetBlobDesc4BnInOp,
parallel_ctx, sbp_signature);
parallel_ctx);
}
Maybe<void> Operator::InferInplaceObn2Ibn(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
for (const std::string& obn : output_bns()) {
const auto& obn_modifier = OutputBlobModifier4Obn(obn);
if (obn_modifier.has_mutable_inplace_ibn()) {
......@@ -373,21 +373,30 @@ void Operator::ForEachBnInOp(std::function<void(const std::string&)> Handler) co
for (const std::string& bn_in_op : tmp_bns()) { Handler(bn_in_op); }
}
Maybe<void> Operator::FillSbpSignature(const SbpSignature& sbp_signature) {
CHECK_OR_RETURN(!sbp_signature_);
sbp_signature_.reset(new SbpSignature(sbp_signature));
*op_attribute_.mutable_sbp_signature() = sbp_signature;
return Maybe<void>::Ok();
}
Maybe<void> Operator::InferSbpSignatureIf(
const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) {
SbpSignature signature;
if (parallel_desc.parallel_num() == 1) {
auto* bn2sbp = mut_sbp_signature()->mutable_bn_in_op2sbp_parallel();
auto* bn2sbp = signature.mutable_bn_in_op2sbp_parallel();
for (const auto& ibn : input_bns()) { (*bn2sbp)[ibn].mutable_split_parallel()->set_axis(0); }
for (const auto& obn : output_bns()) { (*bn2sbp)[obn].mutable_split_parallel()->set_axis(0); }
} else if (parallel_desc.parallel_num() > 1) {
return InferSbpSignature(mut_sbp_signature(), sbp_sig_conf, CalcOrderValue4SbpSig,
SbpInferHint4Ibn, parallel_desc);
InferSbpSignature(&signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn,
parallel_desc);
} else {
UNIMPLEMENTED();
}
FillSbpSignature(signature);
return Maybe<void>::Ok();
}
......@@ -529,13 +538,13 @@ Maybe<void> Operator::InferMirroredSignature(
}
Maybe<const SbpSignature*> Operator::sbp_signature() const {
CHECK_OR_RETURN(op_attribute_.has_sbp_signature()) << "sbp signature not infered";
return &op_attribute_.sbp_signature();
CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered";
return sbp_signature_.get();
}
Maybe<const SbpParallel*> Operator::SbpParallel4BnInOp(const std::string& bn_in_op) const {
CHECK_OR_RETURN(op_attribute_.has_sbp_signature()) << "sbp signature not infered";
const auto& map = op_attribute_.sbp_signature().bn_in_op2sbp_parallel();
CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered";
const auto& map = sbp_signature_->bn_in_op2sbp_parallel();
const auto& iter = map.find(bn_in_op);
CHECK_OR_RETURN(iter != map.end()) << "blob_name " << bn_in_op << " not found in sbp signature";
return &iter->second;
......@@ -571,10 +580,8 @@ bool HasBlobDescWithField(std::function<const BlobDesc*(const std::string&)> Get
} // namespace
void Operator::GenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const {
const std::function<const BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
auto* dtype_signature = kernel_conf->mutable_dtype_signature();
for (const std::string& ibn : input_bns()) {
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(ibn);
......@@ -606,31 +613,6 @@ void Operator::GenKernelConf(
kernel_conf->set_data_type(data_type);
}
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp,
parallel_desc, sbp_signature);
}
void Operator::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const {
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp,
parallel_desc);
}
void Operator::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc) const {
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp);
}
void Operator::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp) const {
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);
}
......
......@@ -95,8 +95,6 @@ class Operator {
Maybe<void> InferParallelSignatureIf();
Maybe<const ParallelDesc> GetParallelDesc4BnInOp(const std::string& bn) const;
// Read: shape of input_blobs
// Write: shape of output_blobs
Maybe<void> FillLogicalInBlobDesc(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp);
Maybe<void> FillLogicalInBlobDesc(
......@@ -128,11 +126,11 @@ class Operator {
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const;
Maybe<void> InferInplaceObn2IbnIf(HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const;
Maybe<void> InferInplaceObn2IbnIf(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const;
// Infer out blob's time shape
Maybe<void> InferOutputBlobTimeShapeIf(
......@@ -141,7 +139,8 @@ class Operator {
virtual Maybe<void> InferOutputBlobTimeShape(
std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*,
Shape* time_shape) const;
// Infer blob's SbpSignature
Maybe<void> FillSbpSignature(const SbpSignature& sbp_signature);
Maybe<void> InferSbpSignatureIf(
const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
......@@ -152,10 +151,8 @@ class Operator {
std::function<Maybe<const MirroredSigInferHint*>(const std::string&)>
MirroredSigInferHint4Ibn,
bool is_mirrored_parallel_view_conf, const ParallelDesc& parallel_desc);
void GenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const;
void GenKernelConf(const std::function<const BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext*, KernelConf*) const;
const InputBlobModifier& InputBlobModifier4Ibn(const std::string& ibn) const;
const OutputBlobModifier& OutputBlobModifier4Obn(const std::string& obn) const;
Maybe<const SbpParallel*> SbpParallel4BnInOp(const std::string& bn_in_op) const;
......@@ -175,7 +172,6 @@ class Operator {
ParallelSignature* mut_parallel_signature() { return op_attribute_.mutable_parallel_signature(); }
Maybe<const SbpSignature*> sbp_signature() const;
SbpSignature* mut_sbp_signature() { return op_attribute_.mutable_sbp_signature(); }
BlobLastUsedSignature* mut_blob_last_used_signature() {
return op_attribute_.mutable_blob_last_used_signature();
}
......@@ -225,22 +221,8 @@ class Operator {
virtual Maybe<void> InferInplaceObn2Ibn(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const;
virtual void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
KernelConf*, std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const;
virtual void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
KernelConf*, std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc) const;
virtual void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
KernelConf*, std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp) const;
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const;
virtual void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
......@@ -249,8 +231,6 @@ class Operator {
virtual LogicalBlobId lbi4ibn(const std::string& input_bn) const;
virtual LogicalBlobId lbi4obn(const std::string& output_bn) const;
OperatorConf* mut_op_conf() { return op_attribute_.mutable_op_conf(); }
// enroll data blobs
void EnrollTmpBn(const std::string& dtbn);
void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff);
......@@ -307,6 +287,7 @@ class Operator {
std::unique_ptr<HashMap<std::string, std::shared_ptr<const ParallelDesc>>> bn2parallel_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const BlobDesc>>> ibn2logical_blob_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const BlobDesc>>> obn2logical_blob_desc_;
std::shared_ptr<const SbpSignature> sbp_signature_;
PbRpf<std::string> input_output_bns_;
};
......
......@@ -401,9 +401,9 @@ Maybe<void> UserOp::InferOutBlobDescs(
Maybe<void> UserOp::InferInplaceObn2Ibn(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
UserOpInferContext infer_ctx(op_conf(), parallel_ctx, sbp_signature, job_desc(),
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
UserOpInferContext infer_ctx(op_conf(), parallel_ctx, JUST(sbp_signature()), job_desc(),
GetBlobDesc4BnInOp);
const user_op::OpKernelRegistryResult* kernel_reg_val =
JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(
......@@ -548,28 +548,23 @@ Symbol<OperatorConf> UserOp::GetOpConfWithoutOpNameAndLbn() const {
void UserOp::VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const {
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
auto user_conf = kernel_conf->mutable_user_conf();
*(user_conf->mutable_parallel_ctx()) = *parallel_ctx;
*(user_conf->mutable_sbp_sig()) = *sbp_signature;
#define BLOB_DESCS_TO_PROTO(prefix, is_arg) \
for (const auto& bn : prefix##_bns()) { \
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn); \
if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); } \
if (is_arg) { \
LogicalBlobDesc4BnInOp(bn).ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[bn]); \
} \
}
BLOB_DESCS_TO_PROTO(input, true)
BLOB_DESCS_TO_PROTO(output, true)
BLOB_DESCS_TO_PROTO(tmp, false)
#undef BLOB_DESCS_TO_PROTO
CHECK_NOTNULL(parallel_desc);
*user_conf->mutable_parallel_conf() = parallel_desc->parallel_conf();
*(user_conf->mutable_sbp_sig()) = *CHECK_JUST(sbp_signature());
ForEachBnInOp([&](const std::string& bn) {
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn);
if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); }
});
for (const std::string& ibn : input_bns()) {
CHECK_JUST(GetLogicalBlobDesc4Ibn(ibn))
->ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[ibn]);
}
for (const std::string& obn : output_bns()) {
CHECK_JUST(GetLogicalBlobDesc4Obn(obn))
->ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[obn]);
}
*user_conf->mutable_parallel_conf() = CHECK_JUST(GetOpParallelDesc())->parallel_conf();
}
REGISTER_OP(OperatorConf::kUserConf, UserOp);
......
......@@ -34,11 +34,11 @@ class UserOp final : public Operator {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext*,
const SbpSignature* sbp_signature) const override;
Maybe<void> InferInplaceObn2Ibn(HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
Maybe<void> InferInplaceObn2Ibn(
HashMap<std::string, std::string>* mut_inplace_obn2ibn,
HashMap<std::string, std::string>* con_inplace_obn2ibn,
const std::function<BlobDesc*(const std::string&)>& GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
Symbol<OperatorConf> GetOpConfWithoutOpNameAndLbn() const override;
private:
......@@ -55,11 +55,9 @@ class UserOp final : public Operator {
Maybe<void> InferOutputBlobTimeShape(
std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp, const ParallelContext*,
Shape* time_shape) const override;
void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf,
std::function<const BlobDesc&(const std::string&)> LogicalBlobDesc4BnInOp,
const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const override;
void VirtualGenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
KernelConf* kernel_conf) const override;
const user_op::OpRegistryResult* val_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册