Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
93810a0d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
93810a0d
编写于
8月 21, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API: Minor fixes for dataset parameters
上级
3eef4a4e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
18 deletion
+18
-18
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+2
-2
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+11
-11
mindspore/ccsrc/minddata/dataset/include/samplers.h
mindspore/ccsrc/minddata/dataset/include/samplers.h
+5
-5
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
93810a0d
...
...
@@ -218,7 +218,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
}
// Function to create a TextFileDataset.
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
const
std
::
vector
<
std
::
string
>
&
dataset_files
,
int
32
_t
num_samples
,
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
const
std
::
vector
<
std
::
string
>
&
dataset_files
,
int
64
_t
num_samples
,
ShuffleMode
shuffle
,
int32_t
num_shards
,
int32_t
shard_id
)
{
auto
ds
=
std
::
make_shared
<
TextFileDataset
>
(
dataset_files
,
num_samples
,
shuffle
,
num_shards
,
shard_id
);
...
...
@@ -1331,7 +1331,7 @@ bool TextFileDataset::ValidateParams() {
return
false
;
}
if
(
!
ValidateDatasetShardParams
(
"Text
f
ileDataset"
,
num_shards_
,
shard_id_
))
{
if
(
!
ValidateDatasetShardParams
(
"Text
F
ileDataset"
,
num_shards_
,
shard_id_
))
{
return
false
;
}
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
93810a0d
...
...
@@ -84,10 +84,10 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "");
// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'.
/// \param[in] decode Decode the images after reading (default=False).
/// \param[in] extensions List of file extensions to be included in the dataset (default=None).
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] extensions Set of file extensions to be included in the dataset (default={}).
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
CelebADataset
>
CelebA
(
const
std
::
string
&
dataset_dir
,
const
std
::
string
&
dataset_type
=
"all"
,
const
std
::
shared_ptr
<
SamplerObj
>
&
sampler
=
nullptr
,
bool
decode
=
false
,
...
...
@@ -199,11 +199,11 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir,
/// \notes The generated dataset has two columns ['image', 'label']
/// \param[in] dataset_file The dataset file to be read
/// \param[in] usage Need "train", "eval" or "inference" data (default="train")
/// \param[in] decode Decode the images after reading (default=false).
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \param[in] class_indexing A str-to-int mapping from label name to index (default={}, the folder
/// names will be sorted alphabetically and each class will be given a unique index starting from 0).
/// \param[in] decode Decode the images after reading (default=false).
/// \return Shared pointer to the current ManifestDataset
std
::
shared_ptr
<
ManifestDataset
>
Manifest
(
std
::
string
dataset_file
,
std
::
string
usage
=
"train"
,
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
,
...
...
@@ -230,13 +230,13 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
/// \brief Function to create a RandomDataset
/// \param[in] total_rows Number of rows for the dataset to generate (default=0, number of rows is random)
/// \param[in] schema SchemaObj to set column type, data type and data shape
/// \param[in] columns_list List of columns to be read (default=
None
, read all columns)
/// \param[in] columns_list List of columns to be read (default=
{}
, read all columns)
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
template
<
typename
T
=
std
::
shared_ptr
<
SchemaObj
>
>
std
::
shared_ptr
<
RandomDataset
>
RandomData
(
const
int32_t
&
total_rows
=
0
,
T
schema
=
nullptr
,
std
::
vector
<
std
::
string
>
columns_list
=
{},
const
std
::
vector
<
std
::
string
>
&
columns_list
=
{},
std
::
shared_ptr
<
SamplerObj
>
sampler
=
nullptr
)
{
auto
ds
=
std
::
make_shared
<
RandomDataset
>
(
total_rows
,
schema
,
std
::
move
(
columns_list
),
std
::
move
(
sampler
));
return
ds
->
ValidateParams
()
?
ds
:
nullptr
;
...
...
@@ -257,7 +257,7 @@ std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schem
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
const
std
::
vector
<
std
::
string
>
&
dataset_files
,
int
32
_t
num_samples
=
0
,
std
::
shared_ptr
<
TextFileDataset
>
TextFile
(
const
std
::
vector
<
std
::
string
>
&
dataset_files
,
int
64
_t
num_samples
=
0
,
ShuffleMode
shuffle
=
ShuffleMode
::
kGlobal
,
int32_t
num_shards
=
1
,
int32_t
shard_id
=
0
);
...
...
@@ -302,7 +302,7 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
virtual
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
Build
()
=
0
;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool
True if all the param
s are valid
/// \return bool
true if all the parameter
s are valid
virtual
bool
ValidateParams
()
=
0
;
/// \brief Setter function for runtime number of workers
...
...
@@ -767,8 +767,8 @@ class RandomDataset : public Dataset {
static
constexpr
int32_t
kMaxDimValue
=
32
;
/// \brief Constructor
RandomDataset
(
const
int32_t
&
total_rows
,
std
::
shared_ptr
<
SchemaObj
>
schema
,
std
::
vector
<
std
::
string
>
columns_list
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
RandomDataset
(
const
int32_t
&
total_rows
,
std
::
shared_ptr
<
SchemaObj
>
schema
,
const
std
::
vector
<
std
::
string
>
&
columns_list
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
total_rows_
(
total_rows
),
schema_path_
(
""
),
schema_
(
std
::
move
(
schema
)),
...
...
mindspore/ccsrc/minddata/dataset/include/samplers.h
浏览文件 @
93810a0d
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_
API
_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_
API
_SAMPLERS_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_
INCLUDE
_SAMPLERS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_
INCLUDE
_SAMPLERS_H_
#include <vector>
#include <memory>
...
...
@@ -70,7 +70,7 @@ std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle = false, i
/// Function to create a Random Sampler.
/// \notes Samples the elements randomly.
/// \param[in] replacement - If
T
rue, put the sample ID back for the next draw.
/// \param[in] replacement - If
t
rue, put the sample ID back for the next draw.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \return Shared pointer to the current Sampler.
std
::
shared_ptr
<
RandomSamplerObj
>
RandomSampler
(
bool
replacement
=
false
,
int64_t
num_samples
=
0
);
...
...
@@ -94,7 +94,7 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t>
/// weights (probabilities).
/// \param[in] weights - A vector sequence of weights, not necessarily summing up to 1.
/// \param[in] num_samples - The number of samples to draw (default to all elements).
/// \param[in] replacement - If
T
rue, put the sample ID back for the next draw.
/// \param[in] replacement - If
t
rue, put the sample ID back for the next draw.
/// \return Shared pointer to the current Sampler.
std
::
shared_ptr
<
WeightedRandomSamplerObj
>
WeightedRandomSampler
(
std
::
vector
<
double
>
weights
,
int64_t
num_samples
=
0
,
bool
replacement
=
true
);
...
...
@@ -199,4 +199,4 @@ class WeightedRandomSamplerObj : public SamplerObj {
}
// namespace api
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_
API
_SAMPLERS_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_
INCLUDE
_SAMPLERS_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录