Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow_yolov3
提交
2d6352ab
O
oneflow_yolov3
项目概览
Oneflow-Inc
/
oneflow_yolov3
9 个月 前同步成功
通知
4
Star
6
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow_yolov3
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
2d6352ab
编写于
7月 20, 2020
作者:
F
Flowingsun007
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace upsample op with flow.layers.upsample_2d
上级
ea9c95f4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
3 addition
and
199 deletion
+3
-199
oneflow_yolov3/model/yolo_net.py
oneflow_yolov3/model/yolo_net.py
+3
-3
oneflow_yolov3/ops/__init__.py
oneflow_yolov3/ops/__init__.py
+0
-3
oneflow_yolov3/ops/upsample_nearest.py
oneflow_yolov3/ops/upsample_nearest.py
+0
-17
ops/upsample_nearest.cu
ops/upsample_nearest.cu
+0
-176
未找到文件。
oneflow_yolov3/model/yolo_net.py
浏览文件 @
2d6352ab
import
oneflow
as
flow
import
oneflow.core.operator.op_conf_pb2
as
op_conf_util
from
oneflow_yolov3.ops.upsample_nearest
import
upsample_nearest
#
from oneflow_yolov3.ops.upsample_nearest import upsample_nearest
from
oneflow_yolov3.ops.yolo_detect
import
yolo_detect
,
yolo_box_diff
,
yolo_prob_loss
,
logistic
from
oneflow_yolov3.ops.yolo_nms
import
yolo_nms
...
...
@@ -117,8 +117,8 @@ def _leaky_relu(input, alpha=None, name=None):
def
_upsample
(
input
,
name
=
None
):
# return
flow.detection.
upsample_nearest(input, name=name, scale=2, data_format="channels_first")
return
upsample_nearest
(
input
,
name
=
name
,
scale
=
2
,
data_format
=
"channels_fir
st"
)
# return upsample_nearest(input, name=name, scale=2, data_format="channels_first")
return
flow
.
layers
.
upsample_2d
(
input
,
size
=
2
,
data_format
=
'NCHW'
,
interpolation
=
"neare
st"
)
def
conv_unit
(
data
,
num_filter
=
1
,
kernel
=
(
1
,
1
),
stride
=
(
1
,
1
),
pad
=
"same"
,
data_format
=
"NCHW"
,
use_bias
=
False
,
...
...
oneflow_yolov3/ops/__init__.py
浏览文件 @
2d6352ab
from
.upsample_nearest
import
upsample_nearest
__all__
=
[
k
for
k
in
globals
().
keys
()
if
not
k
.
startswith
(
"_"
)]
oneflow_yolov3/ops/upsample_nearest.py
已删除
100644 → 0
浏览文件 @
ea9c95f4
from
__future__
import
absolute_import
import
oneflow
as
flow
def
upsample_nearest
(
x
,
scale
,
name
,
data_format
=
"channels_first"
):
return
(
flow
.
user_op_builder
(
name
)
.
Op
(
"upsample_nearest"
)
.
Input
(
"x"
,
[
x
])
.
Output
(
"y"
)
.
Attr
(
"scale"
,
scale
)
.
Attr
(
"data_format"
,
data_format
)
.
Build
()
.
InferAndTryRun
()
.
RemoteBlobList
()[
0
]
)
ops/upsample_nearest.cu
已删除
100644 → 0
浏览文件 @
ea9c95f4
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
namespace
oneflow
{
namespace
{
template
<
typename
T
>
__global__
void
UpsampleNearestForward
(
const
int64_t
nthreads
,
const
T
*
in_dptr
,
const
int64_t
channel_num
,
const
int64_t
height
,
const
int64_t
width
,
const
int64_t
new_height
,
const
int64_t
new_width
,
const
float
scale_h
,
const
float
scale_w
,
const
bool
align_corners
,
T
*
out_dptr
)
{
const
int64_t
new_area
=
new_height
*
new_width
;
const
int64_t
channel_area
=
channel_num
*
height
*
width
;
const
int64_t
channel_new_area
=
channel_num
*
new_height
*
new_width
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
const
int64_t
h
=
(
index
/
new_width
)
%
new_height
;
const
int64_t
w
=
index
%
new_width
;
const
int64_t
c
=
(
index
/
new_area
)
%
channel_num
;
const
int64_t
n
=
index
/
channel_new_area
;
const
int64_t
in_h
=
min
((
align_corners
)
?
static_cast
<
int64_t
>
(
roundf
(
h
*
scale_h
))
:
static_cast
<
int64_t
>
(
floorf
(
h
*
scale_h
)),
height
-
1
);
const
int64_t
in_w
=
min
((
align_corners
)
?
static_cast
<
int64_t
>
(
roundf
(
w
*
scale_w
))
:
static_cast
<
int64_t
>
(
floorf
(
w
*
scale_w
)),
width
-
1
);
out_dptr
[
index
]
=
in_dptr
[
n
*
channel_area
+
(
c
*
height
+
in_h
)
*
width
+
in_w
];
}
}
template
<
typename
T
>
__global__
void
UpsampleNearestBackward
(
const
int64_t
nthreads
,
const
T
*
dy_dptr
,
const
int64_t
channel_num
,
const
int64_t
height
,
const
int64_t
width
,
const
int64_t
new_height
,
const
int64_t
new_width
,
const
float
scale_h
,
const
float
scale_w
,
const
bool
align_corners
,
T
*
dx_dptr
)
{
const
int64_t
area
=
height
*
width
;
const
int64_t
new_area
=
new_height
*
new_width
;
const
int64_t
channel_area
=
channel_num
*
height
*
width
;
const
int64_t
channel_new_area
=
channel_num
*
new_height
*
new_width
;
CUDA_1D_KERNEL_LOOP
(
index
,
nthreads
)
{
const
int64_t
h
=
(
index
/
new_width
)
%
new_height
;
const
int64_t
w
=
index
%
new_width
;
const
int64_t
c
=
(
index
/
new_area
)
%
channel_num
;
const
int64_t
n
=
index
/
channel_new_area
;
const
int64_t
in_h
=
min
((
align_corners
)
?
static_cast
<
int64_t
>
(
roundf
(
h
*
scale_h
))
:
static_cast
<
int64_t
>
(
floorf
(
h
*
scale_h
)),
height
-
1
);
const
int64_t
in_w
=
min
((
align_corners
)
?
static_cast
<
int64_t
>
(
roundf
(
w
*
scale_w
))
:
static_cast
<
int64_t
>
(
floorf
(
w
*
scale_w
)),
width
-
1
);
atomicAdd
(
dx_dptr
+
n
*
channel_area
+
(
c
*
height
+
in_h
)
*
width
+
in_w
,
dy_dptr
[
index
]);
}
}
}
// namespace
template
<
typename
T
>
class
UpsampleNearestGPUKernel
final
:
public
user_op
::
OpKernel
{
public:
UpsampleNearestGPUKernel
()
=
default
;
~
UpsampleNearestGPUKernel
()
=
default
;
private:
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
x_blob
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
user_op
::
Tensor
*
y_blob
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
const
int32_t
scale
=
ctx
->
Attr
<
int32_t
>
(
"scale"
);
const
int64_t
elem_cnt
=
y_blob
->
shape
().
elem_cnt
();
UpsampleNearestForward
<
T
>
<<<
BlocksNum4ThreadsNum
(
elem_cnt
),
1024
,
0
,
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
elem_cnt
,
x_blob
->
dptr
<
T
>
(),
x_blob
->
shape
().
At
(
1
),
x_blob
->
shape
().
At
(
2
),
x_blob
->
shape
().
At
(
3
),
y_blob
->
shape
().
At
(
2
),
y_blob
->
shape
().
At
(
3
),
1.
f
/
scale
,
1.
f
/
scale
,
false
,
y_blob
->
mut_dptr
<
T
>
());
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
template
<
typename
T
>
class
UpsampleNearestGradGPUKernel
final
:
public
user_op
::
OpKernel
{
public:
UpsampleNearestGradGPUKernel
()
=
default
;
~
UpsampleNearestGradGPUKernel
()
=
default
;
private:
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
user_op
::
Tensor
*
dx_blob
=
ctx
->
Tensor4ArgNameAndIndex
(
"dx"
,
0
);
if
(
dx_blob
==
nullptr
)
{
return
;
}
Memset
<
DeviceType
::
kGPU
>
(
ctx
->
device_ctx
(),
dx_blob
->
mut_dptr
<
T
>
(),
0
,
dx_blob
->
shape
().
elem_cnt
()
*
sizeof
(
T
));
const
user_op
::
Tensor
*
dy_blob
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
int32_t
scale
=
ctx
->
Attr
<
int32_t
>
(
"scale"
);
const
int64_t
elem_cnt
=
dy_blob
->
shape
().
elem_cnt
();
UpsampleNearestBackward
<
T
>
<<<
BlocksNum4ThreadsNum
(
elem_cnt
),
1024
,
0
,
ctx
->
device_ctx
()
->
cuda_stream
()
>>>
(
elem_cnt
,
dy_blob
->
dptr
<
T
>
(),
dx_blob
->
shape
().
At
(
1
),
dx_blob
->
shape
().
At
(
2
),
dx_blob
->
shape
().
At
(
3
),
dy_blob
->
shape
().
At
(
2
),
dy_blob
->
shape
().
At
(
3
),
1.
f
/
scale
,
1.
f
/
scale
,
false
,
dx_blob
->
mut_dptr
<
T
>
());
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
#define REGISTER_UPSAMPLE_NEAREST_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("upsample_nearest") \
.SetCreateFn<UpsampleNearestGPUKernel<dtype>>() \
.SetIsMatchedHob(user_op::HobTrue()); \
REGISTER_USER_KERNEL("upsample_nearest_grad") \
.SetCreateFn<UpsampleNearestGradGPUKernel<dtype>>() \
.SetIsMatchedHob(user_op::HobTrue());
REGISTER_UPSAMPLE_NEAREST_GPU_KERNEL
(
float
)
REGISTER_USER_OP
(
"upsample_nearest"
)
.
Input
(
"x"
)
.
Output
(
"y"
)
.
Attr
(
"scale"
,
UserOpAttrType
::
kAtInt32
)
.
Attr
(
"data_format"
,
UserOpAttrType
::
kAtString
)
.
SetTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
)
->
Maybe
<
void
>
{
const
Shape
*
x_shape
=
ctx
->
Shape4ArgNameAndIndex
(
"x"
,
0
);
Shape
*
y_shape
=
ctx
->
Shape4ArgNameAndIndex
(
"y"
,
0
);
const
int32_t
scale
=
ctx
->
Attr
<
int32_t
>
(
"scale"
);
if
(
ctx
->
Attr
<
std
::
string
>
(
"data_format"
)
!=
"channels_first"
||
x_shape
->
NumAxes
()
!=
4
)
{
LOG
(
FATAL
)
<<
"upsample_nearest only supports NCHW"
;
}
*
y_shape
=
Shape
({
x_shape
->
At
(
0
),
x_shape
->
At
(
1
),
scale
*
x_shape
->
At
(
2
),
scale
*
x_shape
->
At
(
3
)});
return
Maybe
<
void
>::
Ok
();
})
.
SetBatchAxisInferFn
([](
user_op
::
BatchAxisContext
*
ctx
)
->
Maybe
<
void
>
{
*
ctx
->
BatchAxis4ArgNameAndIndex
(
"y"
,
0
)
=
*
ctx
->
BatchAxis4ArgNameAndIndex
(
"x"
,
0
);
return
Maybe
<
void
>::
Ok
();
})
.
SetGetSbpFn
([](
user_op
::
SbpContext
*
ctx
)
->
Maybe
<
void
>
{
ctx
->
NewBuilder
().
Split
(
user_op
::
OpArg
(
"x"
,
0
),
0
).
Split
(
user_op
::
OpArg
(
"y"
,
0
),
0
).
Build
();
return
Maybe
<
void
>::
Ok
();
});
REGISTER_USER_OP
(
"upsample_nearest_grad"
)
.
Input
(
"dy"
)
.
Output
(
"dx"
)
.
Attr
(
"scale"
,
UserOpAttrType
::
kAtInt32
)
.
Attr
(
"data_format"
,
UserOpAttrType
::
kAtString
)
.
SetTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
)
->
Maybe
<
void
>
{
const
Shape
*
dy_shape
=
ctx
->
Shape4ArgNameAndIndex
(
"dy"
,
0
);
Shape
*
dx_shape
=
ctx
->
Shape4ArgNameAndIndex
(
"dx"
,
0
);
const
int32_t
scale
=
ctx
->
Attr
<
int32_t
>
(
"scale"
);
if
(
ctx
->
Attr
<
std
::
string
>
(
"data_format"
)
!=
"channels_first"
||
dy_shape
->
NumAxes
()
!=
4
)
{
LOG
(
FATAL
)
<<
"upsample_nearest only supports NCHW"
;
}
*
dx_shape
=
Shape
(
{
dy_shape
->
At
(
0
),
dy_shape
->
At
(
1
),
dy_shape
->
At
(
2
)
/
scale
,
dy_shape
->
At
(
3
)
/
scale
});
return
Maybe
<
void
>::
Ok
();
})
.
SetGetSbpFn
([](
user_op
::
SbpContext
*
ctx
)
->
Maybe
<
void
>
{
ctx
->
NewBuilder
().
Split
(
user_op
::
OpArg
(
"dy"
,
0
),
0
).
Split
(
user_op
::
OpArg
(
"dx"
,
0
),
0
).
Build
();
return
Maybe
<
void
>::
Ok
();
});
REGISTER_USER_OP_GRAD
(
"upsample_nearest"
)
.
SetGenBackwardOpConfFn
([](
const
user_op
::
UserOpWrapper
&
op
,
user_op
::
AddOpFn
AddOp
)
{
if
(
op
.
NeedGenGradTensor4OpInput
(
"x"
,
0
))
{
user_op
::
UserOpConfWrapperBuilder
builder
(
op
.
op_name
()
+
"_grad"
);
user_op
::
UserOpConfWrapper
grad_op
=
builder
.
Op
(
"upsample_nearest_grad"
)
.
Input
(
"dy"
,
op
.
GetGradTensorWithOpOutput
(
"y"
,
0
))
.
Output
(
"dx"
)
.
Attr
(
"scale"
,
op
.
attr
<
int32_t
>
(
"scale"
))
.
Attr
(
"data_format"
,
op
.
attr
<
std
::
string
>
(
"data_format"
))
.
Build
();
op
.
BindGradTensorWithOpInput
(
grad_op
.
output
(
"dx"
,
0
),
"x"
,
0
);
AddOp
(
grad_op
);
}
});
}
// namespace oneflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录