Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9401b64d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9401b64d
编写于
7月 03, 2018
作者:
Y
Yu Yang
提交者:
GitHub
7月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11877 from reyoung/feature/fix_reshape_op_size
User can register a standard C++ functor as Kernel
上级
ebadb4c0
550ab8d7
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
261 addition
and
239 deletion
+261
-239
paddle/fluid/framework/op_registry.h
paddle/fluid/framework/op_registry.h
+78
-10
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+2
-2
paddle/fluid/operators/fc_mkldnn_op.cc
paddle/fluid/operators/fc_mkldnn_op.cc
+1
-0
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+179
-11
paddle/fluid/operators/reshape_op.cu
paddle/fluid/operators/reshape_op.cu
+0
-26
paddle/fluid/operators/reshape_op.h
paddle/fluid/operators/reshape_op.h
+0
-189
未找到文件。
paddle/fluid/framework/op_registry.h
浏览文件 @
9401b64d
...
...
@@ -76,6 +76,20 @@ class OpRegistry {
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctor
;
template
<
typename
PlaceType
,
typename
T
,
typename
Func
>
inline
void
RegisterKernelClass
(
const
char
*
op_type
,
const
char
*
library_type
,
Func
func
)
{
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
data_layout
=
"MKLDNNLAYOUT"
;
}
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
]
=
func
;
}
template
<
typename
PlaceType
,
size_t
I
,
typename
...
KernelTypes
>
struct
OpKernelRegistrarFunctor
<
PlaceType
,
false
,
I
,
KernelTypes
...
>
{
using
KERNEL_TYPE
=
...
...
@@ -83,16 +97,10 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
std
::
string
library
(
library_type
);
std
::
string
data_layout
=
"ANYLAYOUT"
;
if
(
library
==
"MKLDNN"
)
{
data_layout
=
"MKLDNNLAYOUT"
;
}
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
(),
StringToDataLayout
(
data_layout
),
StringToLibraryType
(
library_type
));
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
].
reset
(
new
KERNEL_TYPE
);
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
[](
const
framework
::
ExecutionContext
&
ctx
)
{
KERNEL_TYPE
().
Compute
(
ctx
);
});
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
OpKernelRegistrarFunctor
<
PlaceType
,
I
+
1
==
size
,
I
+
1
,
KernelTypes
...
>
func
;
...
...
@@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar {
}
};
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
struct
OpKernelRegistrarFunctorEx
;
template
<
typename
PlaceType
,
typename
...
DataTypeAndKernelType
>
class
OpKernelRegistrarEx
:
public
Registrar
{
public:
explicit
OpKernelRegistrarEx
(
const
char
*
op_type
,
const
char
*
library_type
)
{
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
0
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
true
,
I
,
DataTypeAndKernelType
...
>
{
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{}
};
template
<
typename
PlaceType
,
size_t
I
,
typename
...
DataTypeAndKernelType
>
struct
OpKernelRegistrarFunctorEx
<
PlaceType
,
false
,
I
,
DataTypeAndKernelType
...
>
{
using
Functor
=
typename
std
::
tuple_element
<
I
+
1
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
using
T
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
DataTypeAndKernelType
...
>>::
type
;
void
operator
()(
const
char
*
op_type
,
const
char
*
library_type
)
const
{
RegisterKernelClass
<
PlaceType
,
T
>
(
op_type
,
library_type
,
Functor
());
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
DataTypeAndKernelType
...
>>::
value
;
OpKernelRegistrarFunctorEx
<
PlaceType
,
I
+
2
>=
size
,
I
+
2
,
DataTypeAndKernelType
...
>
func
;
func
(
op_type
,
library_type
);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
...
...
@@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar {
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \
#library_type); \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \
__VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
9401b64d
...
...
@@ -651,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx
=
pool
.
Get
(
expected_kernel_key
.
place_
);
}
kernel_iter
->
second
->
Compute
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
if
(
!
transfered_inplace_vars
.
empty
())
{
// there is inplace variable has been transfered.
...
...
paddle/fluid/framework/operator.h
浏览文件 @
9401b64d
...
...
@@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase {
class
OperatorWithKernel
:
public
OperatorBase
{
public:
using
OpKernelFunc
=
std
::
function
<
void
(
const
ExecutionContext
&
)
>
;
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelType
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernelType
::
Hash
>
;
std
::
unordered_map
<
OpKernelType
,
OpKernelFunc
,
OpKernelType
::
Hash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
...
...
paddle/fluid/operators/fc_mkldnn_op.cc
浏览文件 @
9401b64d
...
...
@@ -115,6 +115,7 @@ class MKLDNNMemory {
template
<
typename
T
>
class
FCMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
ctx
.
GetPlace
()),
"It must use CPUPlace."
);
...
...
paddle/fluid/operators/reshape_op.cc
浏览文件 @
9401b64d
...
...
@@ -12,14 +12,108 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
if
(
ctx
->
HasInput
(
"Shape"
)
&&
ctx
->
IsRuntime
())
{
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
return
;
}
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
x_dims
[
0
]
==
out_dims
[
0
])
{
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
static
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
if
(
in_size
>
0
)
{
// in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
output_shape
[
unk_dim_idx
]
=
-
1
;
}
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
ReshapeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -107,19 +201,93 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
}
};
class
ReshapeKernel
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
shape_tensor
=
ctx
.
HasInput
(
"Shape"
)
?
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Shape"
)
:
nullptr
;
framework
::
DDim
out_dims
=
out
->
dims
();
if
(
shape_tensor
)
{
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
out_dims
=
ReshapeOp
::
ValidateShape
(
shape
,
in
->
dims
());
}
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
out_dims
[
0
],
in
->
dims
()[
0
],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op."
);
}
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
(
ctx
.
GetPlace
(),
in
->
type
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
};
class
ReshapeGradKernel
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
(
ctx
.
GetPlace
(),
d_out
->
type
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
reshape
,
ops
::
ReshapeOp
,
ops
::
ReshapeOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
reshape_grad
,
ops
::
ReshapeGradOp
);
REGISTER_OP_CPU_KERNEL
(
reshape
,
ops
::
ReshapeKernel
<
CPU
,
float
>
,
ops
::
ReshapeKernel
<
CPU
,
double
>
,
ops
::
ReshapeKernel
<
CPU
,
int
>
,
ops
::
ReshapeKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
reshape_grad
,
ops
::
ReshapeGradKernel
<
CPU
,
float
>
,
ops
::
ReshapeGradKernel
<
CPU
,
double
>
,
ops
::
ReshapeGradKernel
<
CPU
,
int
>
,
ops
::
ReshapeGradKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
reshape_grad
,
float
,
ops
::
ReshapeGradKernel
,
double
,
ops
::
ReshapeGradKernel
,
int
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape
,
float
,
ops
::
ReshapeKernel
,
double
,
ops
::
ReshapeKernel
,
int
,
ops
::
ReshapeKernel
,
int64_t
,
ops
::
ReshapeKernel
);
REGISTER_OP_CUDA_KERNEL_FUNCTOR
(
reshape_grad
,
float
,
ops
::
ReshapeGradKernel
,
double
,
ops
::
ReshapeGradKernel
,
int
,
ops
::
ReshapeGradKernel
,
int64_t
,
ops
::
ReshapeGradKernel
);
#endif
paddle/fluid/operators/reshape_op.cu
已删除
100644 → 0
浏览文件 @
ebadb4c0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reshape_op.h"
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
reshape
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
float
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
double
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
int
>
,
paddle
::
operators
::
ReshapeKernel
<
CUDA
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
reshape_grad
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
float
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
double
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
int
>
,
paddle
::
operators
::
ReshapeGradKernel
<
CUDA
,
int64_t
>
);
paddle/fluid/operators/reshape_op.h
已删除
100644 → 0
浏览文件 @
ebadb4c0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
ReshapeOp
:
public
framework
::
OperatorWithKernel
{
public:
ReshapeOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ReshapeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of ReshapeOp should not be null."
);
const
std
::
vector
<
int
>
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE
(
!
shape
.
empty
(),
"The shape information must be set by Attr(shape)."
);
if
(
ctx
->
HasInput
(
"Shape"
)
&&
ctx
->
IsRuntime
())
{
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
return
;
}
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
ValidateShape
(
shape
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
if
(
x_dims
[
0
]
==
out_dims
[
0
])
{
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
static
framework
::
DDim
ValidateShape
(
const
std
::
vector
<
int
>
shape
,
const
framework
::
DDim
&
in_dims
)
{
const
int64_t
in_size
=
framework
::
product
(
in_dims
);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const
int64_t
unk_dim_val
=
-
1
;
const
int64_t
copy_dim_val
=
0
;
std
::
vector
<
int64_t
>
output_shape
(
shape
.
size
(),
0
);
int64_t
capacity
=
1
;
int
unk_dim_idx
=
-
1
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
shape
[
i
]
==
unk_dim_val
)
{
PADDLE_ENFORCE
(
unk_dim_idx
==
-
1
,
"Only one input dimension of Attr(shape) can be unknown."
);
unk_dim_idx
=
i
;
}
else
if
(
shape
[
i
]
==
copy_dim_val
)
{
PADDLE_ENFORCE
(
static_cast
<
int
>
(
i
)
<
in_dims
.
size
(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape."
);
}
else
{
PADDLE_ENFORCE
(
shape
[
i
]
>
0
,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension."
);
}
capacity
*=
(
shape
[
i
]
?
shape
[
i
]
:
in_dims
[
i
]);
output_shape
[
i
]
=
(
shape
[
i
]
?
static_cast
<
int64_t
>
(
shape
[
i
])
:
in_dims
[
i
]);
}
if
(
unk_dim_idx
!=
-
1
)
{
if
(
in_size
>
0
)
{
// in_size < 0 and is un-determinate in compile time, skip the check,
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, in_size = -8, output_shape[0] = 0
// the following check will fail.
output_shape
[
unk_dim_idx
]
=
-
in_size
/
capacity
;
PADDLE_ENFORCE_EQ
(
output_shape
[
unk_dim_idx
]
*
capacity
,
-
in_size
,
"Invalid shape is given."
);
}
else
{
output_shape
[
unk_dim_idx
]
=
-
1
;
}
}
else
{
PADDLE_ENFORCE_EQ
(
capacity
,
in_size
,
"Invalid shape is given."
);
}
return
framework
::
make_ddim
(
output_shape
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
shape_tensor
=
ctx
.
HasInput
(
"Shape"
)
?
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Shape"
)
:
nullptr
;
framework
::
DDim
out_dims
=
out
->
dims
();
if
(
shape_tensor
)
{
auto
*
shape_data
=
shape_tensor
->
data
<
int
>
();
framework
::
Tensor
cpu_shape_tensor
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
TensorCopySync
(
*
shape_tensor
,
platform
::
CPUPlace
(),
&
cpu_shape_tensor
);
shape_data
=
cpu_shape_tensor
.
data
<
int
>
();
}
auto
shape
=
std
::
vector
<
int
>
(
shape_data
,
shape_data
+
shape_tensor
->
numel
());
out_dims
=
ReshapeOp
::
ValidateShape
(
shape
,
in
->
dims
());
}
if
(
!
in
->
lod
().
empty
())
{
PADDLE_ENFORCE_EQ
(
out_dims
[
0
],
in
->
dims
()[
0
],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op."
);
}
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
out
->
Resize
(
out_dims
);
if
(
!
inplace
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
TensorCopySync
(
*
in
,
ctx
.
GetPlace
(),
out
);
out
->
Resize
(
out_dims
);
}
else
{
out
->
ShareDataWith
(
*
in
);
out
->
Resize
(
out_dims
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
ReshapeGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
d_out
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
bool
inplace
=
ctx
.
Attr
<
bool
>
(
"inplace"
);
auto
in_dims
=
d_x
->
dims
();
if
(
!
inplace
)
{
framework
::
TensorCopy
(
*
d_out
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
d_x
);
ctx
.
device_context
().
Wait
();
d_x
->
Resize
(
in_dims
);
}
else
{
d_x
->
ShareDataWith
(
*
d_out
);
d_x
->
Resize
(
in_dims
);
}
}
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录