Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
04e3d8b2
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
04e3d8b2
编写于
12月 04, 2018
作者:
S
Shiyuan Shang-Guan
提交者:
Li Xinqi
12月 04, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add scalar_mul (#1553)
Former-commit-id: bc5b1c935372311367de69e38c37db543b30a19d
上级
e0c32b4f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
107 addition
and
0 deletion
+107
-0
oneflow/core/kernel/scalar_mul_kernel.cpp
oneflow/core/kernel/scalar_mul_kernel.cpp
+31
-0
oneflow/core/kernel/scalar_mul_kernel.h
oneflow/core/kernel/scalar_mul_kernel.h
+27
-0
oneflow/core/operator/op_conf.proto
oneflow/core/operator/op_conf.proto
+7
-0
oneflow/core/operator/scalar_mul_op.cpp
oneflow/core/operator/scalar_mul_op.cpp
+17
-0
oneflow/core/operator/scalar_mul_op.h
oneflow/core/operator/scalar_mul_op.h
+25
-0
未找到文件。
oneflow/core/kernel/scalar_mul_kernel.cpp
0 → 100644
浏览文件 @
04e3d8b2
#include "oneflow/core/kernel/scalar_mul_kernel.h"
namespace
oneflow
{
template
<
DeviceType
device_type
,
typename
T
>
void
ScalarMulKernel
<
device_type
,
T
>::
ForwardDataContent
(
const
KernelCtx
&
ctx
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
const
Blob
*
in_blob
=
BnInOp2Blob
(
"in"
);
Blob
*
out_blob
=
BnInOp2Blob
(
"out"
);
Memcpy
<
device_type
>
(
ctx
.
device_ctx
,
out_blob
->
mut_dptr
<
T
>
(),
in_blob
->
dptr
<
T
>
(),
out_blob
->
ByteSizeOfDataContentField
());
KernelUtil
<
device_type
,
T
>::
Scal
(
ctx
.
device_ctx
,
out_blob
->
shape
().
elem_cnt
(),
static_cast
<
T
>
(
this
->
op_conf
().
scalar_mul_conf
().
scalar
()),
out_blob
->
mut_dptr
<
T
>
(),
1
);
}
template
<
DeviceType
device_type
,
typename
T
>
void
ScalarMulKernel
<
device_type
,
T
>::
BackwardDataContent
(
const
KernelCtx
&
ctx
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
{
const
Blob
*
out_diff_blob
=
BnInOp2Blob
(
GenDiffBn
(
"out"
));
Blob
*
in_diff_blob
=
BnInOp2Blob
(
GenDiffBn
(
"in"
));
Memcpy
<
device_type
>
(
ctx
.
device_ctx
,
in_diff_blob
->
mut_dptr
<
T
>
(),
out_diff_blob
->
dptr
<
T
>
(),
out_diff_blob
->
ByteSizeOfDataContentField
());
KernelUtil
<
device_type
,
T
>::
Scal
(
ctx
.
device_ctx
,
in_diff_blob
->
shape
().
elem_cnt
(),
static_cast
<
T
>
(
this
->
op_conf
().
scalar_mul_conf
().
scalar
()),
in_diff_blob
->
mut_dptr
<
T
>
(),
1
);
}
ADD_DEFAULT_KERNEL_CREATOR
(
OperatorConf
::
kScalarMulConf
,
ScalarMulKernel
,
FLOATING_DATA_TYPE_SEQ
);
}
// namespace oneflow
oneflow/core/kernel/scalar_mul_kernel.h
0 → 100644
浏览文件 @
04e3d8b2
#ifndef ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace
oneflow
{
template
<
DeviceType
device_type
,
typename
T
>
class
ScalarMulKernel
final
:
public
KernelIf
<
device_type
>
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ScalarMulKernel
);
ScalarMulKernel
()
=
default
;
~
ScalarMulKernel
()
=
default
;
private:
void
ForwardDataContent
(
const
KernelCtx
&
ctx
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
override
;
void
BackwardDataContent
(
const
KernelCtx
&
ctx
,
std
::
function
<
Blob
*
(
const
std
::
string
&
)
>
BnInOp2Blob
)
const
override
;
const
PbMessage
&
GetCustomizedOpConf
()
const
override
{
return
this
->
op_conf
().
scalar_mul_conf
();
}
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_SCALAR_MUL_KERNEL_H_
oneflow/core/operator/op_conf.proto
浏览文件 @
04e3d8b2
...
...
@@ -674,6 +674,12 @@ message HingeLossOpConf {
optional
Norm
norm
=
7
[
default
=
L1
];
}
message
ScalarMulOpConf
{
required
string
in
=
1
;
required
string
out
=
2
;
required
float
scalar
=
3
;
}
message
OperatorConf
{
required
string
name
=
1
;
optional
string
model_load_dir
=
2
;
...
...
@@ -742,6 +748,7 @@ message OperatorConf {
LossPrintOpConf
loss_print_conf
=
235
;
DefineTestBlobConf
define_test_blob_conf
=
236
;
PReluOpConf
prelu_conf
=
237
;
ScalarMulOpConf
scalar_mul_conf
=
238
;
}
}
...
...
oneflow/core/operator/scalar_mul_op.cpp
0 → 100644
浏览文件 @
04e3d8b2
#include "oneflow/core/operator/scalar_mul_op.h"
namespace
oneflow
{
void
ScalarMulOp
::
InitFromOpConf
()
{
EnrollInputBn
(
"in"
);
EnrollOutputBn
(
"out"
);
}
void
ScalarMulOp
::
InferBlobDescs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
parallel_ctx
)
const
{
*
GetBlobDesc4BnInOp
(
"out"
)
=
*
GetBlobDesc4BnInOp
(
"in"
);
}
REGISTER_OP
(
OperatorConf
::
kScalarMulConf
,
ScalarMulOp
);
}
// namespace oneflow
oneflow/core/operator/scalar_mul_op.h
0 → 100644
浏览文件 @
04e3d8b2
#ifndef ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
#define ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
#include "oneflow/core/operator/operator.h"
namespace
oneflow
{
class
ScalarMulOp
final
:
public
Operator
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
ScalarMulOp
);
ScalarMulOp
()
=
default
;
~
ScalarMulOp
()
=
default
;
void
InitFromOpConf
()
override
;
void
InferBlobDescs
(
std
::
function
<
BlobDesc
*
(
const
std
::
string
&
)
>
GetBlobDesc4BnInOp
,
const
ParallelContext
*
parallel_ctx
)
const
override
;
bool
IsElemWiseOp
()
const
override
{
return
true
;
}
const
PbMessage
&
GetCustomizedConf
()
const
override
{
return
op_conf
().
scalar_mul_conf
();
}
bool
NeedInBlobWhenBackward
()
const
override
{
return
false
;
}
bool
NeedOutBlobWhenBackward
()
const
override
{
return
false
;
}
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_SCALAR_MUL_OP_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录