Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
039c0bf2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
039c0bf2
编写于
1月 13, 2017
作者:
H
hedaoyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add some constructors for generating object that only contains shape (do not contains data).
上级
2a20fdc1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
2 deletion
+49
-2
paddle/function/BufferArg.h
paddle/function/BufferArg.h
+31
-2
paddle/function/FunctionTest.cpp
paddle/function/FunctionTest.cpp
+18
-0
未找到文件。
paddle/function/BufferArg.h
浏览文件 @
039c0bf2
...
...
@@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class
BufferArg
;
class
SequenceArg
;
class
SparseMatrixArg
;
typedef
std
::
shared_ptr
<
BufferArg
>
BufferArgPtr
;
/**
* \brief BufferArg used as the argument type of Function.
...
...
@@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix.
*
* Buffer shape
* For most buffers, the first dimension `shape()[0]` represents
* the size of the mini-batch.
*
* Buffer argType
* There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the
...
...
@@ -71,6 +75,14 @@ public:
ArgType
getArgType
()
const
{
return
argType_
;
}
public:
BufferArg
(
ValueType
valueType
,
const
TensorShape
&
shape
,
ArgType
argType
=
UNSPECIFIED
)
:
buf_
(
nullptr
),
valueType_
(
valueType
),
shape_
(
shape
),
argType_
(
argType
)
{}
BufferArg
(
void
*
buf
,
ValueType
valueType
,
const
TensorShape
&
shape
,
...
...
@@ -170,6 +182,12 @@ protected:
// if a < b then value_.buf_[a] < value_.buf_[b]
class
SequenceIdArg
:
public
BufferArg
{
public:
SequenceIdArg
(
const
TensorShape
&
shape
,
ArgType
argType
=
UNSPECIFIED
)
:
BufferArg
(
VALUE_TYPE_INT32
,
shape
,
argType
)
{
CHECK_EQ
(
shape_
.
ndims
(),
(
size_t
)
1
);
numSeqs_
=
shape_
[
0
]
-
1
;
}
SequenceIdArg
(
void
*
buf
,
const
TensorShape
&
shape
,
ArgType
argType
=
UNSPECIFIED
)
...
...
@@ -190,9 +208,18 @@ private:
size_t
numSeqs_
;
};
// sequence data
// sequences data
// For mini-batch calculate,
// one batch can contain more than one sequence of data.
// SequenceArg can be used to represent sequences that contain multiple
// unequal lengths.
class
SequenceArg
:
public
BufferArg
{
public:
SequenceArg
(
ValueType
valueType
,
const
TensorShape
&
shape
,
ArgType
argType
=
UNSPECIFIED
)
:
BufferArg
(
valueType
,
shape
,
argType
),
startPositions_
(
TensorShape
())
{}
SequenceArg
(
void
*
buf
,
ValueType
valueType
,
const
TensorShape
&
shape
,
...
...
@@ -210,6 +237,8 @@ public:
void
*
getIdBuf
()
const
{
return
startPositions_
.
data
();
}
size_t
numSeqs
()
const
{
return
startPositions_
.
numSeqs
();
}
SequenceIdArg
&
getSequenceId
()
{
return
startPositions_
;
}
const
SequenceIdArg
&
getSequenceId
()
const
{
return
startPositions_
;
}
private:
SequenceIdArg
startPositions_
;
...
...
paddle/function/FunctionTest.cpp
浏览文件 @
039c0bf2
...
...
@@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs,
}
}
void
testBufferArgs
(
const
BufferArgs
&
inputs
,
const
CheckBufferArg
&
check
)
{
check
(
inputs
[
0
]);
}
TEST
(
Arguments
,
Matrix
)
{
MatrixPtr
matrix
=
Matrix
::
create
(
100
,
200
);
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
...
...
@@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) {
testBufferArgs
(
argments
,
checkFunc
);
}
TEST
(
Arguments
,
BufferArg
)
{
BufferArg
arg
(
nullptr
,
VALUE_TYPE_FLOAT
,
{
1
,
2
,
3
});
CheckBufferArg
check
=
[
=
](
const
BufferArg
&
arg
)
{
EXPECT_EQ
(
arg
.
shape
().
ndims
(),
3
);
EXPECT_EQ
(
arg
.
shape
()[
0
],
1
);
EXPECT_EQ
(
arg
.
shape
()[
1
],
2
);
EXPECT_EQ
(
arg
.
shape
()[
2
],
3
);
};
BufferArgs
argments
;
argments
.
addArg
(
arg
);
testBufferArgs
(
argments
,
check
);
}
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录