Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f29eacbb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f29eacbb
编写于
4月 23, 2020
作者:
A
Adel Shafiei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed an input validation error for uniform augment op
上级
d132b61b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
10 addition
and
15 deletion
+10
-15
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
+7
-13
mindspore/dataset/transforms/vision/validators.py
mindspore/dataset/transforms/vision/validators.py
+3
-2
未找到文件。
mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc
浏览文件 @
f29eacbb
...
...
@@ -42,34 +42,28 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
output
)
{
IO_CHECK_VECTOR
(
input
,
output
);
// variables to generate random number to select ops from the list
std
::
vector
<
int
>
random_indexes
;
// variables to copy the result to output if it is not already
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
even_out
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
*
even_out_ptr
=
&
even_out
;
int
count
=
1
;
// select random indexes for candidates to be applied
for
(
int
i
=
0
;
i
<
num_ops_
;
++
i
)
{
random_indexes
.
insert
(
random_indexes
.
end
(),
std
::
uniform_int_distribution
<
int
>
(
0
,
tensor_op_list_
.
size
()
-
1
)(
rnd_
));
}
// randomly select ops to be applied
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
selected_tensor_ops
;
std
::
sample
(
tensor_op_list_
.
begin
(),
tensor_op_list_
.
end
(),
std
::
back_inserter
(
selected_tensor_ops
),
num_ops_
,
rnd_
);
for
(
auto
it
=
random_indexes
.
begin
();
it
!=
random_indexes
.
end
();
++
it
)
{
for
(
auto
tensor_op
=
selected_tensor_ops
.
begin
();
tensor_op
!=
selected_tensor_ops
.
end
();
++
tensor_op
)
{
// Do NOT apply the op, if second random generator returned zero
if
(
std
::
uniform_int_distribution
<
int
>
(
0
,
1
)(
rnd_
))
{
continue
;
}
std
::
shared_ptr
<
TensorOp
>
tensor_op
=
tensor_op_list_
[
*
it
];
// apply python/C++ op
if
(
count
==
1
)
{
(
*
tensor_op
).
Compute
(
input
,
output
);
(
*
*
tensor_op
).
Compute
(
input
,
output
);
}
else
if
(
count
%
2
==
0
)
{
(
*
tensor_op
).
Compute
(
*
output
,
even_out_ptr
);
(
*
*
tensor_op
).
Compute
(
*
output
,
even_out_ptr
);
}
else
{
(
*
tensor_op
).
Compute
(
even_out
,
output
);
(
*
*
tensor_op
).
Compute
(
even_out
,
output
);
}
count
++
;
}
...
...
mindspore/dataset/transforms/vision/validators.py
浏览文件 @
f29eacbb
...
...
@@ -17,11 +17,12 @@
import
numbers
from
functools
import
wraps
from
mindspore._c_dataengine
import
TensorOp
from
.utils
import
Inter
,
Border
from
...transforms.validators
import
check_pos_int32
,
check_pos_float32
,
check_value
,
check_uint8
,
FLOAT_MAX_INTEGER
,
\
check_bool
,
check_2tuple
,
check_range
,
check_list
,
check_type
,
check_positive
,
INT32_MAX
def
check_inter_mode
(
mode
):
if
not
isinstance
(
mode
,
Inter
):
raise
ValueError
(
"Invalid interpolation mode."
)
...
...
@@ -836,7 +837,7 @@ def check_uniform_augmentation(method):
if
not
isinstance
(
operations
,
list
):
raise
ValueError
(
"operations is not a python list"
)
for
op
in
operations
:
if
not
callable
(
op
):
if
not
callable
(
op
)
and
not
isinstance
(
op
,
TensorOp
)
:
raise
ValueError
(
"non-callable op in operations list"
)
kwargs
[
"num_ops"
]
=
num_ops
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录