Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b548ecbc
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b548ecbc
编写于
8月 22, 2018
作者:
X
Xin Pan
提交者:
sneaxiy
8月 22, 2018
浏览文件
操作
浏览文件
下载
差异文件
add stack_op
上级
bc4f5375
a2c0e52f
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
540 addition
and
79 deletion
+540
-79
paddle/fluid/framework/array.h
paddle/fluid/framework/array.h
+48
-0
paddle/fluid/operators/stack_op.cc
paddle/fluid/operators/stack_op.cc
+66
-0
paddle/fluid/operators/stack_op.cu
paddle/fluid/operators/stack_op.cu
+109
-0
paddle/fluid/operators/stack_op.h
paddle/fluid/operators/stack_op.h
+192
-0
paddle/fluid/operators/while_op.cc
paddle/fluid/operators/while_op.cc
+5
-5
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+28
-74
python/paddle/fluid/tests/unittests/test_stack_op.py
python/paddle/fluid/tests/unittests/test_stack_op.py
+92
-0
未找到文件。
paddle/fluid/framework/array.h
0 → 100644
浏览文件 @
b548ecbc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
namespace
paddle
{
namespace
framework
{
template
<
typename
T
,
size_t
N
>
class
Array
{
static_assert
(
N
>
0
,
"The size of array must be larger than 0"
);
public:
HOSTDEVICE
Array
()
{}
HOSTDEVICE
explicit
Array
(
const
T
&
val
)
{
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
data_
[
i
]
=
val
;
}
HOSTDEVICE
const
T
*
Get
()
const
{
return
data_
;
}
HOSTDEVICE
T
*
GetMutable
()
{
return
data_
;
}
HOSTDEVICE
T
&
operator
[](
size_t
index
)
{
return
data_
[
index
];
}
HOSTDEVICE
const
T
&
operator
[](
size_t
index
)
const
{
return
data_
[
index
];
}
HOSTDEVICE
constexpr
size_t
size
()
const
{
return
N
;
}
private:
T
data_
[
N
];
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/operators/stack_op.cc
0 → 100644
浏览文件 @
b548ecbc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
namespace
paddle
{
namespace
operators
{
struct
CPUStackFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
const
std
::
vector
<
const
T
*>&
x
,
T
*
y
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
for
(
int
idx
=
0
;
idx
<
total_num
;
++
idx
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
y
[
idx
]
=
x
[
which_x
][
x_index
];
}
}
};
struct
CPUStackGradFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
std
::
vector
<
T
*>&
dx
,
// NOLINT
const
T
*
dy
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
for
(
int
idx
=
0
;
idx
<
total_num
;
++
idx
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
dx
[
which_x
][
x_index
]
=
dy
[
idx
];
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
stack
,
ops
::
StackOp
,
ops
::
StackOpMaker
,
ops
::
StackGradOpDescMaker
);
REGISTER_OPERATOR
(
stack_grad
,
ops
::
StackOpGrad
);
REGISTER_OP_CPU_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
float
,
ops
::
CPUStackFunctor
>
,
ops
::
StackKernel
<
plat
::
CPUDeviceContext
,
double
,
ops
::
CPUStackFunctor
>
);
REGISTER_OP_CPU_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
float
,
ops
::
CPUStackGradFunctor
>
,
ops
::
StackGradKernel
<
plat
::
CPUDeviceContext
,
double
,
ops
::
CPUStackGradFunctor
>
);
paddle/fluid/operators/stack_op.cu
0 → 100644
浏览文件 @
b548ecbc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/operators/stack_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
VecXType
>
__global__
void
StackCUDAKernel
(
VecXType
x
,
T
*
y
,
int
total_num
,
int
n
,
int
post
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
total_num
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
y
[
idx
]
=
x
[
which_x
][
x_index
];
}
}
template
<
typename
T
,
typename
VecDxType
>
__global__
void
StackGradCUDAKernel
(
VecDxType
dx
,
const
T
*
dy
,
int
total_num
,
int
n
,
int
post
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
total_num
)
{
int
i
=
idx
/
(
n
*
post
);
int
which_x
=
idx
/
post
-
i
*
n
;
int
x_index
=
i
*
post
+
idx
%
post
;
dx
[
which_x
][
x_index
]
=
dy
[
idx
];
}
}
struct
GPUStackFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
const
std
::
vector
<
const
T
*>&
x
,
T
*
y
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
int
threads
=
512
;
int
grid
=
(
total_num
+
threads
-
1
)
/
threads
;
constexpr
auto
kMaxThreshold
=
16
;
if
(
n
<=
kMaxThreshold
)
{
framework
::
Array
<
const
T
*
,
kMaxThreshold
>
arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
arr
[
i
]
=
x
[
i
];
StackCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
arr
,
y
,
total_num
,
n
,
post
);
}
else
{
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors may be slow on GPU."
;
thrust
::
device_vector
<
const
T
*>
dev_x
(
x
);
StackCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
dev_x
.
data
().
get
(),
y
,
total_num
,
n
,
post
);
}
}
};
struct
GPUStackGradFunctor
{
template
<
typename
DeviceContext
,
typename
T
>
void
operator
()(
const
DeviceContext
&
ctx
,
std
::
vector
<
T
*>&
dx
,
// NOLINT
const
T
*
dy
,
int
pre
,
int
n
,
int
post
)
const
{
int
total_num
=
pre
*
post
*
n
;
int
threads
=
512
;
int
grid
=
(
total_num
+
threads
-
1
)
/
threads
;
constexpr
auto
kMaxThreshold
=
16
;
if
(
n
<=
kMaxThreshold
)
{
framework
::
Array
<
T
*
,
kMaxThreshold
>
arr
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
arr
[
i
]
=
dx
[
i
];
StackGradCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
arr
,
dy
,
total_num
,
n
,
post
);
}
else
{
VLOG
(
10
)
<<
"Stack more than "
<<
kMaxThreshold
<<
" tensors may be slow on GPU."
;
thrust
::
device_vector
<
T
*>
dev_dx
(
dx
);
StackGradCUDAKernel
<<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
dev_dx
.
data
().
get
(),
dy
,
total_num
,
n
,
post
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
plat
=
paddle
::
platform
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
stack
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
float
,
ops
::
GPUStackFunctor
>
,
ops
::
StackKernel
<
plat
::
CUDADeviceContext
,
double
,
ops
::
GPUStackFunctor
>
);
REGISTER_OP_CUDA_KERNEL
(
stack_grad
,
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
float
,
ops
::
GPUStackGradFunctor
>
,
ops
::
StackGradKernel
<
plat
::
CUDADeviceContext
,
double
,
ops
::
GPUStackGradFunctor
>
);
paddle/fluid/operators/stack_op.h
0 → 100644
浏览文件 @
b548ecbc
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
inline
void
GetPrePostForStackOp
(
const
framework
::
DDim
&
dim
,
int
axis
,
int
*
pre
,
int
*
post
)
{
*
pre
=
1
;
for
(
auto
i
=
0
;
i
<
axis
;
++
i
)
(
*
pre
)
*=
dim
[
i
];
*
post
=
1
;
for
(
auto
i
=
axis
;
i
<
dim
.
size
();
++
i
)
(
*
post
)
*=
dim
[
i
];
}
class
StackOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_GT
(
ctx
->
Inputs
(
"X"
).
size
(),
0
,
"Number of Inputs(X) must be larger than 0"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) must exist."
);
auto
input_dims
=
ctx
->
GetInputsDim
(
"X"
);
for
(
size_t
i
=
1
;
i
<
input_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
input_dims
[
i
],
input_dims
[
0
],
"Dims of all Inputs(X) must be the same"
);
}
// Only lod of X[0] would be shared with Y
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
int
rank
=
input_dims
[
0
].
size
();
PADDLE_ENFORCE
(
axis
>=
-
(
rank
+
1
)
&&
axis
<
rank
+
1
,
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d"
,
rank
);
if
(
axis
<
0
)
axis
+=
(
rank
+
1
);
auto
vec
=
framework
::
vectorize2int
(
input_dims
[
0
]);
vec
.
insert
(
vec
.
begin
()
+
axis
,
input_dims
.
size
());
ctx
->
SetOutputDim
(
"Y"
,
framework
::
make_ddim
(
vec
));
}
};
class
StackOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input of stack op."
).
AsDuplicable
();
AddOutput
(
"Y"
,
"The output of stack op."
);
AddAttr
<
int
>
(
"axis"
,
"The axis along which all of the Inputs(X) should be stacked."
)
.
SetDefault
(
0
);
AddComment
(
R"DOC(
Stack Operator.
Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same.
)DOC"
);
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
>
class
StackKernel
:
public
framework
::
OpKernel
<
T
>
{
using
Tensor
=
framework
::
LoDTensor
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
x
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
axis
<
0
)
axis
+=
(
x
[
0
]
->
dims
().
size
()
+
1
);
int
n
=
static_cast
<
int
>
(
x
.
size
());
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
vector
<
const
T
*>
x_datas
(
n
);
for
(
int
i
=
0
;
i
<
n
;
i
++
)
x_datas
[
i
]
=
x
[
i
]
->
data
<
T
>
();
int
pre
=
1
,
post
=
1
;
auto
&
dim
=
x
[
0
]
->
dims
();
for
(
auto
i
=
0
;
i
<
axis
;
++
i
)
pre
*=
dim
[
i
];
for
(
auto
i
=
axis
;
i
<
dim
.
size
();
++
i
)
post
*=
dim
[
i
];
Functor
functor
;
functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
x_datas
,
y_data
,
pre
,
n
,
post
);
}
};
class
StackOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@Grad) must exist."
);
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
auto
dy_dim
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
int
rank
=
dy_dim
.
size
();
PADDLE_ENFORCE
(
axis
>=
-
rank
&&
axis
<
rank
,
"Attr(axis) must be inside [-rank, rank), where rank = %d"
,
rank
);
if
(
axis
<
0
)
axis
+=
rank
;
PADDLE_ENFORCE_EQ
(
ctx
->
Outputs
(
framework
::
GradVarName
(
"X"
)).
size
(),
static_cast
<
size_t
>
(
dy_dim
[
axis
]),
"Number of Outputs(X@Grad) is wrong"
);
auto
vec
=
framework
::
vectorize2int
(
dy_dim
);
vec
.
erase
(
vec
.
begin
()
+
axis
);
ctx
->
SetOutputsDim
(
framework
::
GradVarName
(
"X"
),
std
::
vector
<
framework
::
DDim
>
(
dy_dim
[
axis
],
framework
::
make_ddim
(
vec
)));
}
};
class
StackGradOpDescMaker
:
public
framework
::
SingleGradOpDescMaker
/*framework::GradOpDescMakerBase*/
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
/*
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator ()() const override {
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Y");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDesc();
grad_op->SetInput("X", og);
grad_op->SetOutput("Y", {x_grad});
grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(grad_op);
});
return grad_ops;
}
*/
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
op
(
new
framework
::
OpDesc
());
op
->
SetType
(
"stack_grad"
);
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
OutputGrad
(
"Y"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
,
false
));
op
->
SetAttrMap
(
Attrs
());
return
op
;
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
GradFunctor
>
class
StackGradKernel
:
public
framework
::
OpKernel
<
T
>
{
using
Tensor
=
framework
::
LoDTensor
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
dy
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
dx
=
ctx
.
MultiOutput
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
axis
<
0
)
axis
+=
dy
->
dims
().
size
();
int
n
=
dy
->
dims
()[
axis
];
std
::
vector
<
T
*>
dx_datas
(
n
);
// NOLINT
for
(
int
i
=
0
;
i
<
n
;
i
++
)
dx_datas
[
i
]
=
dx
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dy_data
=
dy
->
data
<
T
>
();
int
pre
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
pre
*=
dy
->
dims
()[
i
];
int
post
=
dy
->
numel
()
/
(
n
*
pre
);
GradFunctor
functor
;
functor
(
ctx
.
template
device_context
<
DeviceContext
>(),
dx_datas
,
dy_data
,
pre
,
n
,
post
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/while_op.cc
浏览文件 @
b548ecbc
...
...
@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
cond
.
place
()),
"Condition of while op must in CPU memory."
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
while
(
cond
.
data
<
bool
>
()[
0
])
{
auto
&
current_scope
=
scope
.
NewScope
();
step_scopes
->
push_back
(
&
current_scope
);
executor
.
Run
(
*
program
,
&
current_scope
,
block
->
ID
(),
false
/*create_local_scope*/
);
executor
.
RunPreparedContext
(
ctx
.
get
(),
&
current_scope
,
false
);
}
}
};
...
...
@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
framework
::
Executor
executor
(
dev_place
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
());
auto
*
step_scopes
=
scope
.
FindVar
(
Input
(
kStepScopes
))
->
GetMutable
<
StepScopeVar
>
();
...
...
@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
}
}
}
executor
.
Run
(
*
program
,
*
cur_scope_iter
,
block
->
ID
(),
false
);
executor
.
RunPreparedContext
(
ctx
.
get
(),
*
cur_scope_iter
,
false
);
auto
&
pg_names
=
Outputs
(
kXGRAD
);
auto
&
p_names
=
Inputs
(
kX
);
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
b548ecbc
...
...
@@ -29,80 +29,21 @@ from .. import unique_name
from
functools
import
reduce
__all__
=
[
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'dynamic_lstmp'
,
'dynamic_gru'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'conv3d'
,
'sequence_pool'
,
'sequence_softmax'
,
'softmax'
,
'pool2d'
,
'pool3d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'conv3d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'reduce_prod'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'ctc_greedy_decoder'
,
'edit_distance'
,
'l2_normalize'
,
'matmul'
,
'topk'
,
'warpctc'
,
'sequence_reshape'
,
'transpose'
,
'im2sequence'
,
'nce'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
'multiplex'
,
'layer_norm'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'one_hot'
,
'autoincreased_step_counter'
,
'reshape'
,
'lod_reset'
,
'lrn'
,
'pad'
,
'label_smooth'
,
'roi_pool'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'gather'
,
'scatter'
,
'random_crop'
,
'mean_iou'
,
'relu'
,
'log'
,
'crop'
,
'rank_loss'
,
'prelu'
,
'flatten'
,
'fc'
,
'embedding'
,
'dynamic_lstm'
,
'dynamic_lstmp'
,
'dynamic_gru'
,
'gru_unit'
,
'linear_chain_crf'
,
'crf_decoding'
,
'cos_sim'
,
'cross_entropy'
,
'square_error_cost'
,
'chunk_eval'
,
'sequence_conv'
,
'conv2d'
,
'conv3d'
,
'sequence_pool'
,
'sequence_softmax'
,
'softmax'
,
'pool2d'
,
'pool3d'
,
'batch_norm'
,
'beam_search_decode'
,
'conv2d_transpose'
,
'conv3d_transpose'
,
'sequence_expand'
,
'lstm_unit'
,
'reduce_sum'
,
'reduce_mean'
,
'reduce_max'
,
'reduce_min'
,
'reduce_prod'
,
'sequence_first_step'
,
'sequence_last_step'
,
'dropout'
,
'split'
,
'ctc_greedy_decoder'
,
'edit_distance'
,
'l2_normalize'
,
'matmul'
,
'topk'
,
'warpctc'
,
'sequence_reshape'
,
'transpose'
,
'im2sequence'
,
'nce'
,
'hsigmoid'
,
'beam_search'
,
'row_conv'
,
'multiplex'
,
'layer_norm'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'one_hot'
,
'autoincreased_step_counter'
,
'reshape'
,
'lod_reset'
,
'lrn'
,
'pad'
,
'label_smooth'
,
'roi_pool'
,
'dice_loss'
,
'image_resize'
,
'image_resize_short'
,
'resize_bilinear'
,
'gather'
,
'scatter'
,
'random_crop'
,
'mean_iou'
,
'relu'
,
'log'
,
'crop'
,
'rank_loss'
,
'prelu'
,
'flatten'
,
'stack'
]
...
...
@@ -5517,3 +5458,16 @@ def flatten(x, axis=1, name=None):
outputs
=
{
'Out'
:
out
},
attrs
=
{
"axis"
:
axis
})
return
out
def
stack
(
x
,
axis
=
0
):
helper
=
LayerHelper
(
'stack'
,
**
locals
())
axis
=
0
if
axis
is
None
else
axis
if
not
isinstance
(
x
,
list
)
and
not
isinstance
(
x
,
tuple
):
x
=
[
x
]
out
=
helper
.
create_tmp_variable
(
x
[
0
].
dtype
)
helper
.
append_op
(
type
=
'stack'
,
inputs
=
{
'X'
:
x
},
outpus
=
{
'Y'
:
out
},
attrs
=
{
'axis'
:
axis
})
return
out
python/paddle/fluid/tests/unittests/test_stack_op.py
0 → 100644
浏览文件 @
b548ecbc
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
op_test
import
OpTest
import
numpy
as
np
import
unittest
class
TestStackOpBase
(
OpTest
):
def
initDefaultParameters
(
self
):
self
.
num_inputs
=
4
self
.
input_dim
=
(
5
,
6
,
7
)
self
.
axis
=
0
self
.
dtype
=
'float32'
def
initParameters
(
self
):
pass
def
get_x_names
(
self
):
x_names
=
[]
for
i
in
range
(
self
.
num_inputs
):
x_names
.
append
(
'x{}'
.
format
(
i
))
return
x_names
def
setUp
(
self
):
self
.
initDefaultParameters
()
self
.
initParameters
()
self
.
op_type
=
'stack'
self
.
x
=
[]
for
i
in
range
(
self
.
num_inputs
):
self
.
x
.
append
(
np
.
random
.
random
(
size
=
self
.
input_dim
).
astype
(
self
.
dtype
))
tmp
=
[]
x_names
=
self
.
get_x_names
()
for
i
in
range
(
self
.
num_inputs
):
tmp
.
append
((
x_names
[
i
],
self
.
x
[
i
]))
self
.
inputs
=
{
'X'
:
tmp
}
self
.
outputs
=
{
'Y'
:
np
.
stack
(
self
.
x
,
axis
=
self
.
axis
)}
self
.
attrs
=
{
'axis'
:
self
.
axis
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
(
self
.
get_x_names
(),
'Y'
)
class
TestStackOp1
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
num_inputs
=
16
class
TestStackOp2
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
num_inputs
=
20
class
TestStackOp3
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
axis
=
-
1
class
TestStackOp4
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
axis
=
-
4
class
TestStackOp5
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
axis
=
1
class
TestStackOp6
(
TestStackOpBase
):
def
initParameters
(
self
):
self
.
axis
=
3
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录