Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
91b72482
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
91b72482
编写于
11月 07, 2017
作者:
Y
Yu Yang
提交者:
GitHub
11月 07, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5443 from reyoung/feature/InferKernelKey
Polish OpWithKernel
上级
fcc50cb4
46610e9a
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
185 addition
and
126 deletion
+185
-126
doc/design/float16.md
doc/design/float16.md
+1
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+34
-3
paddle/framework/operator.h
paddle/framework/operator.h
+27
-52
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+2
-2
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+4
-3
paddle/operators/auc_op.cc
paddle/operators/auc_op.cc
+4
-3
paddle/operators/batch_norm_op.cc
paddle/operators/batch_norm_op.cc
+4
-2
paddle/operators/crf_decoding_op.cc
paddle/operators/crf_decoding_op.cc
+4
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+8
-4
paddle/operators/fill_constant_batch_size_like_op.cc
paddle/operators/fill_constant_batch_size_like_op.cc
+4
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+3
-2
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+8
-4
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+4
-2
paddle/operators/linear_chain_crf_op.cc
paddle/operators/linear_chain_crf_op.cc
+10
-5
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+8
-4
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+8
-6
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+8
-4
paddle/operators/positive_negative_pair_op.cc
paddle/operators/positive_negative_pair_op.cc
+4
-2
paddle/operators/precision_recall_op.cc
paddle/operators/precision_recall_op.cc
+4
-2
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+8
-4
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+4
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+9
-5
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+10
-6
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+4
-2
未找到文件。
doc/design/float16.md
浏览文件 @
91b72482
...
...
@@ -55,6 +55,6 @@ After float16 class is available, some of the future items are below:
-
Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
-
Modify
`
IndicateData
Type()`
method in
`framework/operator.h`
to make it compatible with float16.
-
Modify
`
GetKernel
Type()`
method in
`framework/operator.h`
to make it compatible with float16.
-
Create a type-casting operator that can convert the data type in tensor between float16 and other types.
paddle/framework/op_registry.h
浏览文件 @
91b72482
...
...
@@ -92,8 +92,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void
operator
()(
const
char
*
op_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
OperatorWithKernel
::
OpKernelKey
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
());
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
());
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
].
reset
(
new
KERNEL_TYPE
);
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
...
...
paddle/framework/operator.cc
浏览文件 @
91b72482
...
...
@@ -254,8 +254,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return
res
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OperatorWithKernel
::
OpKernelKey
&
kernel_key
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OpKernelType
&
kernel_key
)
{
os
<<
"place["
<<
kernel_key
.
place_
<<
"]:data_type["
<<
kernel_key
.
data_type_
<<
"]"
;
return
os
;
...
...
@@ -432,7 +431,7 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
kernel_key
=
OpKernelKey
(
IndicateDataType
(
ctx
),
dev_
ctx
);
auto
kernel_key
=
GetKernelType
(
ctx
);
auto
kernel_iter
=
kernels
.
find
(
kernel_key
);
if
(
kernel_iter
==
kernels
.
end
())
{
...
...
@@ -444,6 +443,38 @@ void OperatorWithKernel::Run(const Scope& scope,
// throws errors if have.
dev_ctx
.
Finish
();
}
OpKernelType
OperatorWithKernel
::
GetKernelType
(
const
ExecutionContext
&
ctx
)
const
{
return
OpKernelType
(
IndicateDataType
(
ctx
),
ctx
.
device_context
());
}
DataType
OperatorWithKernel
::
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
scope
=
ctx
.
scope
();
int
data_type
=
-
1
;
for
(
auto
&
input
:
this
->
inputs_
)
{
for
(
auto
&
ipt_name
:
input
.
second
)
{
auto
*
var
=
scope
.
FindVar
(
ipt_name
);
if
(
var
!=
nullptr
)
{
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
if
(
t
!=
nullptr
)
{
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same."
,
Type
());
data_type
=
tmp
;
}
}
}
}
PADDLE_ENFORCE
(
data_type
!=
-
1
,
"DataType should be indicated by input"
);
return
static_cast
<
DataType
>
(
data_type
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
91b72482
...
...
@@ -345,27 +345,10 @@ class OpKernel : public OpKernelBase {
using
ELEMENT_TYPE
=
T
;
};
class
OperatorWithKernel
:
public
OperatorBase
{
public:
struct
OpKernelKey
{
platform
::
Place
place_
;
DataType
data_type_
;
OpKernelKey
(
DataType
data_type
,
platform
::
Place
place
)
:
place_
(
place
),
data_type_
(
data_type
)
{}
OpKernelKey
(
DataType
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
)
:
place_
(
dev_ctx
.
GetPlace
()),
data_type_
(
data_type
)
{}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
;
}
};
struct
OpKernelHash
{
struct
OpKernelType
{
struct
Hash
{
std
::
hash
<
int
>
hash_
;
size_t
operator
()(
const
OpKernel
Key
&
key
)
const
{
size_t
operator
()(
const
OpKernel
Type
&
key
)
const
{
int
place
=
key
.
place_
.
which
();
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
);
int
pre_hash
=
data_type
<<
NUM_PLACE_TYPE_LIMIT_IN_BIT
|
...
...
@@ -374,9 +357,26 @@ class OperatorWithKernel : public OperatorBase {
}
};
platform
::
Place
place_
;
DataType
data_type_
;
OpKernelType
(
DataType
data_type
,
platform
::
Place
place
)
:
place_
(
place
),
data_type_
(
data_type
)
{}
OpKernelType
(
DataType
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
)
:
place_
(
dev_ctx
.
GetPlace
()),
data_type_
(
data_type
)
{}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
;
}
};
class
OperatorWithKernel
:
public
OperatorBase
{
public:
using
OpKernelMap
=
std
::
unordered_map
<
OpKernel
Key
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernelHash
>
;
std
::
unordered_map
<
OpKernel
Type
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernel
Type
::
Hash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
...
...
@@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual
OpKernelType
GetKernelType
(
const
ExecutionContext
&
ctx
)
const
;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
scope
=
ctx
.
scope
();
int
data_type
=
-
1
;
for
(
auto
&
input
:
this
->
inputs_
)
{
for
(
auto
&
ipt_name
:
input
.
second
)
{
auto
*
var
=
scope
.
FindVar
(
ipt_name
);
if
(
var
!=
nullptr
)
{
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
if
(
t
!=
nullptr
)
{
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same."
,
Type
());
data_type
=
tmp
;
}
}
}
}
PADDLE_ENFORCE
(
data_type
!=
-
1
,
"DataType should be indicated by input"
);
return
static_cast
<
DataType
>
(
data_type
);
}
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OperatorWithKernel
::
OpKernelKey
&
kernel_key
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OpKernelType
&
kernel_key
);
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
91b72482
...
...
@@ -114,8 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
DataType
IndicateData
Type
(
const
ExecutionContext
&
ctx
)
const
override
{
return
DataType
::
FP32
;
OpKernelType
GetKernel
Type
(
const
ExecutionContext
&
ctx
)
const
override
{
return
OpKernelType
(
DataType
::
FP32
,
ctx
.
device_context
())
;
}
};
...
...
paddle/operators/accuracy_op.cc
浏览文件 @
91b72482
...
...
@@ -47,10 +47,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework
::
DataType
IndicateDataType
(
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/auc_op.cc
浏览文件 @
91b72482
...
...
@@ -39,10 +39,11 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework
::
DataType
IndicateDataType
(
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/batch_norm_op.cc
浏览文件 @
91b72482
...
...
@@ -303,7 +303,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
framework
::
DataType
IndicateDataType
(
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
var
=
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
));
if
(
var
==
nullptr
)
{
...
...
@@ -318,7 +319,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
return
framework
::
ToDataType
(
t
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
t
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/crf_decoding_op.cc
浏览文件 @
91b72482
...
...
@@ -120,9 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
91b72482
...
...
@@ -51,9 +51,11 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -98,9 +100,11 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
91b72482
...
...
@@ -49,9 +49,11 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
91b72482
...
...
@@ -33,11 +33,12 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
data_type
=
ctx
.
Attr
<
int
>
(
"data_type"
);
VLOG
(
10
)
<<
" FillConstant data_type = "
<<
data_type
;
return
static_cast
<
framework
::
DataType
>
(
data_type
);
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
data_type
),
ctx
.
device_context
());
}
};
...
...
paddle/operators/gather_op.cc
浏览文件 @
91b72482
...
...
@@ -40,9 +40,11 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -55,9 +57,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
91b72482
...
...
@@ -57,9 +57,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
paddle/operators/linear_chain_crf_op.cc
浏览文件 @
91b72482
...
...
@@ -183,9 +183,11 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -240,10 +242,13 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
91b72482
...
...
@@ -41,9 +41,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -97,9 +99,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/lstm_op.cc
浏览文件 @
91b72482
...
...
@@ -84,10 +84,11 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -245,10 +246,11 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
91b72482
...
...
@@ -51,9 +51,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -107,9 +109,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/positive_negative_pair_op.cc
浏览文件 @
91b72482
...
...
@@ -85,9 +85,11 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/precision_recall_op.cc
浏览文件 @
91b72482
...
...
@@ -80,9 +80,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/scatter_op.cc
浏览文件 @
91b72482
...
...
@@ -49,9 +49,11 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -66,9 +68,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
91b72482
...
...
@@ -107,9 +107,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
91b72482
...
...
@@ -121,9 +121,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -160,10 +162,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/sum_op.cc
浏览文件 @
91b72482
...
...
@@ -47,20 +47,24 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
x_vars
=
ctx
.
MultiInputVar
(
"X"
);
if
(
x_vars
[
0
]
->
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
().
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
().
type
()),
ctx
.
device_context
());
}
else
if
(
x_vars
[
0
]
->
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
()),
ctx
.
device_context
());
}
else
if
(
x_vars
[
0
]
->
IsType
<
framework
::
LoDTensorArray
>
())
{
auto
&
array
=
x_vars
[
0
]
->
Get
<
framework
::
LoDTensorArray
>
();
for
(
auto
&
each
:
array
)
{
if
(
each
.
numel
()
!=
0
)
{
return
framework
::
ToDataType
(
each
.
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
each
.
type
()),
ctx
.
device_context
());
}
}
}
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
91b72482
...
...
@@ -63,9 +63,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录