Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1ce478f1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ce478f1
编写于
7月 02, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish reshape op
上级
81f22bb2
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
157 addition
and
90 deletion
+157
-90
paddle/fluid/framework/op_registry.h
paddle/fluid/framework/op_registry.h
+74
-10
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+65
-9
paddle/fluid/operators/reshape_op.cu.cc
paddle/fluid/operators/reshape_op.cu.cc
+8
-10
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+10
-61
未找到文件。
paddle/fluid/framework/op_registry.h
浏览文件 @
1ce478f1
...
...
@@ -76,13 +76,8 @@ class OpRegistry {
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelTypes
>
struct
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelTypes
...
>
{
using
KERNEL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelTypes
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
template
<
typename
PlaceType
,
typename
T
,
typename
KernelType
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
)
{
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
...
...
@@ -91,8 +86,17 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
].
reset
(
new
KERNEL_TYPE
);
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
].
reset
(
new
KernelType
());
}
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelTypes
>
struct
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelTypes
...
>
{
using
KERNEL_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
KernelTypes
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
RegisterKernelClass
<
PlaceType
,
T
,
KERNEL_TYPE
>
(
op_type
,
library_type
);
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
func
;
...
...
@@ -116,6 +120,47 @@ class OpKernelRegistrar : public Registrar {
}
};
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctorEx
;
template
<
typename
PlaceType
,
typename
...
DataTypeAndKernelType
>
class
OpKernelRegistrarEx
:
public
Registrar
{
public:
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
)
{
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
0
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
true
,
I
,
DataTypeAndKernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
I
,
DataTypeAndKernelType
...
>
{
using
KERNEL_TYPE
=
typename
std
::
tuple_element
<
I
+
1
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
using
T
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
RegisterKernelClass
<
PlaceType
,
T
,
KERNEL_TYPE
>
(
op_type
,
library_type
);
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
OpKernelRegistrarFunctorEx
<
PlaceType
,
I
+
2
>=
size
,
I
+
2
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
...
...
@@ -174,6 +219,25 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_EX(op_type, ...) \
REGISTER_OP_KERNEL_EX(p_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_EX(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
...
...
paddle/fluid/operators/reshape_op.cc
浏览文件 @
1ce478f1
...
...
@@ -107,19 +107,75 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
void
ReshapeKernel
::
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
shape_tensor
=
ctx
.
HasInput
(
"Shape"
)
?
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Shape"
)
:
nullptr
;
framework
::
DDim
out_dims
=
out
->
dims
();
if
(
shape_tensor
)
{
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
out_dims
=
ReshapeOp
::
ValidateShape
(
shape
,
in
->
dims
());
}
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
out_dims
[
0
],
in
->
dims
()[
0
],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op."
);
}
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
(
ctx
.
GetPlace
(),
in
->
type
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
void
ReshapeGradKernelBase
::
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
(
ctx
.
GetPlace
(),
d_out
->
type
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
reshape
,
ops
::
ReshapeOp
,
ops
::
ReshapeOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
);
REGISTER_OP_CPU_KERNEL
(
reshape
,
ops
::
ReshapeKernel
<
CPU
,
float
>
,
ops
::
ReshapeKernel
<
CPU
,
double
>
,
ops
::
ReshapeKernel
<
CPU
,
int
>
,
ops
::
ReshapeKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
reshape_grad
,
ops
::
ReshapeGradKernel
<
CPU
,
float
>
,
ops
::
ReshapeGradKernel
<
CPU
,
double
>
,
ops
::
ReshapeGradKernel
<
CPU
,
int
>
,
ops
::
ReshapeGradKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL_EX
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CPU_KERNEL
(
reshape_grad
,
ops
::
ReshapeGradKernel
<
float
>
,
ops
::
ReshapeGradKernel
<
double
>
,
ops
::
ReshapeGradKernel
<
int
>
,
ops
::
ReshapeGradKernel
<
int64_t
>
);
paddle/fluid/operators/reshape_op.cu
→
paddle/fluid/operators/reshape_op.cu
.cc
浏览文件 @
1ce478f1
...
...
@@ -13,14 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
reshape
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
float
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
double
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
int
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
int64_t
>
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL_EX
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CUDA_KERNEL
(
reshape_grad
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
float
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
double
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
int
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
int64_t
>
);
paddle
::
operators
::
ReshapeGradKernel
<
float
>
,
paddle
::
operators
::
ReshapeGradKernel
<
double
>
,
paddle
::
operators
::
ReshapeGradKernel
<
int
>
,
paddle
::
operators
::
ReshapeGradKernel
<
int64_t
>
);
paddle/fluid/operators/reshape_op.h
浏览文件 @
1ce478f1
...
...
@@ -118,72 +118,21 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ReshapeKernel
:
public
framework
::
OpKernelBase
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
shape_tensor
=
ctx
.
HasInput
(
"Shape"
)
?
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Shape"
)
:
nullptr
;
framework
::
DDim
out_dims
=
out
->
dims
();
if
(
shape_tensor
)
{
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
out_dims
=
ReshapeOp
::
ValidateShape
(
shape
,
in
->
dims
());
}
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
out_dims
[
0
],
in
->
dims
()[
0
],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op."
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
final
;
};
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
class
ReshapeGradKernelBase
:
public
framework
::
OpKernelBase
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeGradKernel
:
public
framework
::
OpKernel
<
T
>
{
template
<
typename
T
>
class
ReshapeGradKernel
:
public
ReshapeGradKernelBase
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
// Tell register element type.
using
ELEMENT_TYPE
=
T
;
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录