Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7ec0b585
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7ec0b585
编写于
8月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3977 [MD] Refactor Concatenate Op
Merge pull request !3977 from nhussain/multi_dim_concat_2
上级
13c63f02
61769b2d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
144 addition
and
103 deletion
+144
-103
mindspore/ccsrc/minddata/dataset/core/tensor.cc
mindspore/ccsrc/minddata/dataset/core/tensor.cc
+26
-41
mindspore/ccsrc/minddata/dataset/core/tensor.h
mindspore/ccsrc/minddata/dataset/core/tensor.h
+3
-4
mindspore/ccsrc/minddata/dataset/include/tensor.h
mindspore/ccsrc/minddata/dataset/include/tensor.h
+3
-4
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
+39
-43
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
+0
-5
tests/ut/cpp/dataset/concatenate_op_test.cc
tests/ut/cpp/dataset/concatenate_op_test.cc
+69
-2
tests/ut/cpp/dataset/tensor_test.cc
tests/ut/cpp/dataset/tensor_test.cc
+3
-3
tests/ut/python/dataset/test_concatenate_op.py
tests/ut/python/dataset/test_concatenate_op.py
+1
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/core/tensor.cc
浏览文件 @
7ec0b585
...
...
@@ -526,16 +526,34 @@ Status Tensor::StartAddrOfIndex(std::vector<dsize_t> ind, uchar **start_addr_of_
return
Status
::
OK
();
}
Status
Tensor
::
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
ind
,
const
std
::
shared_ptr
<
Tensor
>
&
tensor
)
{
Status
Tensor
::
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
ind
,
const
std
::
shared_ptr
<
Tensor
>
&
tensor
,
const
bool
partial_insert
)
{
std
::
string
err_msg
;
err_msg
+=
(
this
->
type
()
==
DataType
::
DE_STRING
)
?
"[Tensor] Cannot batch tensors of type string
\n
"
:
""
;
err_msg
+=
(
!
this
->
shape
().
known
()
||
!
tensor
->
shape
().
known
())
?
"[Tensor] unknown shape
\n
"
:
""
;
err_msg
+=
(
ind
.
size
()
+
tensor
->
Rank
()
!=
this
->
Rank
())
?
"[Tensor] incorrect index
\n
"
:
""
;
err_msg
+=
tensor
->
type
().
SizeInBytes
()
!=
this
->
type
().
SizeInBytes
()
?
"[Tensor] incorrect datatype
\n
"
:
""
;
if
(
partial_insert
)
{
err_msg
+=
(
ind
.
size
()
!=
1
)
?
"[Tensor] only supports 1D insertion of elements not along the full length of the axis
\n
"
:
""
;
err_msg
+=
(
ind
.
at
(
0
)
+
tensor
->
shape
().
NumOfElements
()
>
shape
().
NumOfElements
())
?
"[Tensor] incorrect index
\n
"
:
""
;
}
else
{
err_msg
+=
(
ind
.
size
()
+
tensor
->
Rank
()
!=
Rank
())
?
"[Tensor] incorrect index
\n
"
:
""
;
}
err_msg
+=
(
type
()
==
DataType
::
DE_STRING
)
?
"[Tensor] Cannot insert into a tensor of type string
\n
"
:
""
;
err_msg
+=
(
!
shape
().
known
()
||
!
tensor
->
shape
().
known
())
?
"[Tensor] unknown shape
\n
"
:
""
;
err_msg
+=
tensor
->
type
().
SizeInBytes
()
!=
type
().
SizeInBytes
()
?
"[Tensor] incorrect datatype
\n
"
:
""
;
uchar
*
start_addr_of_ind
=
nullptr
;
TensorShape
remaining_shape
=
TensorShape
::
CreateUnknownRankShape
();
err_msg
+=
(
!
StartAddrOfIndex
(
ind
,
&
start_addr_of_ind
,
&
remaining_shape
).
IsOk
())
?
"[Tensor] incorrect index
\n
"
:
""
;
err_msg
+=
!
(
remaining_shape
==
tensor
->
shape
())
?
"[Tensor] memory error
\n
"
:
""
;
if
(
partial_insert
)
{
TensorShape
remaining_shape
=
tensor
->
shape
();
err_msg
+=
(
!
StartAddrOfIndex
(
ind
,
&
start_addr_of_ind
,
&
remaining_shape
).
IsOk
())
?
"[Tensor] incorrect index
\n
"
:
""
;
}
else
{
TensorShape
remaining_shape
=
TensorShape
::
CreateUnknownRankShape
();
err_msg
+=
(
!
StartAddrOfIndex
(
ind
,
&
start_addr_of_ind
,
&
remaining_shape
).
IsOk
())
?
"[Tensor] incorrect index
\n
"
:
""
;
err_msg
+=
!
(
remaining_shape
==
tensor
->
shape
())
?
"[Tensor] memory error
\n
"
:
""
;
}
if
(
!
err_msg
.
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Insert tensor message: "
<<
err_msg
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
...
...
@@ -556,39 +574,6 @@ Status Tensor::InsertTensor(const std::vector<dsize_t> &ind, const std::shared_p
}
}
Status
Tensor
::
Concatenate
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
tensor
)
{
std
::
string
err_msg
;
err_msg
+=
(
index
.
size
()
!=
1
)
?
"[Tensor] only supports 1d concatenation
\n
"
:
""
;
err_msg
+=
(
type
()
==
DataType
::
DE_STRING
)
?
"[Tensor] Cannot batch tensors of type string
\n
"
:
""
;
err_msg
+=
(
!
shape
().
known
()
||
!
tensor
->
shape
().
known
())
?
"[Tensor] unknown shape
\n
"
:
""
;
err_msg
+=
(
index
.
at
(
0
)
+
tensor
->
shape
().
NumOfElements
()
>
this
->
shape
().
NumOfElements
())
?
"[Tensor] incorrect index
\n
"
:
""
;
err_msg
+=
tensor
->
type
().
SizeInBytes
()
!=
this
->
type
().
SizeInBytes
()
?
"[Tensor] incorrect datatype
\n
"
:
""
;
uchar
*
start_addr_of_ind
=
nullptr
;
TensorShape
remaining_shape
=
tensor
->
shape
();
StartAddrOfIndex
(
index
,
&
start_addr_of_ind
,
&
remaining_shape
);
err_msg
+=
(
start_addr_of_ind
==
nullptr
)
?
"Failed to create memory for Tensor.
\n
"
:
""
;
if
(
!
err_msg
.
empty
())
{
MS_LOG
(
DEBUG
)
<<
"Insert tensor message: "
<<
err_msg
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
{
int
ret_code
=
memcpy_s
(
start_addr_of_ind
,
tensor
->
SizeInBytes
(),
tensor
->
GetMutableBuffer
(),
tensor
->
SizeInBytes
());
if
(
ret_code
==
0
)
{
return
Status
::
OK
();
}
else
{
err_msg
+=
"[Tensor] error in memcpy_s when inserting tensor
\n
"
;
MS_LOG
(
DEBUG
)
<<
"Tensor message: "
<<
err_msg
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
}
Status
Tensor
::
ExpandDim
(
const
dsize_t
&
axis
)
{
if
(
axis
>
Rank
())
{
std
::
string
err
=
"Axis is out of bound"
;
...
...
mindspore/ccsrc/minddata/dataset/core/tensor.h
浏览文件 @
7ec0b585
...
...
@@ -330,8 +330,10 @@ class Tensor {
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
/// \param index
/// \param input
/// \param partial_insert: boolean to determine if insertion along the full axis is enforced
/// \return Status code
Status
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
);
Status
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
,
const
bool
partial_insert
=
false
);
/// Find the address of the given index. Used in InsertTensor.
/// Example:
...
...
@@ -393,9 +395,6 @@ class Tensor {
static
Status
GetBufferInfo
(
Tensor
*
t
,
py
::
buffer_info
*
out
);
#endif
/// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
Status
Concatenate
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
);
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
/// \tparam T type of values in the Tensor Iterator
...
...
mindspore/ccsrc/minddata/dataset/include/tensor.h
浏览文件 @
7ec0b585
...
...
@@ -330,8 +330,10 @@ class Tensor {
/// Ex: shape <2,3>, inserting into index {0} will replace the first row. index {1,2} will replace the last cell.
/// \param index
/// \param input
/// \param partial_insert: boolean to determine if insertion along the full axis is enforced
/// \return Status code
Status
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
);
Status
InsertTensor
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
,
const
bool
partial_insert
=
false
);
/// Find the address of the given index. Used in InsertTensor.
/// Example:
...
...
@@ -393,9 +395,6 @@ class Tensor {
static
Status
GetBufferInfo
(
Tensor
*
t
,
py
::
buffer_info
*
out
);
#endif
/// Concatenate based on given tensor, can fill in current tensor with a smaller one, unlike InsertTensor
Status
Concatenate
(
const
std
::
vector
<
dsize_t
>
&
index
,
const
std
::
shared_ptr
<
Tensor
>
&
input
);
/// TensorIterator is a linear iterator that can be used to iterate over the elements of the Tensor
/// The order elements is as the memory layout (i.e., row-major) [[1,2,3],[4,5,6] --> 1,2,3,4,5,6
/// \tparam T type of values in the Tensor Iterator
...
...
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc
浏览文件 @
7ec0b585
...
...
@@ -580,77 +580,73 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
Status
Concatenate
(
const
TensorRow
&
input
,
TensorRow
*
output
,
int8_t
axis
,
std
::
shared_ptr
<
Tensor
>
prepend
,
std
::
shared_ptr
<
Tensor
>
append
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
input
[
0
]
->
shape
().
Rank
()
==
1
,
"Only 1D tensors supported"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
axis
==
0
||
axis
==
-
1
,
"Only concatenation along the last dimension supported"
);
axis
=
Tensor
::
HandleNeg
(
axis
,
input
[
0
]
->
shape
().
Rank
());
CHECK_FAIL_RETURN_UNEXPECTED
(
axis
==
0
,
"Only axis=0 is supported"
);
std
::
shared_ptr
<
Tensor
>
out
;
TensorShape
t
=
TensorShape
::
CreateScalar
();
DataType
first_dtype
=
input
[
0
]
->
type
();
TensorRow
tensor_list
;
if
(
prepend
!=
nullptr
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
first_dtype
==
prepend
->
type
(),
"Tensor types do not match"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
prepend
->
shape
().
Rank
()
==
1
,
"Only 1D tensors supported"
);
RETURN_IF_NOT_OK
(
ConcatenateHelper
(
prepend
,
&
out
,
axis
,
input
[
0
]));
}
else
{
out
=
input
[
0
];
tensor_list
.
emplace_back
(
prepend
);
}
for
(
dsize_t
i
=
1
;
i
<
input
.
size
();
i
++
)
{
std
::
shared_ptr
<
Tensor
>
out_t
;
for
(
dsize_t
i
=
0
;
i
<
input
.
size
();
i
++
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
first_dtype
==
input
[
i
]
->
type
(),
"Tensor types do not match"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
input
[
i
]
->
shape
().
Rank
()
==
1
,
"Only 1D tensors supported"
);
RETURN_IF_NOT_OK
(
ConcatenateHelper
(
out
,
&
out_t
,
axis
,
input
[
i
]));
out
=
out_t
;
tensor_list
.
emplace_back
(
input
[
i
]);
}
std
::
shared_ptr
<
Tensor
>
out_t
;
if
(
append
!=
nullptr
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
first_dtype
==
append
->
type
(),
"Tensor types do not match"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
append
->
shape
().
Rank
()
==
1
,
"Only 1D tensors supported"
);
RETURN_IF_NOT_OK
(
ConcatenateHelper
(
out
,
&
out_t
,
axis
,
append
));
}
else
{
out_t
=
out
;
tensor_list
.
emplace_back
(
append
);
}
output
->
push_back
(
out_t
);
return
Status
::
OK
();
}
Status
ConcatenateHelper
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
int8_t
axis
,
std
::
shared_ptr
<
Tensor
>
append
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
input
->
type
()
==
append
->
type
(),
"Tensor types do not match"
);
TensorShape
t
({});
for
(
dsize_t
i
=
0
;
i
<
input
->
shape
().
Rank
();
i
++
)
{
// create final shape
for
(
dsize_t
i
=
0
;
i
<
tensor_list
[
0
]
->
shape
().
Rank
();
i
++
)
{
if
(
i
!=
axis
)
{
t
=
t
.
AppendDim
(
input
->
shape
()[
i
]);
t
=
t
.
AppendDim
(
tensor_list
[
0
]
->
shape
()[
i
]);
}
else
{
dsize_t
new_shape
=
input
->
shape
()[
i
]
+
append
->
shape
()[
i
];
dsize_t
new_shape
=
0
;
for
(
dsize_t
j
=
0
;
j
<
tensor_list
.
size
();
j
++
)
{
new_shape
=
tensor_list
[
j
]
->
shape
()[
i
]
+
new_shape
;
}
t
=
t
.
AppendDim
(
new_shape
);
}
}
std
::
shared_ptr
<
Tensor
>
out
;
if
(
input
->
type
().
IsNumeric
())
{
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
t
,
input
->
type
(),
&
out
));
if
(
input
[
0
]
->
type
().
IsNumeric
())
{
RETURN_IF_NOT_OK
(
Tensor
::
CreateEmpty
(
t
,
tensor_list
[
0
]
->
type
(),
&
out
));
std
::
vector
<
dsize_t
>
index
(
axis
+
1
,
0
);
RETURN_IF_NOT_OK
(
out
->
Concatenate
({
0
},
input
));
RETURN_IF_NOT_OK
(
out
->
Concatenate
({
input
->
shape
()[
0
]},
append
));
*
output
=
out
;
int
n
=
index
.
size
()
-
1
;
for
(
dsize_t
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
RETURN_IF_NOT_OK
(
out
->
InsertTensor
({
index
},
tensor_list
[
i
],
true
));
index
[
n
]
=
index
[
n
]
+
tensor_list
[
i
]
->
shape
()[
axis
];
}
}
else
{
std
::
vector
<
std
::
string
>
strings
;
auto
itr
=
input
->
begin
<
std
::
string_view
>
();
for
(;
itr
!=
input
->
end
<
std
::
string_view
>
();
itr
++
)
{
strings
.
emplace_back
(
*
itr
);
}
itr
=
append
->
begin
<
std
::
string_view
>
();
for
(;
itr
!=
append
->
end
<
std
::
string_view
>
();
itr
++
)
{
strings
.
emplace_back
(
*
itr
);
for
(
dsize_t
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
auto
itr
=
tensor_list
[
i
]
->
begin
<
std
::
string_view
>
();
for
(;
itr
!=
tensor_list
[
i
]
->
end
<
std
::
string_view
>
();
itr
++
)
{
strings
.
emplace_back
(
*
itr
);
}
}
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromVector
(
strings
,
t
,
&
out
));
*
output
=
out
;
}
output
->
push_back
(
out
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h
浏览文件 @
7ec0b585
...
...
@@ -152,11 +152,6 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
Status
Concatenate
(
const
TensorRow
&
input
,
TensorRow
*
output
,
int8_t
axis
,
std
::
shared_ptr
<
Tensor
>
prepend
,
std
::
shared_ptr
<
Tensor
>
append
);
// helper for concat, always append to the input, and pass that to the output
Status
ConcatenateHelper
(
const
std
::
shared_ptr
<
Tensor
>
&
input
,
std
::
shared_ptr
<
Tensor
>
*
output
,
int8_t
axis
,
std
::
shared_ptr
<
Tensor
>
append
);
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/cpp/dataset/concatenate_op_test.cc
浏览文件 @
7ec0b585
...
...
@@ -28,9 +28,8 @@ class MindDataTestConcatenateOp : public UT::Common {
};
TEST_F
(
MindDataTestConcatenateOp
,
TestOp
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestConcatenate-TestOp."
;
MS_LOG
(
INFO
)
<<
"Doing MindDataTestConcatenate-TestOp
-SingleRowinput
."
;
std
::
vector
<
uint64_t
>
labels
=
{
1
,
1
,
2
};
TensorShape
shape
({
3
});
std
::
shared_ptr
<
Tensor
>
input
;
Tensor
::
CreateFromVector
(
labels
,
&
input
);
...
...
@@ -57,3 +56,71 @@ TEST_F(MindDataTestConcatenateOp, TestOp) {
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
}
TEST_F
(
MindDataTestConcatenateOp
,
TestOp2
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestConcatenate-TestOp2-MultiInput."
;
std
::
vector
<
uint64_t
>
labels
=
{
1
,
12
,
2
};
std
::
shared_ptr
<
Tensor
>
row_1
;
Tensor
::
CreateFromVector
(
labels
,
&
row_1
);
std
::
shared_ptr
<
Tensor
>
row_2
;
Tensor
::
CreateFromVector
(
labels
,
&
row_2
);
std
::
vector
<
uint64_t
>
append_labels
=
{
4
,
4
,
4
};
std
::
shared_ptr
<
Tensor
>
append
;
Tensor
::
CreateFromVector
(
append_labels
,
&
append
);
TensorRow
tensor_list
;
tensor_list
.
push_back
(
row_1
);
tensor_list
.
push_back
(
row_2
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
ConcatenateOp
>
op
(
new
ConcatenateOp
(
0
,
nullptr
,
append
));
TensorRow
out_row
;
Status
s
=
op
->
Compute
(
tensor_list
,
&
out_row
);
std
::
vector
<
uint64_t
>
out
=
{
1
,
12
,
2
,
1
,
12
,
2
,
4
,
4
,
4
};
std
::
shared_ptr
<
Tensor
>
expected
;
Tensor
::
CreateFromVector
(
out
,
&
expected
);
output
=
out_row
[
0
];
EXPECT_TRUE
(
s
.
IsOk
());
ASSERT_TRUE
(
output
->
shape
()
==
expected
->
shape
());
ASSERT_TRUE
(
output
->
type
()
==
expected
->
type
());
MS_LOG
(
DEBUG
)
<<
*
output
<<
std
::
endl
;
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
}
TEST_F
(
MindDataTestConcatenateOp
,
TestOp3
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestConcatenate-TestOp3-Strings."
;
std
::
vector
<
std
::
string
>
labels
=
{
"hello"
,
"bye"
};
std
::
shared_ptr
<
Tensor
>
row_1
;
Tensor
::
CreateFromVector
(
labels
,
&
row_1
);
std
::
vector
<
std
::
string
>
append_labels
=
{
"1"
,
"2"
,
"3"
};
std
::
shared_ptr
<
Tensor
>
append
;
Tensor
::
CreateFromVector
(
append_labels
,
&
append
);
TensorRow
tensor_list
;
tensor_list
.
push_back
(
row_1
);
std
::
shared_ptr
<
Tensor
>
output
;
std
::
unique_ptr
<
ConcatenateOp
>
op
(
new
ConcatenateOp
(
0
,
nullptr
,
append
));
TensorRow
out_row
;
Status
s
=
op
->
Compute
(
tensor_list
,
&
out_row
);
std
::
vector
<
std
::
string
>
out
=
{
"hello"
,
"bye"
,
"1"
,
"2"
,
"3"
};
std
::
shared_ptr
<
Tensor
>
expected
;
Tensor
::
CreateFromVector
(
out
,
&
expected
);
output
=
out_row
[
0
];
EXPECT_TRUE
(
s
.
IsOk
());
ASSERT_TRUE
(
output
->
shape
()
==
expected
->
shape
());
ASSERT_TRUE
(
output
->
type
()
==
expected
->
type
());
MS_LOG
(
DEBUG
)
<<
*
output
<<
std
::
endl
;
MS_LOG
(
DEBUG
)
<<
*
expected
<<
std
::
endl
;
ASSERT_TRUE
(
*
output
==
*
expected
);
}
tests/ut/cpp/dataset/tensor_test.cc
浏览文件 @
7ec0b585
...
...
@@ -432,7 +432,7 @@ TEST_F(MindDataTestTensorDE, TensorSlice) {
ASSERT_EQ
(
*
t2
,
*
t
);
}
TEST_F
(
MindDataTestTensorDE
,
Tensor
Concatenate
)
{
TEST_F
(
MindDataTestTensorDE
,
Tensor
PartialInsert
)
{
std
::
vector
<
uint32_t
>
values1
=
{
1
,
2
,
3
,
0
,
0
,
0
};
std
::
vector
<
uint32_t
>
values2
=
{
4
,
5
,
6
};
std
::
vector
<
uint32_t
>
expected
=
{
1
,
2
,
3
,
4
,
5
,
6
};
...
...
@@ -445,7 +445,7 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) {
std
::
shared_ptr
<
Tensor
>
out
;
Tensor
::
CreateFromVector
(
expected
,
&
out
);
Status
s
=
t1
->
Concatenate
({
3
},
t2
);
Status
s
=
t1
->
InsertTensor
({
3
},
t2
,
true
);
EXPECT_TRUE
(
s
.
IsOk
());
auto
i
=
out
->
begin
<
uint32_t
>
();
...
...
@@ -455,7 +455,7 @@ TEST_F(MindDataTestTensorDE, TensorConcatenate) {
}
// should fail if the concatenated vector is too large
s
=
t1
->
Concatenate
({
5
},
t2
);
s
=
t1
->
InsertTensor
({
5
},
t2
,
true
);
EXPECT_FALSE
(
s
.
IsOk
());
}
...
...
tests/ut/python/dataset/test_concatenate_op.py
浏览文件 @
7ec0b585
...
...
@@ -130,7 +130,7 @@ def test_concatenate_op_incorrect_dim():
def
gen
():
yield
(
np
.
array
([[
"ss"
,
"ad"
],
[
"ss"
,
"ad"
]],
dtype
=
'S'
),)
prepend_tensor
=
np
.
array
([
3
,
5
],
dtype
=
np
.
float
)
prepend_tensor
=
np
.
array
([
"ss"
,
"ss"
],
dtype
=
'S'
)
concatenate_op
=
data_trans
.
Concatenate
(
0
,
prepend_tensor
)
data
=
ds
.
GeneratorDataset
(
gen
,
column_names
=
[
"col"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录