Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
3c7a4240
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3c7a4240
编写于
2月 01, 2018
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add device_type to kernel_conf
Former-commit-id:
0ed0a2df
上级
32a54944
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
74 addition
and
76 deletion
+74
-76
oneflow/core/actor/actor.cpp
oneflow/core/actor/actor.cpp
+1
-2
oneflow/core/graph/boxing_task_node.cpp
oneflow/core/graph/boxing_task_node.cpp
+2
-1
oneflow/core/graph/exec_graph.cpp
oneflow/core/graph/exec_graph.cpp
+6
-5
oneflow/core/graph/exec_graph.h
oneflow/core/graph/exec_graph.h
+3
-2
oneflow/core/graph/forward_compute_task_node.cpp
oneflow/core/graph/forward_compute_task_node.cpp
+2
-1
oneflow/core/graph/loss_compute_task_node.cpp
oneflow/core/graph/loss_compute_task_node.cpp
+4
-2
oneflow/core/graph/model_update_compute_task_node.cpp
oneflow/core/graph/model_update_compute_task_node.cpp
+2
-1
oneflow/core/graph/source_compute_task_node.cpp
oneflow/core/graph/source_compute_task_node.cpp
+2
-1
oneflow/core/graph/task_node.cpp
oneflow/core/graph/task_node.cpp
+2
-1
oneflow/core/kernel/concat_kernel.cpp
oneflow/core/kernel/concat_kernel.cpp
+2
-2
oneflow/core/kernel/kernel.cpp
oneflow/core/kernel/kernel.cpp
+3
-15
oneflow/core/kernel/kernel.h
oneflow/core/kernel/kernel.h
+18
-22
oneflow/core/kernel/kernel.proto
oneflow/core/kernel/kernel.proto
+2
-0
oneflow/core/kernel/model_update_kernel.cpp
oneflow/core/kernel/model_update_kernel.cpp
+4
-4
oneflow/core/kernel/model_update_kernel.h
oneflow/core/kernel/model_update_kernel.h
+4
-4
oneflow/core/kernel/multinomial_logistic_loss_kernel.cpp
oneflow/core/kernel/multinomial_logistic_loss_kernel.cpp
+4
-5
oneflow/core/kernel/softmax_loss_kernel.cpp
oneflow/core/kernel/softmax_loss_kernel.cpp
+3
-3
oneflow/core/operator/operator.cpp
oneflow/core/operator/operator.cpp
+3
-2
oneflow/core/operator/operator.h
oneflow/core/operator/operator.h
+7
-3
未找到文件。
oneflow/core/actor/actor.cpp
浏览文件 @
3c7a4240
...
...
@@ -20,8 +20,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
}
for
(
const
ExecNodeProto
&
node
:
task_proto
.
exec_sequence
().
exec_node
())
{
ExecKernel
ek
;
ek
.
kernel
=
ConstructKernel
(
GetDeviceType
(),
parallel_ctx
(),
node
.
kernel_conf
());
ek
.
kernel
=
ConstructKernel
(
parallel_ctx
(),
node
.
kernel_conf
());
ek
.
bn_in_op2regst_desc_id
=
PbMap2HashMap
(
node
.
bn_in_op2regst_desc_id
());
exec_kernel_vec_
.
push_back
(
std
::
move
(
ek
));
}
...
...
oneflow/core/graph/boxing_task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -204,7 +204,8 @@ void BoxingTaskNode::BuildWithChainPair(
node
->
BindBnInOpAndRegst
(
dtbn
,
middle_regst
);
}
if
(
lbn
!=
kPackedBlobName
)
{
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
nullptr
);
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
nullptr
,
device_type
());
}
}
}
...
...
oneflow/core/graph/exec_graph.cpp
浏览文件 @
3c7a4240
...
...
@@ -12,10 +12,11 @@ std::function<BlobDesc*(const std::string&)> ExecNode::GetBlobDesc4BnInOpFunc()
return
std
::
bind
(
&
ExecNode
::
GetBlobDesc4BnInOp
,
this
,
std
::
placeholders
::
_1
);
}
void
ExecNode
::
ToProto
(
bool
is_forward
,
const
ParallelContext
*
parallel_ctx
,
void
ExecNode
::
ToProto
(
bool
is_forward
,
DeviceType
device_type
,
const
ParallelContext
*
parallel_ctx
,
ExecNodeProto
*
ret
)
const
{
op_
->
GenKernelConf
(
GetBlobDesc4BnInOpFunc
(),
is_forward
,
parallel_ctx
,
ret
->
mutable_kernel_conf
());
op_
->
GenKernelConf
(
GetBlobDesc4BnInOpFunc
(),
is_forward
,
device_type
,
parallel_ctx
,
ret
->
mutable_kernel_conf
());
for
(
const
auto
&
bn_regst
:
bn_in_op2regst_
)
{
const
std
::
string
&
bn_in_op
=
bn_regst
.
first
;
auto
regst
=
bn_regst
.
second
.
lock
();
...
...
@@ -34,11 +35,11 @@ BlobDesc* ExecNode::GetBlobDesc4BnInOp(const std::string& bn_in_op) const {
return
regst
->
MutBlobDesc
(
lbn
);
}
void
ExecGraph
::
ToExecSequence
(
bool
is_forward
,
void
ExecGraph
::
ToExecSequence
(
bool
is_forward
,
DeviceType
device_type
,
const
ParallelContext
*
parallel_ctx
,
ExecSequence
*
ret
)
const
{
TopoForEachNode
([
&
](
ExecNode
*
node
)
{
node
->
ToProto
(
is_forward
,
parallel_ctx
,
ret
->
add_exec_node
());
node
->
ToProto
(
is_forward
,
device_type
,
parallel_ctx
,
ret
->
add_exec_node
());
});
}
...
...
oneflow/core/graph/exec_graph.h
浏览文件 @
3c7a4240
...
...
@@ -48,7 +48,8 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOpFunc
()
const
;
std
::
string
VisualStr
()
const
override
{
return
op_
->
op_name
();
}
void
ToProto
(
bool
is_forward
,
const
ParallelContext
*
,
ExecNodeProto
*
)
const
;
void
ToProto
(
bool
is_forward
,
DeviceType
,
const
ParallelContext
*
,
ExecNodeProto
*
)
const
;
private:
BlobDesc
*
GetBlobDesc4BnInOp
(
const
std
::
string
&
)
const
;
...
...
@@ -63,7 +64,7 @@ class ExecGraph final : public Graph<ExecNode, ExecEdge> {
ExecGraph
()
=
default
;
~
ExecGraph
()
=
default
;
void
ToExecSequence
(
bool
is_forward
,
const
ParallelContext
*
,
void
ToExecSequence
(
bool
is_forward
,
DeviceType
,
const
ParallelContext
*
,
ExecSequence
*
)
const
;
const
char
*
TypeName
()
const
override
{
return
"ExecGraph"
;
}
...
...
oneflow/core/graph/forward_compute_task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -37,7 +37,8 @@ void ForwardCompTaskNode::BuildExecGphAndRegst() {
BuildActivationRegst
();
BuildModelAndTmpRegsts
();
mut_exec_gph
().
TopoForEachNode
([
this
](
ExecNode
*
node
)
{
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
());
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
(),
device_type
());
});
}
...
...
oneflow/core/graph/loss_compute_task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -62,8 +62,10 @@ void LossCompTaskNode::BuildExecGphAndRegst() {
sum_node
->
BindBnInOpAndRegst
(
sum_op
->
SoleIbn
(),
data_tmp_regst
);
loss_regst
->
AddLbn
(
sum_op
->
Lbn4BnInOp
(
sum_op
->
SoleObn
()));
sum_node
->
BindBnInOpAndRegst
(
sum_op
->
SoleObn
(),
loss_regst
);
loss_op
->
InferBlobDescs
(
loss_node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
());
sum_op
->
InferBlobDescs
(
sum_node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
());
loss_op
->
InferBlobDescs
(
loss_node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
(),
device_type
());
sum_op
->
InferBlobDescs
(
sum_node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
(),
device_type
());
in_diff_regst
->
CopyBlobDescWithoutAddLbn
(
in_regst
.
get
());
}
...
...
oneflow/core/graph/model_update_compute_task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -79,7 +79,8 @@ void MdUpdtCompTaskNode::BuildExecGphAndRegst() {
data_tmp_regst
->
AddLbn
(
lbn
);
node
->
BindBnInOpAndRegst
(
dtbn
,
data_tmp_regst
);
}
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
nullptr
);
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
nullptr
,
device_type
());
}
void
MdUpdtCompTaskNode
::
LockRegsts
()
{
GetProducedRegst
(
"data_tmp"
)
->
Lock
();
}
...
...
oneflow/core/graph/source_compute_task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -32,7 +32,8 @@ void SourceCompTaskNode::BuildExecGphAndRegst() {
data_tmp_regst
->
AddLbn
(
lbn
);
node
->
BindBnInOpAndRegst
(
dtbn
,
data_tmp_regst
);
}
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
());
node
->
op
()
->
InferBlobDescs
(
node
->
GetBlobDesc4BnInOpFunc
(),
parallel_ctx
(),
device_type
());
}
void
SourceCompTaskNode
::
FixThrdId
()
{
...
...
oneflow/core/graph/task_node.cpp
浏览文件 @
3c7a4240
...
...
@@ -83,7 +83,8 @@ void TaskNode::ToProto(TaskProto* task_proto) {
task_proto
->
set_thrd_id
(
thrd_id_
);
task_proto
->
set_task_id
(
task_id_
);
exec_gph_
.
ToExecSequence
(
IsBackwardTaskType
(
GetTaskType
())
==
false
,
parallel_ctx
(),
task_proto
->
mutable_exec_sequence
());
device_type
(),
parallel_ctx
(),
task_proto
->
mutable_exec_sequence
());
auto
produced_regst_proto
=
task_proto
->
mutable_produced_regst_desc
();
for
(
auto
&
pair
:
produced_regsts_
)
{
RegstDescProto
regst_desc_proto
;
...
...
oneflow/core/kernel/concat_kernel.cpp
浏览文件 @
3c7a4240
...
...
@@ -88,12 +88,12 @@ void ConcatKernel<device_type>::BackwardDataContent(
namespace
{
Kernel
*
CreateConcatKernel
(
DeviceType
dev_type
)
{
Kernel
*
CreateConcatKernel
(
const
KernelConf
&
kernel_conf
)
{
static
const
HashMap
<
std
::
string
,
std
::
function
<
Kernel
*
()
>>
creators
=
{
#define CONCAT_KERNEL_ENTRY(device_type) \
{GetHashKey(device_type), []() { return new ConcatKernel<device_type>; }},
OF_PP_FOR_EACH_TUPLE
(
CONCAT_KERNEL_ENTRY
,
DEVICE_TYPE_SEQ
)};
return
creators
.
at
(
GetHashKey
(
dev_type
))();
return
creators
.
at
(
GetHashKey
(
kernel_conf
.
device_type
()
))();
}
}
// namespace
...
...
oneflow/core/kernel/kernel.cpp
浏览文件 @
3c7a4240
...
...
@@ -166,27 +166,15 @@ void AddKernelCreator(OperatorConf::OpTypeCase opcase, KernelCreator1 creator) {
CHECK
(
GetCreatorsMap
().
emplace
(
opcase
,
creator
).
second
);
}
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
opcase
,
KernelCreator2
creator
)
{
AddKernelCreator
(
opcase
,
[
creator
](
DeviceType
type
,
const
KernelConf
&
)
{
return
creator
(
type
);
});
}
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
opcase
,
KernelCreator3
creator
)
{
AddKernelCreator
(
opcase
,
[
creator
](
DeviceType
,
const
KernelConf
&
conf
)
{
return
creator
(
conf
);
});
}
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
opcase
,
KernelCreator4
creator
)
{
AddKernelCreator
(
opcase
,
[
creator
](
DeviceType
,
const
KernelConf
&
)
{
return
creator
();
});
AddKernelCreator
(
opcase
,
[
creator
](
const
KernelConf
&
)
{
return
creator
();
});
}
std
::
unique_ptr
<
const
Kernel
>
ConstructKernel
(
DeviceType
device_type
,
const
ParallelContext
*
parallel_ctx
,
const
KernelConf
&
conf
)
{
const
ParallelContext
*
parallel_ctx
,
const
KernelConf
&
conf
)
{
OperatorConf
::
OpTypeCase
opcase
=
conf
.
op_conf
().
op_type_case
();
auto
it
=
GetCreatorsMap
().
find
(
opcase
);
CHECK
(
it
!=
GetCreatorsMap
().
end
())
<<
opcase
;
Kernel
*
rptr
=
it
->
second
(
device_type
,
conf
);
Kernel
*
rptr
=
it
->
second
(
conf
);
rptr
->
Init
(
parallel_ctx
,
conf
);
return
std
::
unique_ptr
<
const
Kernel
>
(
rptr
);
}
...
...
oneflow/core/kernel/kernel.h
浏览文件 @
3c7a4240
...
...
@@ -117,16 +117,11 @@ class KernelIf : public Kernel {
void
(
Blob
::*
Copy
)(
DeviceCtx
*
,
const
Blob
*
))
const
;
};
using
KernelCreator1
=
std
::
function
<
Kernel
*
(
DeviceType
,
const
KernelConf
&
)
>
;
using
KernelCreator2
=
std
::
function
<
Kernel
*
(
DeviceType
)
>
;
using
KernelCreator3
=
std
::
function
<
Kernel
*
(
const
KernelConf
&
)
>
;
using
KernelCreator4
=
std
::
function
<
Kernel
*
()
>
;
using
KernelCreator1
=
std
::
function
<
Kernel
*
(
const
KernelConf
&
)
>
;
using
KernelCreator2
=
std
::
function
<
Kernel
*
()
>
;
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
,
KernelCreator1
);
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
,
KernelCreator2
);
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
,
KernelCreator3
);
void
AddKernelCreator
(
OperatorConf
::
OpTypeCase
,
KernelCreator4
);
std
::
unique_ptr
<
const
Kernel
>
ConstructKernel
(
DeviceType
,
const
ParallelContext
*
,
std
::
unique_ptr
<
const
Kernel
>
ConstructKernel
(
const
ParallelContext
*
,
const
KernelConf
&
);
}
// namespace oneflow
...
...
@@ -139,12 +134,13 @@ std::unique_ptr<const Kernel> ConstructKernel(DeviceType,
#define ADD_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, data_type_seq) \
namespace { \
\
Kernel* CreateKernel(
DeviceType dev_type, const KernelConf& kernel_conf) {
\
Kernel* CreateKernel(
const KernelConf& kernel_conf) {
\
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, \
(kernel_class), DEVICE_TYPE_SEQ, \
data_type_seq)}; \
return creators.at(GetHashKey(dev_type, kernel_conf.data_type()))(); \
return creators.at( \
GetHashKey(kernel_conf.device_type(), kernel_conf.data_type()))(); \
} \
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel)); \
...
...
@@ -154,18 +150,18 @@ std::unique_ptr<const Kernel> ConstructKernel(DeviceType,
{OF_PP_PAIR_SECOND(data_type_pair), \
[]() { return new kernel_class<OF_PP_PAIR_FIRST(data_type_pair)>(); }},
#define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class,
\
data_type_seq)
\
namespace {
\
\
Kernel* CreateKernel(
DeviceType dev_type, const KernelConf& kernel_conf) {
\
static const HashMap<int, std::function<Kernel*()>> creators = {
\
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY,
\
(kernel_class), data_type_seq)};
\
return creators.at(kernel_conf.data_type())();
\
}
\
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel));
\
#define ADD_CPU_DEFAULT_KERNEL_CREATOR(op_type_case, kernel_class, \
data_type_seq) \
namespace { \
\
Kernel* CreateKernel(
const KernelConf& kernel_conf) {
\
static const HashMap<int, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_CPU_KERNEL_CREATOR_ENTRY, \
(kernel_class), data_type_seq)}; \
return creators.at(kernel_conf.data_type())(); \
} \
\
COMMAND(AddKernelCreator(op_type_case, CreateKernel)); \
}
#endif // ONEFLOW_CORE_KERNEL_KERNEL_H_
oneflow/core/kernel/kernel.proto
浏览文件 @
3c7a4240
...
...
@@ -3,6 +3,7 @@ package oneflow;
import
"oneflow/core/operator/op_conf.proto"
;
import
"oneflow/core/common/data_type.proto"
;
import
"oneflow/core/job/resource.proto"
;
message
ConcatKernelConf
{
required
int64
total_cp_num
=
1
;
...
...
@@ -56,6 +57,7 @@ message KernelConf {
required
bool
need_do_col_num
=
13
;
required
bool
is_forward
=
14
;
required
DataType
data_type
=
15
;
required
DeviceType
device_type
=
16
;
oneof
kernel_type
{
MultinomialLogisticLossKernelConf
multinomial_logistic_loss_conf
=
106
;
...
...
oneflow/core/kernel/model_update_kernel.cpp
浏览文件 @
3c7a4240
...
...
@@ -63,15 +63,15 @@ OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_KERNEL, DEVICE_TYPE_SEQ,
namespace
{
Kernel
*
CreateMdUpdtKernel
(
DeviceType
dev_type
,
const
KernelConf
&
kernel_conf
)
{
Kernel
*
CreateMdUpdtKernel
(
const
KernelConf
&
kernel_conf
)
{
const
ModelUpdateOpUserConf
&
user_conf
=
kernel_conf
.
op_conf
().
mdupdt_conf
().
user_conf
();
if
(
user_conf
.
has_normal_conf
())
{
return
CreateNormalMdUpdtKernel
(
dev_type
,
kernel_conf
);
return
CreateNormalMdUpdtKernel
(
kernel_conf
);
}
else
if
(
user_conf
.
has_momentum_conf
())
{
return
CreateMomentumMdUpdtKernel
(
dev_type
,
kernel_conf
);
return
CreateMomentumMdUpdtKernel
(
kernel_conf
);
}
else
if
(
user_conf
.
has_rmsprop_conf
())
{
return
CreateRMSPropMdUpdtKernel
(
dev_type
,
kernel_conf
);
return
CreateRMSPropMdUpdtKernel
(
kernel_conf
);
}
else
{
UNEXPECTED_RUN
();
}
...
...
oneflow/core/kernel/model_update_kernel.h
浏览文件 @
3c7a4240
...
...
@@ -37,16 +37,16 @@ class MdUpdateKernelUtil final {
};
#define DECLARE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(
DeviceType,
const KernelConf&);
Kernel* Create##x##MdUpdtKernel(const KernelConf&);
#define DEFINE_MDUPDT_KERNEL_CREATOR(x) \
Kernel* Create##x##MdUpdtKernel(DeviceType dev_type, \
const KernelConf& kernel_conf) { \
Kernel* Create##x##MdUpdtKernel(const KernelConf& kernel_conf) { \
static const HashMap<std::string, std::function<Kernel*()>> creators = { \
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_KERNEL_CREATOR_ENTRY, \
(x##MdUpdateKernel), DEVICE_TYPE_SEQ, \
FLOATING_DATA_TYPE_SEQ)}; \
return creators.at(GetHashKey(dev_type, kernel_conf.data_type()))(); \
return creators.at( \
GetHashKey(kernel_conf.device_type(), kernel_conf.data_type()))(); \
}
}
// namespace oneflow
...
...
oneflow/core/kernel/multinomial_logistic_loss_kernel.cpp
浏览文件 @
3c7a4240
...
...
@@ -72,8 +72,7 @@ class MultinomialLogisticLossKernelUtil<DeviceType::kCPU, PredType, LabelType>
namespace
{
Kernel
*
CreateMultinomialLogisticLossKernel
(
DeviceType
dev_type
,
const
KernelConf
&
kernel_conf
)
{
Kernel
*
CreateMultinomialLogisticLossKernel
(
const
KernelConf
&
kernel_conf
)
{
static
const
HashMap
<
std
::
string
,
std
::
function
<
Kernel
*
()
>>
creators
=
{
#define MULTINOMIAL_LOGISTIC_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, \
label_type_pair) \
...
...
@@ -87,9 +86,9 @@ Kernel* CreateMultinomialLogisticLossKernel(DeviceType dev_type,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MULTINOMIAL_LOGISTIC_LOSS_KERNEL_ENTRY
,
DEVICE_TYPE_SEQ
,
FLOATING_DATA_TYPE_SEQ
,
INT_DATA_TYPE_SEQ
)};
return
creators
.
at
(
GetHashKey
(
dev_type
,
kernel_conf
.
multinomial_logistic_loss_conf
().
prediction_type
(),
kernel_conf
.
multinomial_logistic_loss_conf
().
label_type
()))();
return
creators
.
at
(
GetHashKey
(
kernel_conf
.
multinomial_logistic_loss_conf
().
prediction_type
(),
kernel_conf
.
multinomial_logistic_loss_conf
().
label_type
()))();
}
}
// namespace
...
...
oneflow/core/kernel/softmax_loss_kernel.cpp
浏览文件 @
3c7a4240
...
...
@@ -70,8 +70,7 @@ class SoftmaxLossKernelUtil<DeviceType::kCPU, PredType, LabelType> final {
namespace
{
Kernel
*
CreateSoftmaxLossKernel
(
DeviceType
dev_type
,
const
KernelConf
&
kernel_conf
)
{
Kernel
*
CreateSoftmaxLossKernel
(
const
KernelConf
&
kernel_conf
)
{
static
const
HashMap
<
std
::
string
,
std
::
function
<
Kernel
*
()
>>
creators
=
{
#define SOFTMAX_LOSS_KERNEL_ENTRY(device_type, pred_type_pair, \
label_type_pair) \
...
...
@@ -86,7 +85,8 @@ Kernel* CreateSoftmaxLossKernel(DeviceType dev_type,
DEVICE_TYPE_SEQ
,
FLOATING_DATA_TYPE_SEQ
,
INT_DATA_TYPE_SEQ
)};
return
creators
.
at
(
GetHashKey
(
dev_type
,
kernel_conf
.
softmax_loss_conf
().
prediction_type
(),
GetHashKey
(
kernel_conf
.
device_type
(),
kernel_conf
.
softmax_loss_conf
().
prediction_type
(),
kernel_conf
.
softmax_loss_conf
().
label_type
()))();
}
...
...
oneflow/core/operator/operator.cpp
浏览文件 @
3c7a4240
...
...
@@ -109,8 +109,8 @@ static bool HasBlobDescWithField(
void
Operator
::
GenKernelConf
(
std
::
function
<
const
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
bool
is_forward
,
const
ParallelContext
*
parallel_ctx
,
KernelConf
*
kernel_conf
)
const
{
bool
is_forward
,
DeviceType
device_type
,
const
ParallelContext
*
parallel_ctx
,
KernelConf
*
kernel_conf
)
const
{
*
(
kernel_conf
->
mutable_op_conf
())
=
op_conf_
;
*
(
kernel_conf
->
mutable_bn_in_op2lbn
())
=
HashMap2PbMap
(
bn_in_op2lbn_
);
*
(
kernel_conf
->
mutable_data_tmp_bns
())
=
StdVec2PbRpf
(
data_tmp_bns_
);
...
...
@@ -138,6 +138,7 @@ void Operator::GenKernelConf(
data_type
=
GetDataTypeFromBnInOpVec
(
GetBlobDesc4BnInOp
,
input_bns_
);
}
kernel_conf
->
set_data_type
(
data_type
);
kernel_conf
->
set_device_type
(
device_type
);
VirtualGenKernelConf
(
GetBlobDesc4BnInOp
,
parallel_ctx
,
kernel_conf
);
}
...
...
oneflow/core/operator/operator.h
浏览文件 @
3c7a4240
...
...
@@ -106,6 +106,11 @@ class Operator {
// Read: shape of input_blobs
// Write: shape of output_blobs, model_blobs, data_tmp_blobs, model_tmp_blobs
virtual
void
InferBlobDescs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
parallel_ctx
,
DeviceType
device_type
)
const
{
InferBlobDescs
(
GetBlobDesc4BnInOp
,
parallel_ctx
);
}
virtual
void
InferBlobDescs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
parallel_ctx
)
const
{
...
...
@@ -118,14 +123,13 @@ class Operator {
virtual
int32_t
MaxModelSplitNum
()
const
{
return
-
1
;
}
void
GenKernelConf
(
std
::
function
<
const
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
bool
is_forward
,
const
ParallelContext
*
parallel_ctx
,
KernelConf
*
kernel_conf
)
const
;
bool
is_forward
,
DeviceType
,
const
ParallelContext
*
,
KernelConf
*
)
const
;
protected:
virtual
void
VirtualFixParallelDesc
(
ParallelDesc
*
pr_desc
)
const
{}
virtual
void
VirtualGenKernelConf
(
std
::
function
<
const
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
parallel_ctx
,
KernelConf
*
kernel_conf
)
const
{}
const
ParallelContext
*
,
KernelConf
*
)
const
{}
virtual
std
::
string
ibn2lbn
(
const
std
::
string
&
input_bn
)
const
;
virtual
std
::
string
obn2lbn
(
const
std
::
string
&
output_bn
)
const
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录