Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
889e5834
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
889e5834
编写于
12月 13, 2022
作者:
R
Ryan
提交者:
GitHub
12月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2St] transforms.RandomVerticalFlip Support static mode (#49024)
* add static RandomVerticalFlip * object => unittest.TestCase
上级
31922692
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
131 addition
and
2 deletion
+131
-2
python/paddle/tests/test_transforms_static.py
python/paddle/tests/test_transforms_static.py
+102
-0
python/paddle/vision/transforms/functional.py
python/paddle/vision/transforms/functional.py
+5
-1
python/paddle/vision/transforms/functional_tensor.py
python/paddle/vision/transforms/functional_tensor.py
+11
-1
python/paddle/vision/transforms/transforms.py
python/paddle/vision/transforms/transforms.py
+13
-0
未找到文件。
python/paddle/tests/test_transforms_static.py
0 → 100644
浏览文件 @
889e5834
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle
from
paddle.vision.transforms
import
transforms
SEED
=
2022
class
TestTransformUnitTestBase
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
img
=
(
np
.
random
.
rand
(
*
self
.
get_shape
())
*
255.0
).
astype
(
np
.
float32
)
self
.
set_trans_api
()
def
get_shape
(
self
):
return
(
64
,
64
,
3
)
def
set_trans_api
(
self
):
self
.
api
=
transforms
.
Resize
(
size
=
16
)
def
dynamic_transform
(
self
):
paddle
.
seed
(
SEED
)
img_t
=
paddle
.
to_tensor
(
self
.
img
)
return
self
.
api
(
img_t
)
def
static_transform
(
self
):
paddle
.
enable_static
()
paddle
.
seed
(
SEED
)
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
):
x
=
paddle
.
static
.
data
(
shape
=
self
.
get_shape
(),
dtype
=
paddle
.
float32
,
name
=
'img'
)
out
=
self
.
api
(
x
)
exe
=
paddle
.
static
.
Executor
()
res
=
exe
.
run
(
main_program
,
fetch_list
=
[
out
],
feed
=
{
'img'
:
self
.
img
})
paddle
.
disable_static
()
return
res
[
0
]
def
test_transform
(
self
):
dy_res
=
self
.
dynamic_transform
()
st_res
=
self
.
static_transform
()
np
.
testing
.
assert_almost_equal
(
dy_res
,
st_res
)
class
TestResize
(
TestTransformUnitTestBase
):
def
set_trans_api
(
self
):
self
.
api
=
transforms
.
Resize
(
size
=
(
16
,
16
))
class
TestResizeError
(
TestTransformUnitTestBase
):
def
test_transform
(
self
):
pass
def
test_error
(
self
):
paddle
.
enable_static
()
# Not support while w<=0 or h<=0, but received w=-1, h=-1
with
self
.
assertRaises
(
NotImplementedError
):
main_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
):
x
=
paddle
.
static
.
data
(
shape
=
[
-
1
,
-
1
,
-
1
],
dtype
=
paddle
.
float32
,
name
=
'img'
)
self
.
api
(
x
)
paddle
.
disable_static
()
class
TestRandomVerticalFlip0
(
TestTransformUnitTestBase
):
def
set_trans_api
(
self
):
self
.
api
=
transforms
.
RandomVerticalFlip
(
prob
=
0
)
class
TestRandomVerticalFlip1
(
TestTransformUnitTestBase
):
def
set_trans_api
(
self
):
self
.
api
=
transforms
.
RandomVerticalFlip
(
prob
=
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/vision/transforms/functional.py
浏览文件 @
889e5834
...
...
@@ -20,6 +20,7 @@ from PIL import Image
import
paddle
from
...fluid.framework
import
Variable
from
.
import
functional_cv2
as
F_cv2
from
.
import
functional_pil
as
F_pil
from
.
import
functional_tensor
as
F_t
...
...
@@ -32,7 +33,10 @@ def _is_pil_image(img):
def
_is_tensor_image
(
img
):
return
isinstance
(
img
,
paddle
.
Tensor
)
"""
Return True if img is a Tensor for dynamic mode or Variable for static mode.
"""
return
isinstance
(
img
,
(
paddle
.
Tensor
,
Variable
))
def
_is_numpy_image
(
img
):
...
...
python/paddle/vision/transforms/functional_tensor.py
浏览文件 @
889e5834
...
...
@@ -18,12 +18,14 @@ import numbers
import
paddle
import
paddle.nn.functional
as
F
from
...fluid.framework
import
Variable
__all__
=
[]
def
_assert_image_tensor
(
img
,
data_format
):
if
(
not
isinstance
(
img
,
paddle
.
Tensor
)
not
isinstance
(
img
,
(
paddle
.
Tensor
,
Variable
)
)
or
img
.
ndim
<
3
or
img
.
ndim
>
4
or
not
data_format
.
lower
()
in
(
'chw'
,
'hwc'
)
...
...
@@ -725,6 +727,14 @@ def resize(img, size, interpolation='bilinear', data_format='CHW'):
if
isinstance
(
size
,
int
):
w
,
h
=
_get_image_size
(
img
,
data_format
)
# TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape.
# We should consider to support this case in future.
if
w
<=
0
or
h
<=
0
:
raise
NotImplementedError
(
"Not support while w<=0 or h<=0, but received w={}, h={}"
.
format
(
w
,
h
)
)
if
(
w
<=
h
and
w
==
size
)
or
(
h
<=
w
and
h
==
size
):
return
img
if
w
<
h
:
...
...
python/paddle/vision/transforms/transforms.py
浏览文件 @
889e5834
...
...
@@ -653,10 +653,23 @@ class RandomVerticalFlip(BaseTransform):
self
.
prob
=
prob
def
_apply_image
(
self
,
img
):
if
paddle
.
in_dynamic_mode
():
return
self
.
_dynamic_apply_image
(
img
)
else
:
return
self
.
_static_apply_image
(
img
)
def
_dynamic_apply_image
(
self
,
img
):
if
random
.
random
()
<
self
.
prob
:
return
F
.
vflip
(
img
)
return
img
def
_static_apply_image
(
self
,
img
):
return
paddle
.
static
.
nn
.
cond
(
paddle
.
rand
(
shape
=
(
1
,))
<
self
.
prob
,
lambda
:
F
.
vflip
(
img
),
lambda
:
img
,
)
class
Normalize
(
BaseTransform
):
"""Normalize the input data with mean and standard deviation.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录