Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d0c5071c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0c5071c
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1993 [Dataset] Fix codedex.
Merge pull request !1993 from luoyang/pylint
上级
2ecd5bdf
dee8471d
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
52 addition
and
166 deletion
+52
-166
mindspore/ccsrc/dataset/core/tensor_row.h
mindspore/ccsrc/dataset/core/tensor_row.h
+7
-7
mindspore/ccsrc/dataset/engine/dataset_iterator.cc
mindspore/ccsrc/dataset/engine/dataset_iterator.cc
+6
-1
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc
...rc/dataset/engine/datasetops/bucket_batch_by_length_op.cc
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h
...src/dataset/engine/datasetops/bucket_batch_by_length_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h
mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h
+2
-0
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
+0
-2
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
+0
-1
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
+2
-2
mindspore/ccsrc/dataset/engine/perf/connector_size.h
mindspore/ccsrc/dataset/engine/perf/connector_size.h
+3
-1
mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h
...pore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h
+2
-2
mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h
mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h
+2
-2
mindspore/ccsrc/dataset/engine/perf/monitor.cc
mindspore/ccsrc/dataset/engine/perf/monitor.cc
+1
-0
mindspore/ccsrc/dataset/engine/perf/monitor.h
mindspore/ccsrc/dataset/engine/perf/monitor.h
+2
-0
mindspore/ccsrc/dataset/engine/perf/profiling.h
mindspore/ccsrc/dataset/engine/perf/profiling.h
+1
-1
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc
+2
-1
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
+0
-1
mindspore/ccsrc/dataset/kernels/data/fill_op.h
mindspore/ccsrc/dataset/kernels/data/fill_op.h
+0
-3
mindspore/ccsrc/mindrecord/meta/shard_column.cc
mindspore/ccsrc/mindrecord/meta/shard_column.cc
+7
-3
mindspore/dataset/transforms/vision/py_transforms_util.py
mindspore/dataset/transforms/vision/py_transforms_util.py
+15
-13
tests/ut/python/dataset/prep_data.py
tests/ut/python/dataset/prep_data.py
+0
-124
未找到文件。
mindspore/ccsrc/dataset/core/tensor_row.h
浏览文件 @
d0c5071c
...
...
@@ -35,13 +35,13 @@ class TensorRow {
static
constexpr
row_id_type
kDefaultRowId
=
-
1
;
// Default row id
// Type definitions
typedef
dsize_t
size_type
;
typedef
std
::
shared_ptr
<
Tensor
>
value_type
;
typedef
std
::
shared_ptr
<
Tensor
>
&
reference
;
typedef
const
std
::
shared_ptr
<
Tensor
>
&
const_reference
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
vector_type
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
iterator
iterator
;
typedef
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
const_iterator
const_iterator
;
using
size_type
=
dsize_t
;
using
value_type
=
std
::
shared_ptr
<
Tensor
>
;
using
reference
=
std
::
shared_ptr
<
Tensor
>
&
;
using
const_reference
=
const
std
::
shared_ptr
<
Tensor
>
&
;
using
vector_type
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
;
using
iterator
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
iterator
;
using
const_iterator
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>::
const_iterator
;
TensorRow
()
noexcept
;
...
...
mindspore/ccsrc/dataset/engine/dataset_iterator.cc
浏览文件 @
d0c5071c
...
...
@@ -84,7 +84,12 @@ Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
// Constructor of the DatasetIterator
DatasetIterator
::
DatasetIterator
(
std
::
shared_ptr
<
ExecutionTree
>
exe_tree
)
:
IteratorBase
(),
root_
(
exe_tree
->
root
()),
tracing_
(
nullptr
),
cur_batch_num_
(
0
),
cur_connector_size_
(
0
)
{
:
IteratorBase
(),
root_
(
exe_tree
->
root
()),
tracing_
(
nullptr
),
cur_batch_num_
(
0
),
cur_connector_size_
(
0
),
cur_connector_capacity_
(
0
)
{
std
::
shared_ptr
<
Tracing
>
node
;
Status
s
=
exe_tree
->
GetProfilingManager
()
->
GetTracingNode
(
kDatasetIteratorTracingName
,
&
node
);
if
(
s
.
IsOk
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.cc
浏览文件 @
d0c5071c
...
...
@@ -237,6 +237,5 @@ Status BucketBatchByLengthOp::Reset() {
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/bucket_batch_by_length_op.h
浏览文件 @
d0c5071c
...
...
@@ -146,7 +146,6 @@ class BucketBatchByLengthOp : public PipelineOp {
std
::
unique_ptr
<
ChildIterator
>
child_iterator_
;
std
::
vector
<
std
::
unique_ptr
<
TensorQTable
>>
buckets_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/datasetops/build_vocab_op.h
浏览文件 @
d0c5071c
...
...
@@ -112,6 +112,8 @@ class BuildVocabOp : public ParallelOp {
BuildVocabOp
(
std
::
shared_ptr
<
Vocab
>
vocab
,
std
::
vector
<
std
::
string
>
col_names
,
std
::
pair
<
int64_t
,
int64_t
>
freq_range
,
int64_t
top_k
,
int32_t
num_workers
,
int32_t
op_connector_size
);
~
BuildVocabOp
()
=
default
;
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// collect the work product from each worker
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc
浏览文件 @
d0c5071c
...
...
@@ -30,7 +30,6 @@
namespace
mindspore
{
namespace
dataset
{
ClueOp
::
Builder
::
Builder
()
:
builder_device_id_
(
0
),
builder_num_devices_
(
1
),
builder_num_samples_
(
0
),
builder_shuffle_files_
(
false
)
{
std
::
shared_ptr
<
ConfigManager
>
config_manager
=
GlobalContext
::
config_manager
();
...
...
@@ -545,6 +544,5 @@ Status ClueOp::CountAllFileRows(const std::vector<std::string> &files, int64_t *
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h
浏览文件 @
d0c5071c
...
...
@@ -264,7 +264,6 @@ class ClueOp : public ParallelOp {
bool
load_jagged_connector_
;
ColKeyMap
cols_to_keyword_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_CLUE_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
浏览文件 @
d0c5071c
...
...
@@ -59,8 +59,8 @@ CocoOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
Status
CocoOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CocoOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
if
(
builder_sampler_
==
nullptr
)
{
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
const
int64_t
num_samples
=
0
;
const
int64_t
start_index
=
0
;
builder_sampler_
=
std
::
make_shared
<
SequentialSampler
>
(
start_index
,
num_samples
);
}
builder_schema_
=
std
::
make_unique
<
DataSchema
>
();
...
...
mindspore/ccsrc/dataset/engine/perf/connector_size.h
浏览文件 @
d0c5071c
...
...
@@ -44,6 +44,8 @@ class ConnectorSize : public Sampling {
public:
explicit
ConnectorSize
(
ExecutionTree
*
tree
)
:
tree_
(
tree
)
{}
~
ConnectorSize
()
=
default
;
// Driver function for connector size sampling.
// This function samples the connector size of every nodes within the ExecutionTree
Status
Sample
()
override
;
...
...
@@ -54,7 +56,7 @@ class ConnectorSize : public Sampling {
// @return Status - The error code return
Status
SaveToFile
()
override
;
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
);
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
)
override
;
// Parse op infomation and transform to json format
json
ParseOpInfo
(
const
DatasetOp
&
node
,
const
std
::
vector
<
int32_t
>
&
size
);
...
...
mindspore/ccsrc/dataset/engine/perf/dataset_iterator_tracing.h
浏览文件 @
d0c5071c
...
...
@@ -28,7 +28,7 @@ class DatasetIteratorTracing : public Tracing {
DatasetIteratorTracing
()
=
default
;
// Destructor
~
DatasetIteratorTracing
()
=
default
;
~
DatasetIteratorTracing
()
override
=
default
;
// Record tracing data
// @return Status - The error code return
...
...
@@ -40,7 +40,7 @@ class DatasetIteratorTracing : public Tracing {
// @return Status - The error code return
Status
SaveToFile
()
override
;
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
);
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
)
override
;
private:
std
::
vector
<
std
::
string
>
value_
;
...
...
mindspore/ccsrc/dataset/engine/perf/device_queue_tracing.h
浏览文件 @
d0c5071c
...
...
@@ -29,7 +29,7 @@ class DeviceQueueTracing : public Tracing {
DeviceQueueTracing
()
=
default
;
// Destructor
~
DeviceQueueTracing
()
=
default
;
~
DeviceQueueTracing
()
override
=
default
;
// Record tracing data
// @return Status - The error code return
...
...
@@ -41,7 +41,7 @@ class DeviceQueueTracing : public Tracing {
// @return Status - The error code return
Status
SaveToFile
()
override
;
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
);
Status
Init
(
const
std
::
string
&
dir_path
,
const
std
::
string
&
device_id
)
override
;
private:
std
::
vector
<
std
::
string
>
value_
;
...
...
mindspore/ccsrc/dataset/engine/perf/monitor.cc
浏览文件 @
d0c5071c
...
...
@@ -25,6 +25,7 @@ namespace dataset {
Monitor
::
Monitor
(
ExecutionTree
*
tree
)
:
tree_
(
tree
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
sampling_interval_
=
cfg
->
monitor_sampling_interval
();
max_samples_
=
0
;
}
Status
Monitor
::
operator
()()
{
...
...
mindspore/ccsrc/dataset/engine/perf/monitor.h
浏览文件 @
d0c5071c
...
...
@@ -33,6 +33,8 @@ class Monitor {
Monitor
()
=
default
;
~
Monitor
()
=
default
;
// Functor for Perf Monitor main loop.
// This function will be the entry point of Mindspore::Dataset::Task
Status
operator
()();
...
...
mindspore/ccsrc/dataset/engine/perf/profiling.h
浏览文件 @
d0c5071c
...
...
@@ -99,7 +99,7 @@ class ProfilingManager {
// If profiling is enabled.
bool
IsProfilingEnable
()
const
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Sampling
>>
&
GetSamplingNodes
()
{
return
sampling_nodes_
;
}
const
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Sampling
>>
&
GetSamplingNodes
()
{
return
sampling_nodes_
;
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tracing
>>
tracing_nodes_
;
...
...
mindspore/ccsrc/dataset/engine/tdt/tdt_plugin.cc
浏览文件 @
d0c5071c
...
...
@@ -119,7 +119,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
data_item
.
tensorShape_
=
dataShapes
;
data_item
.
tensorType_
=
datatype
;
data_item
.
dataLen_
=
ts
->
SizeInBytes
();
data_item
.
dataPtr_
=
std
::
shared_ptr
<
void
>
(
reinterpret_cast
<
uchar
*>
(
&
(
*
ts
->
begin
<
uint8_t
>
())),
[](
void
*
elem
)
{});
data_item
.
dataPtr_
=
std
::
shared_ptr
<
void
>
(
reinterpret_cast
<
uchar
*>
(
&
(
*
ts
->
begin
<
uint8_t
>
())),
[](
const
void
*
elem
)
{});
items
.
emplace_back
(
data_item
);
MS_LOG
(
DEBUG
)
<<
"TDT data type is "
<<
datatype
<<
", data shape is "
<<
dataShapes
<<
", data length is "
<<
ts
->
Size
()
<<
"."
;
...
...
mindspore/ccsrc/dataset/kernels/data/fill_op.cc
浏览文件 @
d0c5071c
...
...
@@ -21,7 +21,6 @@
namespace
mindspore
{
namespace
dataset
{
Status
FillOp
::
Compute
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
)
{
IO_CHECK
(
input
,
output
);
Status
s
=
Fill
(
input
,
output
,
fill_value_
);
...
...
mindspore/ccsrc/dataset/kernels/data/fill_op.h
浏览文件 @
d0c5071c
...
...
@@ -26,7 +26,6 @@
namespace
mindspore
{
namespace
dataset
{
class
FillOp
:
public
TensorOp
{
public:
explicit
FillOp
(
std
::
shared_ptr
<
Tensor
>
value
)
:
fill_value_
(
value
)
{}
...
...
@@ -39,9 +38,7 @@ class FillOp : public TensorOp {
private:
std
::
shared_ptr
<
Tensor
>
fill_value_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_FILL_OP_H
mindspore/ccsrc/mindrecord/meta/shard_column.cc
浏览文件 @
d0c5071c
...
...
@@ -351,7 +351,7 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const
// Write this int to destination blob
uint64_t
u_n
=
*
reinterpret_cast
<
uint64_t
*>
(
&
i_n
);
auto
temp_bytes
=
UIntToBytesLittle
(
u_n
,
dst_int_type
);
for
(
uint64_t
j
=
0
;
j
<
(
kUnsignedOne
<<
dst_int_type
);
j
++
)
{
for
(
uint64_t
j
=
0
;
j
<
(
kUnsignedOne
<<
static_cast
<
uint8_t
>
(
dst_int_type
)
);
j
++
)
{
dst_bytes
[
i_dst
++
]
=
temp_bytes
[
j
];
}
...
...
@@ -406,7 +406,10 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<
auto
data
=
reinterpret_cast
<
const
unsigned
char
*>
(
array_data
.
get
());
*
data_ptr
=
std
::
make_unique
<
unsigned
char
[]
>
(
*
num_bytes
);
memcpy_s
(
data_ptr
->
get
(),
*
num_bytes
,
data
,
*
num_bytes
);
int
ret_code
=
memcpy_s
(
data_ptr
->
get
(),
*
num_bytes
,
data
,
*
num_bytes
);
if
(
ret_code
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"Failed to copy data!"
;
}
return
SUCCESS
;
}
...
...
@@ -444,7 +447,8 @@ int64_t ShardColumn::BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_a
const
IntegerType
&
src_i_type
,
IntegerType
*
dst_i_type
)
{
uint64_t
u_temp
=
0
;
for
(
uint64_t
i
=
0
;
i
<
(
kUnsignedOne
<<
static_cast
<
uint8_t
>
(
src_i_type
));
i
++
)
{
u_temp
=
(
u_temp
<<
kBitsOfByte
)
+
bytes_array
[
pos
+
(
kUnsignedOne
<<
src_i_type
)
-
kUnsignedOne
-
i
];
u_temp
=
(
u_temp
<<
kBitsOfByte
)
+
bytes_array
[
pos
+
(
kUnsignedOne
<<
static_cast
<
uint8_t
>
(
src_i_type
))
-
kUnsignedOne
-
i
];
}
int64_t
i_out
;
...
...
mindspore/dataset/transforms/vision/py_transforms_util.py
浏览文件 @
d0c5071c
...
...
@@ -554,26 +554,28 @@ def adjust_hue(img, hue_factor):
Returns:
img (PIL Image), Hue adjusted image.
"""
if
not
-
0.5
<=
hue_factor
<=
0.5
:
raise
ValueError
(
'hue_factor {} is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
image
=
img
image_hue_factor
=
hue_factor
if
not
-
0.5
<=
image_hue_factor
<=
0.5
:
raise
ValueError
(
'image_hue_factor {} is not in [-0.5, 0.5].'
.
format
(
image_hue_factor
))
if
not
is_pil
(
im
g
):
raise
TypeError
(
augment_error_message
.
format
(
type
(
im
g
)))
if
not
is_pil
(
im
age
):
raise
TypeError
(
augment_error_message
.
format
(
type
(
im
age
)))
input_mode
=
img
.
mode
if
input_
mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
im
g
mode
=
image
.
mode
if
mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
im
age
h
,
s
,
v
=
img
.
convert
(
'HSV'
).
split
()
h
ue
,
saturation
,
value
=
img
.
convert
(
'HSV'
).
split
()
np_h
=
np
.
array
(
h
,
dtype
=
np
.
uint8
)
np_h
ue
=
np
.
array
(
hue
,
dtype
=
np
.
uint8
)
with
np
.
errstate
(
over
=
'ignore'
):
np_h
+=
np
.
uint8
(
hue_factor
*
255
)
h
=
Image
.
fromarray
(
np_h
,
'L'
)
np_h
ue
+=
np
.
uint8
(
image_
hue_factor
*
255
)
h
ue
=
Image
.
fromarray
(
np_hue
,
'L'
)
im
g
=
Image
.
merge
(
'HSV'
,
(
h
,
s
,
v
)).
convert
(
input_
mode
)
return
im
g
im
age
=
Image
.
merge
(
'HSV'
,
(
hue
,
saturation
,
value
)).
convert
(
mode
)
return
im
age
def
to_type
(
img
,
output_type
):
...
...
tests/ut/python/dataset/prep_data.py
已删除
100644 → 0
浏览文件 @
2ecd5bdf
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# import jsbeautifier
import
os
import
urllib
import
urllib.request
def
create_data_cache_dir
():
cwd
=
os
.
getcwd
()
target_directory
=
os
.
path
.
join
(
cwd
,
"data_cache"
)
try
:
if
not
os
.
path
.
exists
(
target_directory
):
os
.
mkdir
(
target_directory
)
except
OSError
:
print
(
"Creation of the directory %s failed"
%
target_directory
)
return
target_directory
def
download_and_uncompress
(
files
,
source_url
,
target_directory
,
is_tar
=
False
):
for
f
in
files
:
url
=
source_url
+
f
target_file
=
os
.
path
.
join
(
target_directory
,
f
)
##check if file already downloaded
if
not
(
os
.
path
.
exists
(
target_file
)
or
os
.
path
.
exists
(
target_file
[:
-
3
])):
urllib
.
request
.
urlretrieve
(
url
,
target_file
)
if
is_tar
:
print
(
"extracting from local tar file "
+
target_file
)
rc
=
os
.
system
(
"tar -C "
+
target_directory
+
" -xvf "
+
target_file
)
else
:
print
(
"unzipping "
+
target_file
)
rc
=
os
.
system
(
"gunzip -f "
+
target_file
)
if
rc
!=
0
:
print
(
"Failed to uncompress "
,
target_file
,
" removing"
)
os
.
system
(
"rm "
+
target_file
)
##exit with error so that build script will fail
raise
SystemError
else
:
print
(
"Using cached dataset at "
,
target_file
)
def
download_mnist
(
target_directory
=
None
):
if
target_directory
is
None
:
target_directory
=
create_data_cache_dir
()
##create mnst directory
target_directory
=
os
.
path
.
join
(
target_directory
,
"mnist"
)
try
:
if
not
os
.
path
.
exists
(
target_directory
):
os
.
mkdir
(
target_directory
)
except
OSError
:
print
(
"Creation of the directory %s failed"
%
target_directory
)
MNIST_URL
=
"http://yann.lecun.com/exdb/mnist/"
files
=
[
'train-images-idx3-ubyte.gz'
,
'train-labels-idx1-ubyte.gz'
,
't10k-images-idx3-ubyte.gz'
,
't10k-labels-idx1-ubyte.gz'
]
download_and_uncompress
(
files
,
MNIST_URL
,
target_directory
,
is_tar
=
False
)
return
target_directory
,
os
.
path
.
join
(
target_directory
,
"datasetSchema.json"
)
CIFAR_URL
=
"https://www.cs.toronto.edu/~kriz/"
def
download_cifar
(
target_directory
,
files
,
directory_from_tar
):
if
target_directory
is
None
:
target_directory
=
create_data_cache_dir
()
download_and_uncompress
([
files
],
CIFAR_URL
,
target_directory
,
is_tar
=
True
)
##if target dir was specify move data from directory created by tar
##and put data into target dir
if
target_directory
is
not
None
:
tar_dir_full_path
=
os
.
path
.
join
(
target_directory
,
directory_from_tar
)
all_files
=
os
.
path
.
join
(
tar_dir_full_path
,
"*"
)
cmd
=
"mv "
+
all_files
+
" "
+
target_directory
if
os
.
path
.
exists
(
tar_dir_full_path
):
print
(
"copy files back to target_directory"
)
print
(
"Executing: "
,
cmd
)
rc1
=
os
.
system
(
cmd
)
rc2
=
os
.
system
(
"rm -r "
+
tar_dir_full_path
)
if
rc1
!=
0
or
rc2
!=
0
:
print
(
"error when running command: "
,
cmd
)
download_file
=
os
.
path
.
join
(
target_directory
,
files
)
print
(
"removing "
+
download_file
)
os
.
system
(
"rm "
+
download_file
)
##exit with error so that build script will fail
raise
SystemError
##change target directory to directory after tar
return
os
.
path
.
join
(
target_directory
,
directory_from_tar
)
def
download_cifar10
(
target_directory
=
None
):
return
download_cifar
(
target_directory
,
"cifar-10-binary.tar.gz"
,
"cifar-10-batches-bin"
)
def
download_cifar100
(
target_directory
=
None
):
return
download_cifar
(
target_directory
,
"cifar-100-binary.tar.gz"
,
"cifar-100-binary"
)
def
download_all_for_test
(
cwd
):
download_mnist
(
os
.
path
.
join
(
cwd
,
"testMnistData"
))
##Download all datasets to existing test directories
if
__name__
==
"__main__"
:
download_all_for_test
(
os
.
getcwd
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录