提交 3c7a4240 编写于 作者: W willzhang4a58

add device_type to kernel_conf


Former-commit-id: 0ed0a2df
上级 32a54944
......@@ -20,8 +20,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
}
for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {
ExecKernel ek;
ek.kernel =
ConstructKernel(GetDeviceType(), parallel_ctx(), node.kernel_conf());
ek.kernel = ConstructKernel(parallel_ctx(), node.kernel_conf());
ek.bn_in_op2regst_desc_id = PbMap2HashMap(node.bn_in_op2regst_desc_id());
exec_kernel_vec_.push_back(std::move(ek));
}
......
......@@ -204,7 +204,8 @@ void BoxingTaskNode::BuildWithChainPair(
node->BindBnInOpAndRegst(dtbn, middle_regst);
}
if (lbn != kPackedBlobName) {
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), nullptr);
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), nullptr,
device_type());
}
}
}
......
......@@ -12,10 +12,11 @@ std::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc()
return std::bind(&ExecNode::GetBlobDesc4BnInOp, this, std::placeholders::_1);
}
void ExecNode::ToProto(bool is_forward, const ParallelContext* parallel_ctx,
void ExecNode::ToProto(bool is_forward, DeviceType device_type,
const ParallelContext* parallel_ctx,
ExecNodeProto* ret) const {
op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), is_forward, parallel_ctx,
ret->mutable_kernel_conf());
op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), is_forward, device_type,
parallel_ctx, ret->mutable_kernel_conf());
for (const auto& bn_regst : bn_in_op2regst_) {
const std::string& bn_in_op = bn_regst.first;
auto regst = bn_regst.second.lock();
......@@ -34,11 +35,11 @@ BlobDesc* ExecNode::GetBlobDesc4BnInOp(const std::string& bn_in_op) const {
return regst->MutBlobDesc(lbn);
}
void ExecGraph::ToExecSequence(bool is_forward,
void ExecGraph::ToExecSequence(bool is_forward, DeviceType device_type,
const ParallelContext* parallel_ctx,
ExecSequence* ret) const {
TopoForEachNode([&](ExecNode* node) {
node->ToProto(is_forward, parallel_ctx, ret->add_exec_node());
node->ToProto(is_forward, device_type, parallel_ctx, ret->add_exec_node());
});
}
......
......@@ -48,7 +48,8 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOpFunc() const;
std::string VisualStr() const override { return op_->op_name(); }
void ToProto(bool is_forward, const ParallelContext*, ExecNodeProto*) const;
void ToProto(bool is_forward, DeviceType, const ParallelContext*,
ExecNodeProto*) const;
private:
BlobDesc* GetBlobDesc4BnInOp(const std::string&) const;
......@@ -63,7 +64,7 @@ class ExecGraph final : public Graph<ExecNode, ExecEdge> {
ExecGraph() = default;
~ExecGraph() = default;
void ToExecSequence(bool is_forward, const ParallelContext*,
void ToExecSequence(bool is_forward, DeviceType, const ParallelContext*,
ExecSequence*) const;
const char* TypeName() const override { return "ExecGraph"; }
......
......@@ -37,7 +37,8 @@ void ForwardCompTaskNode::BuildExecGphAndRegst() {
BuildActivationRegst();
BuildModelAndTmpRegsts();
mut_exec_gph().TopoForEachNode([this](ExecNode* node) {
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), parallel_ctx());
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), parallel_ctx(),
device_type());
});
}
......
......@@ -62,8 +62,10 @@ void LossCompTaskNode::BuildExecGphAndRegst() {
sum_node->BindBnInOpAndRegst(sum_op->SoleIbn(), data_tmp_regst);
loss_regst->AddLbn(sum_op->Lbn4BnInOp(sum_op->SoleObn()));
sum_node->BindBnInOpAndRegst(sum_op->SoleObn(), loss_regst);
loss_op->InferBlobDescs(loss_node->GetBlobDesc4BnInOpFunc(), parallel_ctx());
sum_op->InferBlobDescs(sum_node->GetBlobDesc4BnInOpFunc(), parallel_ctx());
loss_op->InferBlobDescs(loss_node->GetBlobDesc4BnInOpFunc(), parallel_ctx(),
device_type());
sum_op->InferBlobDescs(sum_node->GetBlobDesc4BnInOpFunc(), parallel_ctx(),
device_type());
in_diff_regst->CopyBlobDescWithoutAddLbn(in_regst.get());
}
......
......@@ -79,7 +79,8 @@ void MdUpdtCompTaskNode::BuildExecGphAndRegst() {
data_tmp_regst->AddLbn(lbn);
node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), nullptr);
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), nullptr,
device_type());
}
void MdUpdtCompTaskNode::LockRegsts() { GetProducedRegst("data_tmp")->Lock(); }
......
......@@ -32,7 +32,8 @@ void SourceCompTaskNode::BuildExecGphAndRegst() {
data_tmp_regst->AddLbn(lbn);
node->BindBnInOpAndRegst(dtbn, data_tmp_regst);
}
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), parallel_ctx());
node->op()->InferBlobDescs(node->GetBlobDesc4BnInOpFunc(), parallel_ctx(),
device_type());
}
void SourceCompTaskNode::FixThrdId() {
......
......@@ -83,7 +83,8 @@ void TaskNode::ToProto(TaskProto* task_proto) {
task_proto->set_thrd_id(thrd_id_);
task_proto->set_task_id(task_id_);
exec_gph_.ToExecSequence(IsBackwardTaskType(GetTaskType()) == false,
parallel_ctx(), task_proto->mutable_exec_sequence());
device_type(), parallel_ctx(),
task_proto->mutable_exec_sequence());
auto produced_regst_proto = task_proto->mutable_produced_regst_desc();
for (auto& pair : produced_regsts_) {
RegstDescProto regst_desc_proto;
......
......@@ -88,12 +88,12 @@ void ConcatKernel<device_type>::BackwardDataContent(
namespace {
Kernel* CreateConcatKernel(DeviceType dev_type) {
Kernel* CreateConcatKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
#define CONCAT_KERNEL_ENTRY(device_type) \
{GetHashKey(device_type), []() { return new ConcatKernel<device_type>; }},
OF_PP_FOR_EACH_TUPLE(CONCAT_KERNEL_ENTRY, DEVICE_TYPE_SEQ)};
return creators.at(GetHashKey(dev_type))();
return creators.at(GetHashKey(kernel_conf.device_type()))();
}
} // namespace
......
......@@ -166,27 +166,15 @@ void AddKernelCreator(OperatorConf::OpTypeCase opcase, KernelCreator1 creator) {
CHECK(GetCreatorsMap().emplace(opcase, creator).second);
}
void AddKernelCreator(OperatorConf::OpTypeCase opcase, KernelCreator2 creator) {
AddKernelCreator(opcase, [creator](DeviceType type, const KernelConf&) {
return creator(type);
});
}
void AddKernelCreator(OperatorConf::OpTypeCase opcase, KernelCreator3 creator) {
AddKernelCreator(opcase, [creator](DeviceType, const KernelConf& conf) {
return creator(conf);
});
}
void AddKernelCreator(OperatorConf::OpTypeCase opcase, KernelCreator4 creator) {
AddKernelCreator(
opcase, [creator](DeviceType, const KernelConf&) { return creator(); });
AddKernelCreator(opcase, [creator](const KernelConf&) { return creator(); });
}
std::unique_ptr<const Kernel> ConstructKernel(
DeviceType device_type, const ParallelContext* parallel_ctx,
const KernelConf& conf) {
const ParallelContext* parallel_ctx, const KernelConf& conf) {
OperatorConf::OpTypeCase opcase = conf.op_conf().op_type_case();
auto it = GetCreatorsMap().find(opcase);
CHECK(it != GetCreatorsMap().end()) << opcase;
Kernel* rptr = it->second(device_type, conf);
Kernel* rptr = it->second(conf);
rptr->Init(parallel_ctx, conf);
return std::unique_ptr<const Kernel>(rptr);
}
......
......@@ -117,16 +117,11 @@ class KernelIf : public Kernel {
void (Blob::*Copy)(DeviceCtx*, const Blob*)) const;
};
using KernelCreator1 = std::function<Kernel*(DeviceType, const KernelConf&)>;
using KernelCreator2 = std::function<Kernel*(DeviceType)>;
using KernelCreator3 = std::function<Kernel*(const KernelConf&)>;
using KernelCreator4 = std::function<Kernel*()>;
using KernelCreator1 = std::function<Kernel*(const KernelConf&)>;
using KernelCreator2 = std::function<Kernel*()>;
void AddKernelCreator(OperatorConf::OpTypeCase, KernelCreator1);
void AddKernelCreator(OperatorConf::OpTypeCase, KernelCreator2);
void AddKernelCreator(OperatorConf::OpTypeCase, KernelCreator3);
void AddKernelCreator(OperatorConf::OpTypeCase, KernelCreator4);
std::unique_ptr<const Kernel> ConstructKernel(DeviceType,
const ParallelContext*,
std::unique_ptr<const Kernel> ConstructKernel(const ParallelContext*,
const KernelConf&);
} // namespace oneflow
......@@ -139,12 +134,13 @@ std::unique_ptr<const Kernel> ConstructKernel(DeviceType,
#define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* CreateKernel(DeviceType dev_type, const KernelConf& kernel_conf) { \
Kernel* CreateKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, \
(kernel_class), DEVICE_TYPE_SEQ, \
data_type_seq)}; \
return creators.at(GetHashKey(dev_type, kernel_conf.data_type()))(); \
return creators.at( \
GetHashKey(kernel_conf.device_type(), kernel_conf.data_type()))(); \
} \
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel)); \
......@@ -154,18 +150,18 @@ std::unique_ptr<const Kernel> ConstructKernel(DeviceType,
{OF_PP_PAIR_SECOND(data_type_pair), \
[]() { return new kernel_class<OF_PP_PAIR_FIRST(data_type_pair)>(); }},
#define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, \
data_type_seq) \
namespace { \
\
Kernel* CreateKernel(DeviceType dev_type, const KernelConf& kernel_conf) { \
static const HashMap<int, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY, \
(kernel_class), data_type_seq)}; \
return creators.at(kernel_conf.data_type())(); \
} \
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel)); \
#define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, \
data_type_seq) \
namespace { \
\
Kernel* CreateKernel(const KernelConf& kernel_conf) { \
static const HashMap<int, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY, \
(kernel_class), data_type_seq)}; \
return creators.at(kernel_conf.data_type())(); \
} \
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel)); \
}
#endif // ONEFLOW_CORE_KERNEL_KERNEL_H_
......@@ -3,6 +3,7 @@ package oneflow;
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/common/data_type.proto";
import "oneflow/core/job/resource.proto";
message ConcatKernelConf {
required int64 total_cp_num = 1;
......@@ -56,6 +57,7 @@ message KernelConf {
required bool need_do_col_num = 13;
required bool is_forward = 14;
required DataType data_type = 15;
required DeviceType device_type = 16;
oneof kernel_type {
MultinomialLogisticLossKernelConf multinomial_logistic_loss_conf = 106;
......
......@@ -63,15 +63,15 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_KERNEL, DEVICE_TYPE_SEQ,
namespace {
Kernel* CreateMdUpdtKernel(DeviceType dev_type, const KernelConf& kernel_conf) {
Kernel* CreateMdUpdtKernel(const KernelConf& kernel_conf) {
const ModelUpdateOpUserConf& user_conf =
kernel_conf.op_conf().mdupdt_conf().user_conf();
if (user_conf.has_normal_conf()) {
return CreateNormalMdUpdtKernel(dev_type, kernel_conf);
return CreateNormalMdUpdtKernel(kernel_conf);
} else if (user_conf.has_momentum_conf()) {
return CreateMomentumMdUpdtKernel(dev_type, kernel_conf);
return CreateMomentumMdUpdtKernel(kernel_conf);
} else if (user_conf.has_rmsprop_conf()) {
return CreateRMSPropMdUpdtKernel(dev_type, kernel_conf);
return CreateRMSPropMdUpdtKernel(kernel_conf);
} else {
UNEXPECTED_RUN();
}
......
......@@ -37,16 +37,16 @@ class MdUpdateKernelUtil final {
};
#define DECLARE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(DeviceType, const KernelConf&);
Kernel* Create##x##MdUpdtKernel(const KernelConf&);
#define DEFINE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(DeviceType dev_type, \
const KernelConf& kernel_conf) { \
Kernel* Create##x##MdUpdtKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, \
(x##MdUpdateKernel), DEVICE_TYPE_SEQ, \
FLOATING_DATA_TYPE_SEQ)}; \
return creators.at(GetHashKey(dev_type, kernel_conf.data_type()))(); \
return creators.at( \
GetHashKey(kernel_conf.device_type(), kernel_conf.data_type()))(); \
}
} // namespace oneflow
......
......@@ -72,8 +72,7 @@ class MultinomialLogisticLossKernelUtil<DeviceType::kCPU, PredType, LabelType>
namespace {
Kernel* CreateMultinomialLogisticLossKernel(DeviceType dev_type,
const KernelConf& kernel_conf) {
Kernel* CreateMultinomialLogisticLossKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
#define MULTINOMIAL_LOGISTIC_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, \
label_type_pair) \
......@@ -87,9 +86,9 @@ Kernel* CreateMultinomialLogisticLossKernel(DeviceType dev_type,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MULTINOMIAL_LOGISTIC_LOSS_KERNEL_ENTRY,
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ,
INT_DATA_TYPE_SEQ)};
return creators.at(GetHashKey(
dev_type, kernel_conf.multinomial_logistic_loss_conf().prediction_type(),
kernel_conf.multinomial_logistic_loss_conf().label_type()))();
return creators.at(
GetHashKey(kernel_conf.multinomial_logistic_loss_conf().prediction_type(),
kernel_conf.multinomial_logistic_loss_conf().label_type()))();
}
} // namespace
......
......@@ -70,8 +70,7 @@ class SoftmaxLossKernelUtil<DeviceType::kCPU, PredType, LabelType> final {
namespace {
Kernel* CreateSoftmaxLossKernel(DeviceType dev_type,
const KernelConf& kernel_conf) {
Kernel* CreateSoftmaxLossKernel(const KernelConf& kernel_conf) {
static const HashMap<std::string, std::function<Kernel*()>> creators = {
#define SOFTMAX_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, \
label_type_pair) \
......@@ -86,7 +85,8 @@ Kernel* CreateSoftmaxLossKernel(DeviceType dev_type,
DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ,
INT_DATA_TYPE_SEQ)};
return creators.at(
GetHashKey(dev_type, kernel_conf.softmax_loss_conf().prediction_type(),
GetHashKey(kernel_conf.device_type(),
kernel_conf.softmax_loss_conf().prediction_type(),
kernel_conf.softmax_loss_conf().label_type()))();
}
......
......@@ -109,8 +109,8 @@ static bool HasBlobDescWithField(
void Operator::GenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
bool is_forward, const ParallelContext* parallel_ctx,
KernelConf* kernel_conf) const {
bool is_forward, DeviceType device_type,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {
*(kernel_conf->mutable_op_conf()) = op_conf_;
*(kernel_conf->mutable_bn_in_op2lbn()) = HashMap2PbMap(bn_in_op2lbn_);
*(kernel_conf->mutable_data_tmp_bns()) = StdVec2PbRpf(data_tmp_bns_);
......@@ -138,6 +138,7 @@ void Operator::GenKernelConf(
data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, input_bns_);
}
kernel_conf->set_data_type(data_type);
kernel_conf->set_device_type(device_type);
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);
}
......
......@@ -106,6 +106,11 @@ class Operator {
// Read: shape of input_blobs
// Write: shape of output_blobs, model_blobs, data_tmp_blobs, model_tmp_blobs
virtual void InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, DeviceType device_type) const {
InferBlobDescs(GetBlobDesc4BnInOp, parallel_ctx);
}
virtual void InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
......@@ -118,14 +123,13 @@ class Operator {
virtual int32_t MaxModelSplitNum() const { return -1; }
void GenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
bool is_forward, const ParallelContext* parallel_ctx,
KernelConf* kernel_conf) const;
bool is_forward, DeviceType, const ParallelContext*, KernelConf*) const;
protected:
virtual void VirtualFixParallelDesc(ParallelDesc* pr_desc) const {}
virtual void VirtualGenKernelConf(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const {}
const ParallelContext*, KernelConf*) const {}
virtual std::string ibn2lbn(const std::string& input_bn) const;
virtual std::string obn2lbn(const std::string& output_bn) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册