Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
91b72482
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
91b72482
编写于
11月 07, 2017
作者:
Y
Yu Yang
提交者:
GitHub
11月 07, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5443 from reyoung/feature/InferKernelKey
Polish OpWithKernel
上级
fcc50cb4
46610e9a
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
185 addition
and
126 deletion
+185
-126
doc/design/float16.md
doc/design/float16.md
+1
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+34
-3
paddle/framework/operator.h
paddle/framework/operator.h
+27
-52
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+2
-2
paddle/operators/accuracy_op.cc
paddle/operators/accuracy_op.cc
+4
-3
paddle/operators/auc_op.cc
paddle/operators/auc_op.cc
+4
-3
paddle/operators/batch_norm_op.cc
paddle/operators/batch_norm_op.cc
+4
-2
paddle/operators/crf_decoding_op.cc
paddle/operators/crf_decoding_op.cc
+4
-2
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+8
-4
paddle/operators/fill_constant_batch_size_like_op.cc
paddle/operators/fill_constant_batch_size_like_op.cc
+4
-2
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+3
-2
paddle/operators/gather_op.cc
paddle/operators/gather_op.cc
+8
-4
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+4
-2
paddle/operators/linear_chain_crf_op.cc
paddle/operators/linear_chain_crf_op.cc
+10
-5
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+8
-4
paddle/operators/lstm_op.cc
paddle/operators/lstm_op.cc
+8
-6
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+8
-4
paddle/operators/positive_negative_pair_op.cc
paddle/operators/positive_negative_pair_op.cc
+4
-2
paddle/operators/precision_recall_op.cc
paddle/operators/precision_recall_op.cc
+4
-2
paddle/operators/scatter_op.cc
paddle/operators/scatter_op.cc
+8
-4
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+4
-2
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+9
-5
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+10
-6
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+4
-2
未找到文件。
doc/design/float16.md
浏览文件 @
91b72482
...
...
@@ -55,6 +55,6 @@ After float16 class is available, some of the future items are below:
-
Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
-
Modify
`
IndicateData
Type()`
method in
`framework/operator.h`
to make it compatible with float16.
-
Modify
`
GetKernel
Type()`
method in
`framework/operator.h`
to make it compatible with float16.
-
Create a type-casting operator that can convert the data type in tensor between float16 and other types.
paddle/framework/op_registry.h
浏览文件 @
91b72482
...
...
@@ -92,8 +92,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
void
operator
()(
const
char
*
op_type
)
const
{
using
T
=
typename
KERNEL_TYPE
::
ELEMENT_TYPE
;
OperatorWithKernel
::
OpKernelKey
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
());
OpKernelType
key
(
ToDataType
(
std
::
type_index
(
typeid
(
T
))),
PlaceType
());
OperatorWithKernel
::
AllOpKernels
()[
op_type
][
key
].
reset
(
new
KERNEL_TYPE
);
constexpr
auto
size
=
std
::
tuple_size
<
std
::
tuple
<
KernelTypes
...
>>::
value
;
...
...
paddle/framework/operator.cc
浏览文件 @
91b72482
...
...
@@ -254,8 +254,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return
res
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OperatorWithKernel
::
OpKernelKey
&
kernel_key
)
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OpKernelType
&
kernel_key
)
{
os
<<
"place["
<<
kernel_key
.
place_
<<
"]:data_type["
<<
kernel_key
.
data_type_
<<
"]"
;
return
os
;
...
...
@@ -432,7 +431,7 @@ void OperatorWithKernel::Run(const Scope& scope,
// check if op[type] have kernel for kernel_key
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
kernel_key
=
OpKernelKey
(
IndicateDataType
(
ctx
),
dev_
ctx
);
auto
kernel_key
=
GetKernelType
(
ctx
);
auto
kernel_iter
=
kernels
.
find
(
kernel_key
);
if
(
kernel_iter
==
kernels
.
end
())
{
...
...
@@ -444,6 +443,38 @@ void OperatorWithKernel::Run(const Scope& scope,
// throws errors if have.
dev_ctx
.
Finish
();
}
OpKernelType
OperatorWithKernel
::
GetKernelType
(
const
ExecutionContext
&
ctx
)
const
{
return
OpKernelType
(
IndicateDataType
(
ctx
),
ctx
.
device_context
());
}
DataType
OperatorWithKernel
::
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
scope
=
ctx
.
scope
();
int
data_type
=
-
1
;
for
(
auto
&
input
:
this
->
inputs_
)
{
for
(
auto
&
ipt_name
:
input
.
second
)
{
auto
*
var
=
scope
.
FindVar
(
ipt_name
);
if
(
var
!=
nullptr
)
{
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
if
(
t
!=
nullptr
)
{
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same."
,
Type
());
data_type
=
tmp
;
}
}
}
}
PADDLE_ENFORCE
(
data_type
!=
-
1
,
"DataType should be indicated by input"
);
return
static_cast
<
DataType
>
(
data_type
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
91b72482
...
...
@@ -345,27 +345,10 @@ class OpKernel : public OpKernelBase {
using
ELEMENT_TYPE
=
T
;
};
class
OperatorWithKernel
:
public
OperatorBase
{
public:
struct
OpKernelKey
{
platform
::
Place
place_
;
DataType
data_type_
;
OpKernelKey
(
DataType
data_type
,
platform
::
Place
place
)
:
place_
(
place
),
data_type_
(
data_type
)
{}
OpKernelKey
(
DataType
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
)
:
place_
(
dev_ctx
.
GetPlace
()),
data_type_
(
data_type
)
{}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
;
}
};
struct
OpKernelHash
{
struct
OpKernelType
{
struct
Hash
{
std
::
hash
<
int
>
hash_
;
size_t
operator
()(
const
OpKernel
Key
&
key
)
const
{
size_t
operator
()(
const
OpKernel
Type
&
key
)
const
{
int
place
=
key
.
place_
.
which
();
int
data_type
=
static_cast
<
int
>
(
key
.
data_type_
);
int
pre_hash
=
data_type
<<
NUM_PLACE_TYPE_LIMIT_IN_BIT
|
...
...
@@ -374,9 +357,26 @@ class OperatorWithKernel : public OperatorBase {
}
};
platform
::
Place
place_
;
DataType
data_type_
;
OpKernelType
(
DataType
data_type
,
platform
::
Place
place
)
:
place_
(
place
),
data_type_
(
data_type
)
{}
OpKernelType
(
DataType
data_type
,
const
platform
::
DeviceContext
&
dev_ctx
)
:
place_
(
dev_ctx
.
GetPlace
()),
data_type_
(
data_type
)
{}
bool
operator
==
(
const
OpKernelType
&
o
)
const
{
return
platform
::
places_are_same_class
(
place_
,
o
.
place_
)
&&
data_type_
==
o
.
data_type_
;
}
};
class
OperatorWithKernel
:
public
OperatorBase
{
public:
using
OpKernelMap
=
std
::
unordered_map
<
OpKernel
Key
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernelHash
>
;
std
::
unordered_map
<
OpKernel
Type
,
std
::
unique_ptr
<
OpKernelBase
>
,
OpKernel
Type
::
Hash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
...
...
@@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual
OpKernelType
GetKernelType
(
const
ExecutionContext
&
ctx
)
const
;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
{
auto
&
scope
=
ctx
.
scope
();
int
data_type
=
-
1
;
for
(
auto
&
input
:
this
->
inputs_
)
{
for
(
auto
&
ipt_name
:
input
.
second
)
{
auto
*
var
=
scope
.
FindVar
(
ipt_name
);
if
(
var
!=
nullptr
)
{
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
t
=
&
(
var
->
Get
<
SelectedRows
>
().
value
());
}
if
(
t
!=
nullptr
)
{
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()));
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same."
,
Type
());
data_type
=
tmp
;
}
}
}
}
PADDLE_ENFORCE
(
data_type
!=
-
1
,
"DataType should be indicated by input"
);
return
static_cast
<
DataType
>
(
data_type
);
}
DataType
IndicateDataType
(
const
ExecutionContext
&
ctx
)
const
;
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OperatorWithKernel
::
OpKernelKey
&
kernel_key
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
OpKernelType
&
kernel_key
);
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
91b72482
...
...
@@ -114,8 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
DataType
IndicateData
Type
(
const
ExecutionContext
&
ctx
)
const
override
{
return
DataType
::
FP32
;
OpKernelType
GetKernel
Type
(
const
ExecutionContext
&
ctx
)
const
override
{
return
OpKernelType
(
DataType
::
FP32
,
ctx
.
device_context
())
;
}
};
...
...
paddle/operators/accuracy_op.cc
浏览文件 @
91b72482
...
...
@@ -47,10 +47,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework
::
DataType
IndicateDataType
(
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/auc_op.cc
浏览文件 @
91b72482
...
...
@@ -39,10 +39,11 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
// IndicateDataType
framework
::
DataType
IndicateDataType
(
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/batch_norm_op.cc
浏览文件 @
91b72482
...
...
@@ -303,7 +303,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
framework
::
DataType
IndicateDataType
(
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
var
=
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
));
if
(
var
==
nullptr
)
{
...
...
@@ -318,7 +319,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
return
framework
::
ToDataType
(
t
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
t
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/crf_decoding_op.cc
浏览文件 @
91b72482
...
...
@@ -120,9 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
ctx
.
device_context
());
}
};
}
// namespace operators
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
91b72482
...
...
@@ -51,9 +51,11 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -98,9 +100,11 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
91b72482
...
...
@@ -49,9 +49,11 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
paddle/operators/fill_constant_op.cc
浏览文件 @
91b72482
...
...
@@ -33,11 +33,12 @@ class FillConstantOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
data_type
=
ctx
.
Attr
<
int
>
(
"data_type"
);
VLOG
(
10
)
<<
" FillConstant data_type = "
<<
data_type
;
return
static_cast
<
framework
::
DataType
>
(
data_type
);
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
data_type
),
ctx
.
device_context
());
}
};
...
...
paddle/operators/gather_op.cc
浏览文件 @
91b72482
...
...
@@ -40,9 +40,11 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -55,9 +57,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
91b72482
...
...
@@ -57,9 +57,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
paddle/operators/linear_chain_crf_op.cc
浏览文件 @
91b72482
...
...
@@ -183,9 +183,11 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -240,10 +242,13 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
91b72482
...
...
@@ -41,9 +41,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -97,9 +99,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"W"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/lstm_op.cc
浏览文件 @
91b72482
...
...
@@ -84,10 +84,11 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -245,10 +246,11 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/multiplex_op.cc
浏览文件 @
91b72482
...
...
@@ -51,9 +51,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -107,9 +109,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/positive_negative_pair_op.cc
浏览文件 @
91b72482
...
...
@@ -85,9 +85,11 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/precision_recall_op.cc
浏览文件 @
91b72482
...
...
@@ -80,9 +80,11 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/scatter_op.cc
浏览文件 @
91b72482
...
...
@@ -49,9 +49,11 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -66,9 +68,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Ref"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/sequence_pool_op.cc
浏览文件 @
91b72482
...
...
@@ -107,9 +107,11 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
91b72482
...
...
@@ -121,9 +121,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
()),
ctx
.
device_context
());
}
};
...
...
@@ -160,10 +162,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
()),
ctx
.
device_context
());
}
};
...
...
paddle/operators/sum_op.cc
浏览文件 @
91b72482
...
...
@@ -47,20 +47,24 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
x_vars
=
ctx
.
MultiInputVar
(
"X"
);
if
(
x_vars
[
0
]
->
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
().
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
().
type
()),
ctx
.
device_context
());
}
else
if
(
x_vars
[
0
]
->
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
x_vars
[
0
]
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
()),
ctx
.
device_context
());
}
else
if
(
x_vars
[
0
]
->
IsType
<
framework
::
LoDTensorArray
>
())
{
auto
&
array
=
x_vars
[
0
]
->
Get
<
framework
::
LoDTensorArray
>
();
for
(
auto
&
each
:
array
)
{
if
(
each
.
numel
()
!=
0
)
{
return
framework
::
ToDataType
(
each
.
type
());
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
each
.
type
()),
ctx
.
device_context
());
}
}
}
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
91b72482
...
...
@@ -63,9 +63,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework
::
DataType
IndicateData
Type
(
framework
::
OpKernelType
GetKernel
Type
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
));
return
framework
::
OpKernelType
(
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"data_type"
)),
ctx
.
device_context
());
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录