Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3628d894
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3628d894
编写于
12月 17, 2018
作者:
Y
Yu Yang
提交者:
GitHub
12月 17, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14858 from reyoung/feature/tensor_type
Feature/tensor type
上级
363bf8a4
bacf1d23
变更
142
显示空白变更内容
内联
并排
Showing
142 changed file
with
445 addition
and
627 deletion
+445
-627
paddle/fluid/framework/data_layout_transform.cc
paddle/fluid/framework/data_layout_transform.cc
+3
-3
paddle/fluid/framework/data_layout_transform.h
paddle/fluid/framework/data_layout_transform.h
+8
-8
paddle/fluid/framework/data_type.cc
paddle/fluid/framework/data_type.cc
+7
-17
paddle/fluid/framework/data_type.h
paddle/fluid/framework/data_type.h
+45
-32
paddle/fluid/framework/data_type_test.cc
paddle/fluid/framework/data_type_test.cc
+4
-4
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+1
-1
paddle/fluid/framework/details/fuse_vars_op_handle.h
paddle/fluid/framework/details/fuse_vars_op_handle.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.cc
paddle/fluid/framework/details/reduce_op_handle.cc
+7
-7
paddle/fluid/framework/dlpack_tensor.cc
paddle/fluid/framework/dlpack_tensor.cc
+17
-20
paddle/fluid/framework/dlpack_tensor_test.cc
paddle/fluid/framework/dlpack_tensor_test.cc
+4
-16
paddle/fluid/framework/executor_thread_worker.cc
paddle/fluid/framework/executor_thread_worker.cc
+13
-33
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+3
-3
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+5
-9
paddle/fluid/framework/op_kernel_type_test.cc
paddle/fluid/framework/op_kernel_type_test.cc
+2
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+7
-7
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+2
-2
paddle/fluid/framework/tensor.cc
paddle/fluid/framework/tensor.cc
+2
-2
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+5
-5
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+5
-7
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+4
-6
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-2
paddle/fluid/inference/api/api_impl.cc
paddle/fluid/inference/api/api_impl.cc
+2
-2
paddle/fluid/inference/api/api_impl_tester.cc
paddle/fluid/inference/api/api_impl_tester.cc
+2
-2
paddle/fluid/inference/tests/api/tester_helper.h
paddle/fluid/inference/tests/api/tester_helper.h
+2
-2
paddle/fluid/operators/affine_grid_op.cc
paddle/fluid/operators/affine_grid_op.cc
+4
-4
paddle/fluid/operators/arg_max_op.cc
paddle/fluid/operators/arg_max_op.cc
+0
-1
paddle/fluid/operators/arg_max_op.cu
paddle/fluid/operators/arg_max_op.cu
+0
-2
paddle/fluid/operators/arg_min_op.cc
paddle/fluid/operators/arg_min_op.cc
+0
-1
paddle/fluid/operators/arg_min_op.cu
paddle/fluid/operators/arg_min_op.cu
+0
-2
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+2
-2
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+2
-3
paddle/fluid/operators/average_accumulates_op.cc
paddle/fluid/operators/average_accumulates_op.cc
+2
-3
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+7
-13
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+1
-1
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+1
-2
paddle/fluid/operators/bpr_loss_op.cc
paddle/fluid/operators/bpr_loss_op.cc
+4
-6
paddle/fluid/operators/controlflow/conditional_block_op.cc
paddle/fluid/operators/controlflow/conditional_block_op.cc
+6
-7
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+1
-1
paddle/fluid/operators/conv_op.cc
paddle/fluid/operators/conv_op.cc
+5
-7
paddle/fluid/operators/conv_transpose_op.cc
paddle/fluid/operators/conv_transpose_op.cc
+4
-6
paddle/fluid/operators/crf_decoding_op.cc
paddle/fluid/operators/crf_decoding_op.cc
+2
-3
paddle/fluid/operators/crop_op.cc
paddle/fluid/operators/crop_op.cc
+3
-6
paddle/fluid/operators/cross_entropy_op.cc
paddle/fluid/operators/cross_entropy_op.cc
+4
-6
paddle/fluid/operators/ctc_align_op.cc
paddle/fluid/operators/ctc_align_op.cc
+2
-3
paddle/fluid/operators/detection/anchor_generator_op.cc
paddle/fluid/operators/detection/anchor_generator_op.cc
+1
-2
paddle/fluid/operators/detection/bipartite_match_op.cc
paddle/fluid/operators/detection/bipartite_match_op.cc
+2
-3
paddle/fluid/operators/detection/density_prior_box_op.cc
paddle/fluid/operators/detection/density_prior_box_op.cc
+1
-2
paddle/fluid/operators/detection/generate_proposals_op.cc
paddle/fluid/operators/detection/generate_proposals_op.cc
+2
-3
paddle/fluid/operators/detection/mine_hard_examples_op.cc
paddle/fluid/operators/detection/mine_hard_examples_op.cc
+1
-2
paddle/fluid/operators/detection/multiclass_nms_op.cc
paddle/fluid/operators/detection/multiclass_nms_op.cc
+1
-2
paddle/fluid/operators/detection/prior_box_op.cc
paddle/fluid/operators/detection/prior_box_op.cc
+1
-2
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
...fluid/operators/detection/roi_perspective_transform_op.cc
+4
-6
paddle/fluid/operators/detection/rpn_target_assign_op.cc
paddle/fluid/operators/detection/rpn_target_assign_op.cc
+1
-2
paddle/fluid/operators/detection/target_assign_op.cc
paddle/fluid/operators/detection/target_assign_op.cc
+2
-3
paddle/fluid/operators/detection_map_op.cc
paddle/fluid/operators/detection_map_op.cc
+1
-2
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+1
-2
paddle/fluid/operators/distributed/sendrecvop_utils.cc
paddle/fluid/operators/distributed/sendrecvop_utils.cc
+2
-4
paddle/fluid/operators/distributed/sendrecvop_utils.h
paddle/fluid/operators/distributed/sendrecvop_utils.h
+7
-6
paddle/fluid/operators/distributed/variable_response.cc
paddle/fluid/operators/distributed/variable_response.cc
+7
-8
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
+1
-3
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
...e/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
+1
-3
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+2
-2
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+4
-6
paddle/fluid/operators/fc_op.cc
paddle/fluid/operators/fc_op.cc
+4
-6
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+2
-2
paddle/fluid/operators/fill_op.cc
paddle/fluid/operators/fill_op.cc
+2
-2
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
+4
-6
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
+1
-2
paddle/fluid/operators/fused/fusion_gru_op.cc
paddle/fluid/operators/fused/fusion_gru_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
...le/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
+2
-3
paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
...le/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
+2
-3
paddle/fluid/operators/gather_op.cc
paddle/fluid/operators/gather_op.cc
+4
-6
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+6
-6
paddle/fluid/operators/group_norm_op.cc
paddle/fluid/operators/group_norm_op.cc
+1
-2
paddle/fluid/operators/hierarchical_sigmoid_op.cc
paddle/fluid/operators/hierarchical_sigmoid_op.cc
+4
-6
paddle/fluid/operators/interpolate_op.cc
paddle/fluid/operators/interpolate_op.cc
+4
-4
paddle/fluid/operators/is_empty_op.cc
paddle/fluid/operators/is_empty_op.cc
+1
-2
paddle/fluid/operators/isfinite_op.cc
paddle/fluid/operators/isfinite_op.cc
+2
-3
paddle/fluid/operators/layer_norm_op.cc
paddle/fluid/operators/layer_norm_op.cc
+1
-2
paddle/fluid/operators/linear_chain_crf_op.cc
paddle/fluid/operators/linear_chain_crf_op.cc
+3
-6
paddle/fluid/operators/load_combine_op.cc
paddle/fluid/operators/load_combine_op.cc
+1
-1
paddle/fluid/operators/load_op.cc
paddle/fluid/operators/load_op.cc
+1
-1
paddle/fluid/operators/lod_reset_op.cc
paddle/fluid/operators/lod_reset_op.cc
+4
-6
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+1
-1
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+1
-2
paddle/fluid/operators/lrn_op.cc
paddle/fluid/operators/lrn_op.cc
+2
-3
paddle/fluid/operators/lstm_op.cc
paddle/fluid/operators/lstm_op.cc
+2
-4
paddle/fluid/operators/lstmp_op.cc
paddle/fluid/operators/lstmp_op.cc
+2
-4
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+2
-4
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+1
-1
paddle/fluid/operators/mean_iou_op.cc
paddle/fluid/operators/mean_iou_op.cc
+2
-3
paddle/fluid/operators/mean_op.cc
paddle/fluid/operators/mean_op.cc
+1
-3
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+1
-3
paddle/fluid/operators/metrics/accuracy_op.cc
paddle/fluid/operators/metrics/accuracy_op.cc
+2
-3
paddle/fluid/operators/metrics/auc_op.cc
paddle/fluid/operators/metrics/auc_op.cc
+2
-3
paddle/fluid/operators/metrics/precision_recall_op.cc
paddle/fluid/operators/metrics/precision_recall_op.cc
+2
-3
paddle/fluid/operators/multiplex_op.cc
paddle/fluid/operators/multiplex_op.cc
+4
-6
paddle/fluid/operators/nce_op.cc
paddle/fluid/operators/nce_op.cc
+4
-6
paddle/fluid/operators/optimizers/adadelta_op.cc
paddle/fluid/operators/optimizers/adadelta_op.cc
+2
-3
paddle/fluid/operators/optimizers/adagrad_op.cc
paddle/fluid/operators/optimizers/adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/adam_op.cc
paddle/fluid/operators/optimizers/adam_op.cc
+1
-2
paddle/fluid/operators/optimizers/adamax_op.cc
paddle/fluid/operators/optimizers/adamax_op.cc
+2
-3
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/ftrl_op.cc
paddle/fluid/operators/optimizers/ftrl_op.cc
+1
-2
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
+2
-3
paddle/fluid/operators/optimizers/proximal_gd_op.cc
paddle/fluid/operators/optimizers/proximal_gd_op.cc
+2
-3
paddle/fluid/operators/pad2d_op.cc
paddle/fluid/operators/pad2d_op.cc
+4
-4
paddle/fluid/operators/pad_constant_like_op.cc
paddle/fluid/operators/pad_constant_like_op.cc
+4
-6
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+3
-4
paddle/fluid/operators/pool_with_index_op.cc
paddle/fluid/operators/pool_with_index_op.cc
+4
-6
paddle/fluid/operators/positive_negative_pair_op.cc
paddle/fluid/operators/positive_negative_pair_op.cc
+2
-3
paddle/fluid/operators/prelu_op.cc
paddle/fluid/operators/prelu_op.cc
+4
-6
paddle/fluid/operators/print_op.cc
paddle/fluid/operators/print_op.cc
+1
-1
paddle/fluid/operators/psroi_pool_op.cc
paddle/fluid/operators/psroi_pool_op.cc
+4
-6
paddle/fluid/operators/random_crop_op.cc
paddle/fluid/operators/random_crop_op.cc
+2
-3
paddle/fluid/operators/reader/create_batch_reader_op.cc
paddle/fluid/operators/reader/create_batch_reader_op.cc
+2
-2
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+1
-1
paddle/fluid/operators/reshape_op.cc
paddle/fluid/operators/reshape_op.cc
+5
-9
paddle/fluid/operators/rnn_memory_helper_op.cc
paddle/fluid/operators/rnn_memory_helper_op.cc
+1
-1
paddle/fluid/operators/roi_align_op.cc
paddle/fluid/operators/roi_align_op.cc
+4
-6
paddle/fluid/operators/roi_pool_op.cc
paddle/fluid/operators/roi_pool_op.cc
+4
-6
paddle/fluid/operators/save_combine_op.cc
paddle/fluid/operators/save_combine_op.cc
+1
-1
paddle/fluid/operators/save_op.cc
paddle/fluid/operators/save_op.cc
+1
-1
paddle/fluid/operators/scatter_op.cc
paddle/fluid/operators/scatter_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
+2
-3
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
+4
-6
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
+2
-2
paddle/fluid/operators/similarity_focus_op.cc
paddle/fluid/operators/similarity_focus_op.cc
+2
-3
paddle/fluid/operators/slice_op.cc
paddle/fluid/operators/slice_op.cc
+2
-3
paddle/fluid/operators/softmax_op.cc
paddle/fluid/operators/softmax_op.cc
+3
-4
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+3
-5
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+6
-7
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
+1
-4
paddle/fluid/operators/transpose_op.cc
paddle/fluid/operators/transpose_op.cc
+3
-6
paddle/fluid/operators/unpool_op.cc
paddle/fluid/operators/unpool_op.cc
+4
-6
paddle/fluid/operators/warpctc_op.cc
paddle/fluid/operators/warpctc_op.cc
+4
-6
paddle/fluid/operators/yolov3_loss_op.cc
paddle/fluid/operators/yolov3_loss_op.cc
+4
-6
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+6
-5
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+1
-1
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+1
-1
未找到文件。
paddle/fluid/framework/data_layout_transform.cc
浏览文件 @
3628d894
...
@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
...
@@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out
->
mutable_data
(
expected_kernel_type
.
place_
,
in
.
type
());
out
->
mutable_data
(
expected_kernel_type
.
place_
,
in
.
type
());
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
in
.
type
()
),
in
.
type
(
),
CastDataLayout
(
pool
.
Get
(
expected_kernel_type
.
place_
),
axis
,
in
,
out
));
CastDataLayout
(
pool
.
Get
(
expected_kernel_type
.
place_
),
axis
,
in
,
out
));
out
->
set_layout
(
expected_kernel_type
.
data_layout_
);
out
->
set_layout
(
expected_kernel_type
.
data_layout_
);
...
@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
...
@@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case
mkldnn
::
memory
::
data_type
::
f32
:
case
mkldnn
::
memory
::
data_type
::
f32
:
return
platform
::
to_void_cast
(
tensor
.
data
<
float
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
float
>
());
case
mkldnn
::
memory
::
data_type
::
s8
:
case
mkldnn
::
memory
::
data_type
::
s8
:
return
platform
::
to_void_cast
(
tensor
.
data
<
char
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
int8_t
>
());
case
mkldnn
::
memory
::
data_type
::
u8
:
case
mkldnn
::
memory
::
data_type
::
u8
:
return
platform
::
to_void_cast
(
tensor
.
data
<
unsigned
char
>
());
return
platform
::
to_void_cast
(
tensor
.
data
<
unsigned
char
>
());
case
mkldnn
::
memory
::
data_type
::
s16
:
case
mkldnn
::
memory
::
data_type
::
s16
:
...
@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
...
@@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory
::
data_type
in_type
=
ToMKLDNNDataType
(
in
.
type
());
memory
::
data_type
in_type
=
ToMKLDNNDataType
(
in
.
type
());
PADDLE_ENFORCE
(
in_type
!=
memory
::
data_type
::
data_undef
,
PADDLE_ENFORCE
(
in_type
!=
memory
::
data_type
::
data_undef
,
"Input tensor type is not supported:
"
,
in
.
type
().
nam
e
());
"Input tensor type is not supported:
%s"
,
in
.
typ
e
());
memory
::
data_type
out_type
=
in_type
;
memory
::
data_type
out_type
=
in_type
;
auto
in_format
=
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
auto
in_format
=
platform
::
MKLDNNFormatForSize
(
in_tz
.
size
(),
in
.
format
());
...
...
paddle/fluid/framework/data_layout_transform.h
浏览文件 @
3628d894
...
@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
...
@@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
}
}
}
}
inline
MKLDNNDataType
ToMKLDNNDataType
(
const
std
::
type_index
type
)
{
inline
MKLDNNDataType
ToMKLDNNDataType
(
proto
::
VarType
::
Type
type
)
{
static
const
std
::
map
<
std
::
type_index
,
MKLDNNDataType
>
dict
{
static
std
::
unordered_map
<
int
,
MKLDNNDataType
>
dict
{
{
std
::
type_index
(
typeid
(
float
)),
MKLDNNDataType
::
f32
},
// NOLINT
{
DataTypeTrait
<
float
>::
DataType
,
MKLDNNDataType
::
f32
},
{
std
::
type_index
(
typeid
(
char
)),
MKLDNNDataType
::
s8
},
// NOLINT
{
DataTypeTrait
<
int8_t
>::
DataType
,
MKLDNNDataType
::
s8
},
{
std
::
type_index
(
typeid
(
unsigned
char
))
,
MKLDNNDataType
::
u8
},
{
DataTypeTrait
<
uint8_t
>::
DataType
,
MKLDNNDataType
::
u8
},
{
std
::
type_index
(
typeid
(
int16_t
))
,
MKLDNNDataType
::
s16
},
{
DataTypeTrait
<
int16_t
>::
DataType
,
MKLDNNDataType
::
s16
},
{
std
::
type_index
(
typeid
(
int32_t
))
,
MKLDNNDataType
::
s32
}};
{
DataTypeTrait
<
int32_t
>::
DataType
,
MKLDNNDataType
::
s32
}};
auto
iter
=
dict
.
find
(
type
);
auto
iter
=
dict
.
find
(
static_cast
<
int
>
(
type
)
);
if
(
iter
!=
dict
.
end
())
return
iter
->
second
;
if
(
iter
!=
dict
.
end
())
return
iter
->
second
;
return
MKLDNNDataType
::
data_undef
;
return
MKLDNNDataType
::
data_undef
;
}
}
...
...
paddle/fluid/framework/data_type.cc
浏览文件 @
3628d894
...
@@ -26,7 +26,7 @@ struct DataTypeMap {
...
@@ -26,7 +26,7 @@ struct DataTypeMap {
std
::
unordered_map
<
std
::
type_index
,
proto
::
VarType
::
Type
>
cpp_to_proto_
;
std
::
unordered_map
<
std
::
type_index
,
proto
::
VarType
::
Type
>
cpp_to_proto_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
type_index
>
proto_to_cpp_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
int
,
std
::
string
>
proto_to_str_
;
std
::
unordered_map
<
std
::
type_index
,
size_t
>
cpp
_to_size_
;
std
::
unordered_map
<
int
,
size_t
>
proto
_to_size_
;
};
};
static
DataTypeMap
*
InitDataTypeMap
();
static
DataTypeMap
*
InitDataTypeMap
();
...
@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
...
@@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
typeid
(
T
));
map
->
proto_to_cpp_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
typeid
(
T
));
map
->
cpp_to_proto_
.
emplace
(
typeid
(
T
),
proto_type
);
map
->
cpp_to_proto_
.
emplace
(
typeid
(
T
),
proto_type
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
proto_to_str_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
name
);
map
->
cpp_to_size_
.
emplace
(
typeid
(
T
),
sizeof
(
T
));
map
->
proto_to_size_
.
emplace
(
static_cast
<
int
>
(
proto_type
),
sizeof
(
T
));
}
}
static
DataTypeMap
*
InitDataTypeMap
()
{
static
DataTypeMap
*
InitDataTypeMap
()
{
...
@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
...
@@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \
#define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type)
RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here.
_ForEachDataType_
(
RegType
);
RegType
(
float16
,
proto
::
VarType
::
FP16
);
RegType
(
float
,
proto
::
VarType
::
FP32
);
RegType
(
double
,
proto
::
VarType
::
FP64
);
RegType
(
int
,
proto
::
VarType
::
INT32
);
RegType
(
int64_t
,
proto
::
VarType
::
INT64
);
RegType
(
bool
,
proto
::
VarType
::
BOOL
);
RegType
(
size_t
,
proto
::
VarType
::
SIZE_T
);
RegType
(
int16_t
,
proto
::
VarType
::
INT16
);
RegType
(
uint8_t
,
proto
::
VarType
::
UINT8
);
RegType
(
int8_t
,
proto
::
VarType
::
INT8
);
#undef RegType
#undef RegType
return
retv
;
return
retv
;
...
@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
...
@@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast
<
int
>
(
type
));
static_cast
<
int
>
(
type
));
}
}
size_t
SizeOfType
(
std
::
type_index
type
)
{
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
)
{
auto
it
=
gDataTypeMap
().
cpp_to_size_
.
find
(
type
);
auto
it
=
gDataTypeMap
().
proto_to_size_
.
find
(
static_cast
<
int
>
(
type
)
);
if
(
it
!=
gDataTypeMap
().
cpp
_to_size_
.
end
())
{
if
(
it
!=
gDataTypeMap
().
proto
_to_size_
.
end
())
{
return
it
->
second
;
return
it
->
second
;
}
}
PADDLE_THROW
(
"Not support %s as tensor type"
,
type
.
name
(
));
PADDLE_THROW
(
"Not support %s as tensor type"
,
DataTypeToString
(
type
));
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_type.h
浏览文件 @
3628d894
...
@@ -22,46 +22,59 @@ limitations under the License. */
...
@@ -22,46 +22,59 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
typename
T
>
struct
DataTypeTrait
{};
// Stub handle for void
template
<
>
struct
DataTypeTrait
<
void
>
{
constexpr
static
auto
DataType
=
proto
::
VarType
::
RAW
;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_
(
DefineDataTypeTrait
);
#undef DefineDataTypeTrait
extern
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
);
extern
proto
::
VarType
::
Type
ToDataType
(
std
::
type_index
type
);
extern
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
);
extern
std
::
type_index
ToTypeIndex
(
proto
::
VarType
::
Type
type
);
template
<
typename
Visitor
>
template
<
typename
Visitor
>
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
inline
void
VisitDataType
(
proto
::
VarType
::
Type
type
,
Visitor
visitor
)
{
switch
(
type
)
{
#define VisitDataTypeCallback(cpp_type, proto_type) \
case
proto
::
VarType
::
FP16
:
do { \
visitor
.
template
apply
<
platform
::
float16
>();
if (type == proto_type) { \
break
;
visitor.template apply<cpp_type>(); \
case
proto
::
VarType
::
FP32
:
return; \
visitor
.
template
apply
<
float
>();
} \
break
;
} while (0)
case
proto
::
VarType
::
FP64
:
visitor
.
template
apply
<
double
>();
_ForEachDataType_
(
VisitDataTypeCallback
);
break
;
#undef VisitDataTypeCallback
case
proto
::
VarType
::
INT32
:
visitor
.
template
apply
<
int
>();
break
;
case
proto
::
VarType
::
INT64
:
visitor
.
template
apply
<
int64_t
>();
break
;
case
proto
::
VarType
::
BOOL
:
visitor
.
template
apply
<
bool
>();
break
;
case
proto
::
VarType
::
UINT8
:
visitor
.
template
apply
<
uint8_t
>();
break
;
case
proto
::
VarType
::
INT16
:
visitor
.
template
apply
<
int16_t
>();
break
;
case
proto
::
VarType
::
INT8
:
visitor
.
template
apply
<
int8_t
>();
break
;
default:
PADDLE_THROW
(
"Not supported %d"
,
type
);
PADDLE_THROW
(
"Not supported %d"
,
type
);
}
}
}
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
extern
std
::
string
DataTypeToString
(
const
proto
::
VarType
::
Type
type
);
extern
size_t
SizeOfType
(
std
::
type_index
type
);
extern
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
proto
::
VarType
::
Type
&
type
)
{
const
proto
::
VarType
::
Type
&
type
)
{
out
<<
DataTypeToString
(
type
);
out
<<
DataTypeToString
(
type
);
...
...
paddle/fluid/framework/data_type_test.cc
浏览文件 @
3628d894
...
@@ -26,15 +26,15 @@ TEST(DataType, float16) {
...
@@ -26,15 +26,15 @@ TEST(DataType, float16) {
Tensor
tensor
;
Tensor
tensor
;
CPUPlace
cpu
;
CPUPlace
cpu
;
tensor
.
mutable_data
(
cpu
,
f
::
ToTypeIndex
(
dtype
)
);
tensor
.
mutable_data
(
cpu
,
dtype
);
// test fp16 tensor
// test fp16 tensor
EXPECT_EQ
(
tensor
.
type
(),
std
::
type_index
(
typeid
(
float16
)));
EXPECT_EQ
(
tensor
.
type
(),
f
::
ToDataType
(
typeid
(
float16
)));
// test fp16 size
// test fp16 size
EXPECT_EQ
(
f
::
SizeOfType
(
f
::
ToTypeIndex
(
dtype
)
),
2u
);
EXPECT_EQ
(
f
::
SizeOfType
(
dtype
),
2u
);
// test debug info
// test debug info
std
::
string
type
=
"float16"
;
std
::
string
type
=
"
::paddle::platform::
float16"
;
EXPECT_STREQ
(
f
::
DataTypeToString
(
dtype
).
c_str
(),
type
.
c_str
());
EXPECT_STREQ
(
f
::
DataTypeToString
(
dtype
).
c_str
(),
type
.
c_str
());
}
}
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
3628d894
...
@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
...
@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU
// Reduce All Tensor to trg in CPU
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
scope
=
auto
&
scope
=
...
...
paddle/fluid/framework/details/fuse_vars_op_handle.h
浏览文件 @
3628d894
...
@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
...
@@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle
(
ir
::
Node
*
node
,
Scope
*
local_scope
,
FuseVarsOpHandle
(
ir
::
Node
*
node
,
Scope
*
local_scope
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
&
inputs_numel
,
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
&
inputs_numel
,
const
std
::
type_index
&
var_type
)
const
proto
::
VarType
::
Type
var_type
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
local_scope_
(
local_scope
),
local_scope_
(
local_scope
),
place_
(
place
),
place_
(
place
),
...
@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
...
@@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope
*
local_scope_
;
Scope
*
local_scope_
;
const
platform
::
Place
place_
;
const
platform
::
Place
place_
;
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
inputs_numel_
;
const
std
::
unordered_map
<
std
::
string
,
int64_t
>
inputs_numel_
;
const
std
::
type_index
type_
;
const
proto
::
VarType
::
Type
type_
;
int64_t
total_numel_
;
int64_t
total_numel_
;
};
};
}
// namespace details
}
// namespace details
...
...
paddle/fluid/framework/details/reduce_op_handle.cc
浏览文件 @
3628d894
...
@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
...
@@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() {
}
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if
(
framework
::
IsType
<
const
float
>
(
in_selected_rows
[
0
]
->
value
().
type
()))
{
if
(
in_selected_rows
[
0
]
->
value
().
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
float
>
(
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
float
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
if
(
framework
::
IsType
<
const
double
>
(
}
else
if
(
in_selected_rows
[
0
]
->
value
().
type
()
==
in_selected_rows
[
0
]
->
value
().
type
())
)
{
framework
::
proto
::
VarType
::
FP64
)
{
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
double
>
(
GatherSelectedRows
<
platform
::
CUDADeviceContext
,
double
>
(
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
in_selected_rows
,
in_places
,
dev_ctxes_
,
out_var_handle
,
t_out_p
,
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
out_var
->
GetMutable
<
framework
::
SelectedRows
>
());
}
else
{
}
else
{
PADDLE_ENFORCE
(
false
,
PADDLE_THROW
(
"only support double or float when gather SelectedRows"
);
"only support double or float when gahter SelectedRows"
);
}
}
#endif
#endif
});
});
...
@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if
(
!
FLAGS_cpu_deterministic
)
{
if
(
!
FLAGS_cpu_deterministic
)
{
ReduceLoDTensor
func
(
lod_tensors
,
ReduceLoDTensor
func
(
lod_tensors
,
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
out_var
->
GetMutable
<
framework
::
LoDTensor
>
());
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
}
else
{
}
else
{
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
...
@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
...
@@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->
FindVar
(
out_var_handle
->
name_
)
->
FindVar
(
out_var_handle
->
name_
)
->
GetMutable
<
framework
::
LoDTensor
>
();
->
GetMutable
<
framework
::
LoDTensor
>
();
ReduceLoDTensor
func
(
lod_tensors
,
&
reduce_sum_trg
);
ReduceLoDTensor
func
(
lod_tensors
,
&
reduce_sum_trg
);
VisitDataType
(
ToDataType
(
lod_tensors
[
0
]
->
type
()
),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(
),
func
);
auto
trg
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
trg
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
if
(
reduce_sum_trg
.
data
<
void
>
()
!=
trg
->
data
<
void
>
())
{
if
(
reduce_sum_trg
.
data
<
void
>
()
!=
trg
->
data
<
void
>
())
{
...
...
paddle/fluid/framework/dlpack_tensor.cc
浏览文件 @
3628d894
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
...
@@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return
dtype
;
return
dtype
;
}
}
static
DLDataType
GetDLDataTypeFromTypeIndex
(
const
std
::
type_index
&
type
)
{
static
std
::
unordered_map
<
int
,
::
DLDataType
>
CreateDLDataTypeMap
()
{
#define REG_DL_DATA_TYPE(type) \
static
std
::
unordered_map
<
int
,
::
DLDataType
>
result
;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static
const
std
::
unordered_map
<
std
::
type_index
,
::
DLDataType
>
#define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map
({
result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE
(
platform
::
float16
),
// NOLINT
REG_DL_DATA_TYPE
(
float
),
// NOLINT
_ForEachDataType_
(
REG_DL_DATA_TYPE
);
REG_DL_DATA_TYPE
(
double
),
// NOLINT
#undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE
(
int
),
// NOLINT
return
result
;
REG_DL_DATA_TYPE
(
int64_t
),
// NOLINT
}
REG_DL_DATA_TYPE
(
bool
),
// NOLINT
REG_DL_DATA_TYPE
(
size_t
),
// NOLINT
static
DLDataType
GetDLDataTypeFromTypeIndex
(
proto
::
VarType
::
Type
type
)
{
REG_DL_DATA_TYPE
(
int16_t
),
// NOLINT
static
auto
type_to_dtype_map
=
CreateDLDataTypeMap
();
REG_DL_DATA_TYPE
(
uint8_t
),
// NOLINT
REG_DL_DATA_TYPE
(
int8_t
)
// NOLINT
});
static
auto
type_to_dtype_map_end_it
=
type_to_dtype_map
.
end
();
static
auto
type_to_dtype_map_end_it
=
type_to_dtype_map
.
end
();
auto
it
=
type_to_dtype_map
.
find
(
type
);
auto
it
=
type_to_dtype_map
.
find
(
static_cast
<
int
>
(
type
)
);
PADDLE_ENFORCE
(
it
!=
type_to_dtype_map_end_it
,
"Unsupported data type %
s
"
,
PADDLE_ENFORCE
(
it
!=
type_to_dtype_map_end_it
,
"Unsupported data type %
d
"
,
type
.
name
()
);
type
);
return
it
->
second
;
return
it
->
second
;
#undef REG_DL_DATA_TYPE
#undef REG_DL_DATA_TYPE
}
}
...
...
paddle/fluid/framework/dlpack_tensor_test.cc
浏览文件 @
3628d894
...
@@ -91,23 +91,11 @@ void TestMainLoop() {
...
@@ -91,23 +91,11 @@ void TestMainLoop() {
}
}
}
}
}
}
TEST
(
dlpack
,
test_all
)
{
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \
_ForEachDataType_
(
TestCallback
);
TEST(dlpack, test_##type) { TestMainLoop<type>(); }
}
using
float16
=
platform
::
float16
;
PADDLE_DLPACK_TEST
(
float16
);
PADDLE_DLPACK_TEST
(
float
);
PADDLE_DLPACK_TEST
(
double
);
PADDLE_DLPACK_TEST
(
int
);
PADDLE_DLPACK_TEST
(
int64_t
);
PADDLE_DLPACK_TEST
(
bool
);
PADDLE_DLPACK_TEST
(
size_t
);
PADDLE_DLPACK_TEST
(
int16_t
);
PADDLE_DLPACK_TEST
(
uint8_t
);
PADDLE_DLPACK_TEST
(
int8_t
);
#undef PADDLE_DLPACK_TEST
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/executor_thread_worker.cc
浏览文件 @
3628d894
...
@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
...
@@ -139,39 +139,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std
::
cout
<<
sstream
.
str
()
<<
std
::
endl
;
std
::
cout
<<
sstream
.
str
()
<<
std
::
endl
;
}
}
void
print_fetch_var
(
Scope
*
scope
,
std
::
string
var_name
)
{
static
void
print_fetch_var
(
Scope
*
scope
,
const
std
::
string
&
var_name
)
{
const
LoDTensor
&
tensor
=
scope
->
FindVar
(
var_name
)
->
Get
<
LoDTensor
>
();
auto
&
tensor
=
scope
->
FindVar
(
var_name
)
->
Get
<
LoDTensor
>
();
if
(
std
::
type_index
(
tensor
.
type
())
==
#define PrintLoDTensorCallback(cpp_type, proto_type) \
std
::
type_index
(
typeid
(
platform
::
float16
)))
{
do { \
print_lod_tensor
<
platform
::
float16
>
(
var_name
,
tensor
);
if (tensor.type() == proto_type) { \
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
float
)))
{
print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor
<
float
>
(
var_name
,
tensor
);
return; \
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
} \
std
::
type_index
(
typeid
(
double
)))
{
} while (0)
print_lod_tensor
<
double
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
_ForEachDataType_
(
PrintLoDTensorCallback
);
print_lod_tensor
<
int
>
(
var_name
,
tensor
);
VLOG
(
1
)
<<
"print_fetch_var: unrecognized data type:"
<<
tensor
.
type
();
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int64_t
)))
{
print_lod_tensor
<
int64_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
bool
)))
{
print_lod_tensor
<
bool
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
uint8_t
)))
{
print_lod_tensor
<
uint8_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int16_t
)))
{
print_lod_tensor
<
int16_t
>
(
var_name
,
tensor
);
}
else
if
(
std
::
type_index
(
tensor
.
type
())
==
std
::
type_index
(
typeid
(
int8_t
)))
{
print_lod_tensor
<
int8_t
>
(
var_name
,
tensor
);
}
else
{
VLOG
(
1
)
<<
"print_fetch_var: unrecognized data type:"
<<
tensor
.
type
().
name
();
}
return
;
}
}
void
ExecutorThreadWorker
::
TrainFiles
()
{
void
ExecutorThreadWorker
::
TrainFiles
()
{
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
3628d894
...
@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
...
@@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements
// only print first ten elements
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
int64_t
size
=
t
.
numel
()
<
10
?
t
.
numel
()
:
10
;
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
if
(
IsType
<
float
>
(
t
.
type
())
)
{
if
(
t
.
type
()
==
proto
::
VarType
::
FP32
)
{
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
os
<<
t
.
data
<
float
>
()[
i
]
<<
" "
;
}
else
if
(
IsType
<
int64_t
>
(
t
.
type
())
)
{
}
else
if
(
t
.
type
()
==
proto
::
VarType
::
INT64
)
{
os
<<
t
.
data
<
int64_t
>
()[
i
]
<<
" "
;
os
<<
t
.
data
<
int64_t
>
()[
i
]
<<
" "
;
}
else
{
}
else
{
PADDLE_THROW
(
"LoDTensor data type not in [float, int64_t]"
);
PADDLE_THROW
(
"LoDTensor data type not in [float, int64_t]"
);
...
@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
...
@@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE
(
!
lod_tensors
.
empty
());
PADDLE_ENFORCE
(
!
lod_tensors
.
empty
());
framework
::
DDim
new_dim
=
lod_tensors
[
0
]
->
dims
();
framework
::
DDim
new_dim
=
lod_tensors
[
0
]
->
dims
();
std
::
type_index
new_type
=
lod_tensors
[
0
]
->
type
();
auto
new_type
=
lod_tensors
[
0
]
->
type
();
framework
::
DataLayout
new_layout
=
lod_tensors
[
0
]
->
layout
();
framework
::
DataLayout
new_layout
=
lod_tensors
[
0
]
->
layout
();
LoD
new_lod
=
lod_tensors
[
0
]
->
lod
();
LoD
new_lod
=
lod_tensors
[
0
]
->
lod
();
for
(
size_t
i
=
1
;
i
<
lod_tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
lod_tensors
.
size
();
++
i
)
{
...
...
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
3628d894
...
@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
...
@@ -471,27 +471,23 @@ void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const {
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
PADDLE_ENFORCE
(
sp
==
Ddim2Shape
(
tensor_pd
->
dims
()),
PADDLE_ENFORCE
(
sp
==
Ddim2Shape
(
tensor_pd
->
dims
()),
"Ensure ngraph tensor layout align with paddle tensor"
);
"Ensure ngraph tensor layout align with paddle tensor"
);
if
(
tensor_pd
->
type
().
hash_code
()
==
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
FP32
)
{
typeid
(
float
).
hash_code
())
{
// NOLINT
const
float
*
arr
=
tensor_pd
->
data
<
float
>
();
const
float
*
arr
=
tensor_pd
->
data
<
float
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f32
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f32
,
sp
,
const_cast
<
float
*>
(
arr
));
const_cast
<
float
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
INT32
)
{
typeid
(
int
).
hash_code
())
{
// NOLINT
const
int
*
arr
=
tensor_pd
->
data
<
int
>
();
const
int
*
arr
=
tensor_pd
->
data
<
int
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i32
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i32
,
sp
,
const_cast
<
int
*>
(
arr
));
const_cast
<
int
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
()
.
hash_code
()
==
typeid
(
int64_t
).
hash_code
()
)
{
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
INT64
)
{
const
int64_t
*
arr
=
tensor_pd
->
data
<
int64_t
>
();
const
int64_t
*
arr
=
tensor_pd
->
data
<
int64_t
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i64
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
i64
,
sp
,
const_cast
<
int64_t
*>
(
arr
));
const_cast
<
int64_t
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
FP64
)
{
typeid
(
double
).
hash_code
())
{
// NOLINT
const
double
*
arr
=
tensor_pd
->
data
<
double
>
();
const
double
*
arr
=
tensor_pd
->
data
<
double
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f64
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
f64
,
sp
,
const_cast
<
double
*>
(
arr
));
const_cast
<
double
*>
(
arr
));
}
else
if
(
tensor_pd
->
type
().
hash_code
()
==
}
else
if
(
tensor_pd
->
type
()
==
proto
::
VarType
::
BOOL
)
{
typeid
(
bool
).
hash_code
())
{
// NOLINT
const
bool
*
arr
=
tensor_pd
->
data
<
bool
>
();
const
bool
*
arr
=
tensor_pd
->
data
<
bool
>
();
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
boolean
,
sp
,
ti
=
backend_
->
create_tensor
(
ngraph
::
element
::
boolean
,
sp
,
const_cast
<
bool
*>
(
arr
));
const_cast
<
bool
*>
(
arr
));
...
...
paddle/fluid/framework/op_kernel_type_test.cc
浏览文件 @
3628d894
...
@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
...
@@ -34,7 +34,8 @@ TEST(OpKernelType, ToString) {
OpKernelType
op_kernel_type2
(
DataType
::
FP16
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
OpKernelType
op_kernel_type2
(
DataType
::
FP16
,
CUDAPlace
(
0
),
DataLayout
::
kNCHW
,
LibraryType
::
kCUDNN
);
LibraryType
::
kCUDNN
);
ASSERT_EQ
(
paddle
::
framework
::
KernelTypeToString
(
op_kernel_type2
),
ASSERT_EQ
(
paddle
::
framework
::
KernelTypeToString
(
op_kernel_type2
),
"data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_"
"data_type[::paddle::platform::float16]:data_layout[NCHW]:place["
"CUDAPlace(0)]:library_"
"type[CUDNN]"
);
"type[CUDNN]"
);
}
}
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
3628d894
...
@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
...
@@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
)
{
proto
::
VarType
::
Type
GetDataTypeOfVar
(
const
Variable
*
var
)
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
ToDataType
(
var
->
Get
<
framework
::
LoDTensor
>
().
type
()
);
return
var
->
Get
<
framework
::
LoDTensor
>
().
type
(
);
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
ToDataType
(
return
var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
();
var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
());
}
else
{
}
else
{
PADDLE_THROW
(
"Var should be LoDTensor or SelectedRows"
);
PADDLE_THROW
(
"Var should be LoDTensor or SelectedRows"
);
}
}
...
@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
...
@@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
return
""
;
return
""
;
}
}
return
DataTypeToString
(
ToDataType
(
tensor
.
type
()
));
return
DataTypeToString
(
tensor
.
type
(
));
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
auto
tensor
=
var
->
Get
<
SelectedRows
>
().
value
();
auto
tensor
=
var
->
Get
<
SelectedRows
>
().
value
();
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
if
(
UNLIKELY
(
!
tensor
.
IsInitialized
()))
{
return
"uninited"
;
return
"uninited"
;
}
else
{
}
else
{
return
DataTypeToString
(
ToDataType
(
tensor
.
type
()
));
return
DataTypeToString
(
tensor
.
type
(
));
}
}
}
else
{
}
else
{
return
""
;
return
""
;
...
@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
...
@@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if
(
tensor
.
memory_size
()
==
0
)
{
if
(
tensor
.
memory_size
()
==
0
)
{
return
;
return
;
}
}
if
(
!
IsType
<
float
>
(
tensor
.
type
())
&&
!
IsType
<
double
>
(
tensor
.
type
()))
{
if
(
tensor
.
type
()
!=
proto
::
VarType
::
FP32
&&
tensor
.
type
()
!=
proto
::
VarType
::
FP64
)
{
return
;
return
;
}
}
PADDLE_ENFORCE
(
!
framework
::
TensorContainsInf
(
tensor
),
PADDLE_ENFORCE
(
!
framework
::
TensorContainsInf
(
tensor
),
...
@@ -881,7 +881,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
...
@@ -881,7 +881,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
if
(
t
!=
nullptr
)
{
if
(
t
!=
nullptr
)
{
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s is not initialized: %s"
,
PADDLE_ENFORCE
(
t
->
IsInitialized
(),
"Input %s is not initialized: %s"
,
ipt_name
,
DebugString
());
ipt_name
,
DebugString
());
int
tmp
=
static_cast
<
int
>
(
ToDataType
(
t
->
type
()
));
int
tmp
=
static_cast
<
int
>
(
t
->
type
(
));
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
tmp
==
data_type
||
data_type
==
-
1
,
tmp
==
data_type
||
data_type
==
-
1
,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)"
,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)"
,
...
...
paddle/fluid/framework/selected_rows.cc
浏览文件 @
3628d894
...
@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
...
@@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if
(
index
<
0
)
{
if
(
index
<
0
)
{
VLOG
(
5
)
<<
"id "
<<
id
<<
" not in the table, return 0"
;
VLOG
(
5
)
<<
"id "
<<
id
<<
" not in the table, return 0"
;
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()
),
value_
->
type
(
),
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
}
else
{
}
else
{
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()
),
value_
->
type
(
),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
index
*
value_width
,
value_width
));
}
}
...
...
paddle/fluid/framework/tensor.cc
浏览文件 @
3628d894
...
@@ -16,7 +16,7 @@ limitations under the License. */
...
@@ -16,7 +16,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
extern
size_t
SizeOfType
(
std
::
type_index
type
);
extern
size_t
SizeOfType
(
proto
::
VarType
::
Type
type
);
void
Tensor
::
check_memory_size
()
const
{
void
Tensor
::
check_memory_size
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
holder_
,
"Tensor holds no memory. Call Tensor::mutable_data first."
);
...
@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
...
@@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
}
}
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
proto
::
VarType
::
Type
type
,
memory
::
Allocator
::
Attr
attr
,
memory
::
Allocator
::
Attr
attr
,
size_t
requested_size
)
{
size_t
requested_size
)
{
type_
=
type
;
type_
=
type
;
...
...
paddle/fluid/framework/tensor.h
浏览文件 @
3628d894
...
@@ -19,9 +19,9 @@ limitations under the License. */
...
@@ -19,9 +19,9 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <typeindex>
#include <typeindex>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -67,7 +67,7 @@ class Tensor {
...
@@ -67,7 +67,7 @@ class Tensor {
friend
struct
EigenVector
;
friend
struct
EigenVector
;
public:
public:
Tensor
()
:
type_
(
typeid
(
float
)
),
offset_
(
0
)
{}
Tensor
()
:
type_
(
proto
::
VarType
::
FP32
),
offset_
(
0
)
{}
/*! Return a pointer to mutable memory block. */
/*! Return a pointer to mutable memory block. */
template
<
typename
T
>
template
<
typename
T
>
...
@@ -88,7 +88,7 @@ class Tensor {
...
@@ -88,7 +88,7 @@ class Tensor {
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
size_t
requested_size
=
0
);
size_t
requested_size
=
0
);
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
void
*
mutable_data
(
platform
::
Place
place
,
proto
::
VarType
::
Type
type
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
memory
::
Allocator
::
Attr
attr
=
memory
::
Allocator
::
kDefault
,
size_t
requested_size
=
0
);
size_t
requested_size
=
0
);
...
@@ -138,7 +138,7 @@ class Tensor {
...
@@ -138,7 +138,7 @@ class Tensor {
return
holder_
->
place
();
return
holder_
->
place
();
}
}
std
::
type_index
type
()
const
{
proto
::
VarType
::
Type
type
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
holder_
,
"Tensor not initialized yet when Tensor::type() is called."
);
holder_
,
"Tensor not initialized yet when Tensor::type() is called."
);
return
type_
;
return
type_
;
...
@@ -165,7 +165,7 @@ class Tensor {
...
@@ -165,7 +165,7 @@ class Tensor {
private:
private:
/*! holds the memory block if allocated. */
/*! holds the memory block if allocated. */
std
::
shared_ptr
<
memory
::
Allocation
>
holder_
;
std
::
shared_ptr
<
memory
::
Allocation
>
holder_
;
std
::
type_index
type_
;
proto
::
VarType
::
Type
type_
;
/**
/**
* @brief points to elements dimensions.
* @brief points to elements dimensions.
*
*
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
3628d894
...
@@ -24,9 +24,8 @@ template <typename T>
...
@@ -24,9 +24,8 @@ template <typename T>
inline
const
T
*
Tensor
::
data
()
const
{
inline
const
T
*
Tensor
::
data
()
const
{
check_memory_size
();
check_memory_size
();
bool
valid
=
bool
valid
=
std
::
is_same
<
T
,
void
>::
value
||
type_
==
std
::
type_index
(
typeid
(
T
));
std
::
is_same
<
T
,
void
>::
value
||
type_
==
DataTypeTrait
<
T
>::
DataType
;
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %d"
,
type_
);
type_
.
name
());
return
reinterpret_cast
<
const
T
*>
(
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
...
@@ -38,9 +37,8 @@ template <typename T>
...
@@ -38,9 +37,8 @@ template <typename T>
inline
T
*
Tensor
::
data
()
{
inline
T
*
Tensor
::
data
()
{
check_memory_size
();
check_memory_size
();
bool
valid
=
bool
valid
=
std
::
is_same
<
T
,
void
>::
value
||
type_
==
std
::
type_index
(
typeid
(
T
));
std
::
is_same
<
T
,
void
>::
value
||
type_
==
DataTypeTrait
<
T
>::
DataType
;
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
PADDLE_ENFORCE
(
valid
,
"Tensor holds the wrong type, it holds %s"
,
type_
);
type_
.
name
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
offset_
);
}
}
...
@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
...
@@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t
requested_size
)
{
size_t
requested_size
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)
,
attr
,
requested_size
));
mutable_data
(
place
,
DataTypeTrait
<
T
>::
DataType
,
attr
,
requested_size
));
}
}
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
...
...
paddle/fluid/framework/tensor_util.cc
浏览文件 @
3628d894
...
@@ -186,7 +186,7 @@ struct AnyDTypeVisitor {
...
@@ -186,7 +186,7 @@ struct AnyDTypeVisitor {
template
<
typename
Predicate
,
typename
DevCtx
>
template
<
typename
Predicate
,
typename
DevCtx
>
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
inline
void
AnyImpl
(
Predicate
predicate
,
const
framework
::
Tensor
&
tensor
,
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
const
DevCtx
&
ctx
,
framework
::
Tensor
*
out
)
{
VisitDataType
(
ToDataType
(
tensor
.
type
()
),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
VisitDataType
(
tensor
.
type
(
),
AnyDTypeVisitor
<
Predicate
,
DevCtx
>
(
predicate
,
tensor
,
ctx
,
out
));
predicate
,
tensor
,
ctx
,
out
));
}
}
...
@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
...
@@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size
// int32_t size
// void* protobuf message
// void* protobuf message
proto
::
VarType
::
TensorDesc
desc
;
proto
::
VarType
::
TensorDesc
desc
;
desc
.
set_data_type
(
framework
::
ToDataType
(
tensor
.
type
()
));
desc
.
set_data_type
(
tensor
.
type
(
));
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
dims
=
framework
::
vectorize
(
tensor
.
dims
());
auto
*
pb_dims
=
desc
.
mutable_dims
();
auto
*
pb_dims
=
desc
.
mutable_dims
();
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
pb_dims
->
Resize
(
static_cast
<
int
>
(
dims
.
size
()),
0
);
...
@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
...
@@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
void
*
buf
;
void
*
buf
;
auto
ctx
=
platform
::
CPUDeviceContext
();
auto
ctx
=
platform
::
CPUDeviceContext
();
size_t
size
=
size_t
size
=
tensor
->
numel
()
*
framework
::
SizeOfType
(
desc
.
data_type
());
tensor
->
numel
()
*
framework
::
SizeOfType
(
framework
::
ToTypeIndex
(
desc
.
data_type
()));
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
dev_ctx
.
GetPlace
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
Tensor
cpu_tensor
;
Tensor
cpu_tensor
;
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
3628d894
...
@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto
type
=
fetch
.
type
();
auto
type
=
fetch
.
type
();
auto
output
=
&
(
outputs
->
at
(
i
));
auto
output
=
&
(
outputs
->
at
(
i
));
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
if
(
type
==
typeid
(
float
)
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
GetFetchOne
<
float
>
(
fetch
,
output
);
GetFetchOne
<
float
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
FLOAT32
;
output
->
dtype
=
PaddleDType
::
FLOAT32
;
}
else
if
(
type
==
typeid
(
int64_t
)
)
{
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
INT64
;
output
->
dtype
=
PaddleDType
::
INT64
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/api/api_impl.cc
浏览文件 @
3628d894
...
@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
...
@@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto
type
=
fetch
.
type
();
auto
type
=
fetch
.
type
();
auto
output
=
&
(
outputs
->
at
(
i
));
auto
output
=
&
(
outputs
->
at
(
i
));
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
output
->
name
=
fetchs_
[
idx
]
->
Input
(
"X"
)[
0
];
if
(
type
==
typeid
(
float
)
)
{
if
(
type
==
framework
::
DataTypeTrait
<
float
>::
DataType
)
{
GetFetchOne
<
float
>
(
fetch
,
output
);
GetFetchOne
<
float
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
FLOAT32
;
output
->
dtype
=
PaddleDType
::
FLOAT32
;
}
else
if
(
type
==
typeid
(
int64_t
)
)
{
}
else
if
(
type
==
framework
::
DataTypeTrait
<
int64_t
>::
DataType
)
{
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
GetFetchOne
<
int64_t
>
(
fetch
,
output
);
output
->
dtype
=
PaddleDType
::
INT64
;
output
->
dtype
=
PaddleDType
::
INT64
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/api/api_impl_tester.cc
浏览文件 @
3628d894
...
@@ -36,10 +36,10 @@ namespace paddle {
...
@@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor
LodTensorToPaddleTensor
(
framework
::
LoDTensor
*
t
)
{
PaddleTensor
LodTensorToPaddleTensor
(
framework
::
LoDTensor
*
t
)
{
PaddleTensor
pt
;
PaddleTensor
pt
;
if
(
t
->
type
()
==
typeid
(
int64_t
)
)
{
if
(
t
->
type
()
==
framework
::
proto
::
VarType
::
INT64
)
{
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
int64_t
));
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
int64_t
));
pt
.
dtype
=
PaddleDType
::
INT64
;
pt
.
dtype
=
PaddleDType
::
INT64
;
}
else
if
(
t
->
type
()
==
typeid
(
float
)
)
{
}
else
if
(
t
->
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
float
));
pt
.
data
.
Reset
(
t
->
data
<
void
>
(),
t
->
numel
()
*
sizeof
(
float
));
pt
.
dtype
=
PaddleDType
::
FLOAT32
;
pt
.
dtype
=
PaddleDType
::
FLOAT32
;
}
else
{
}
else
{
...
...
paddle/fluid/inference/tests/api/tester_helper.h
浏览文件 @
3628d894
...
@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
...
@@ -373,7 +373,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
}
}
for
(
size_t
i
=
0
;
i
<
a_size
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
a_size
;
i
++
)
{
if
(
a
.
type
()
==
typeid
(
float
)
)
{
if
(
a
.
type
()
==
framework
::
proto
::
VarType
::
FP32
)
{
const
auto
*
a_data
=
a
.
data
<
float
>
();
const
auto
*
a_data
=
a
.
data
<
float
>
();
const
auto
*
b_data
=
b
.
data
<
float
>
();
const
auto
*
b_data
=
b
.
data
<
float
>
();
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
...
@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
...
@@ -382,7 +382,7 @@ static bool CompareTensorData(const framework::LoDTensor &a,
b_data
[
i
]);
b_data
[
i
]);
return
false
;
return
false
;
}
}
}
else
if
(
a
.
type
()
==
typeid
(
int64_t
)
)
{
}
else
if
(
a
.
type
()
==
framework
::
proto
::
VarType
::
INT64
)
{
const
auto
*
a_data
=
a
.
data
<
int64_t
>
();
const
auto
*
a_data
=
a
.
data
<
int64_t
>
();
const
auto
*
b_data
=
b
.
data
<
int64_t
>
();
const
auto
*
b_data
=
b
.
data
<
int64_t
>
();
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
if
(
std
::
abs
(
a_data
[
i
]
-
b_data
[
i
])
>
1e-3
)
{
...
...
paddle/fluid/operators/affine_grid_op.cc
浏览文件 @
3628d894
...
@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
...
@@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library
=
framework
::
LibraryType
::
kCUDNN
;
library
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
auto
data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
()
);
auto
data_type
=
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
(
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
return
framework
::
OpKernelType
(
data_type
,
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library
);
framework
::
DataLayout
::
kAnyLayout
,
library
);
}
}
...
@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
...
@@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Theta"
)
->
type
()
),
ctx
.
GetPlace
(
),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
}
};
};
...
...
paddle/fluid/operators/arg_max_op.cc
浏览文件 @
3628d894
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_max_op.cu
浏览文件 @
3628d894
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_min_op.cc
浏览文件 @
3628d894
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/arg_min_op.cu
浏览文件 @
3628d894
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
...
@@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t
>
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int16_t
>
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
size_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
uint8_t
>
);
uint8_t
>
);
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
3628d894
...
@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
...
@@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor<void> {
ArrayToLoDFunctorImpl
<
DeviceContext
>
functor
;
ArrayToLoDFunctorImpl
<
DeviceContext
>
functor
;
functor
.
dev_ctx_
=
dev_ctx
;
functor
.
dev_ctx_
=
dev_ctx
;
functor
.
prev_functor_
=
this
;
functor
.
prev_functor_
=
this
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
out
->
type
()
),
functor
);
framework
::
VisitDataType
(
out
->
type
(
),
functor
);
}
}
};
};
...
@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
...
@@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
PADDLE_ENFORCE
(
!
x
.
empty
(),
"There's no element in the input array."
);
PADDLE_ENFORCE
(
!
x
.
empty
(),
"There's no element in the input array."
);
int
rank
=
x
[
0
].
dims
().
size
();
int
rank
=
x
[
0
].
dims
().
size
();
platform
::
Place
place
=
x
[
0
].
place
();
platform
::
Place
place
=
x
[
0
].
place
();
std
::
type_index
data_type
=
x
[
0
].
type
();
auto
data_type
=
x
[
0
].
type
();
int64_t
batch_size
=
x
[
0
].
dims
()[
0
];
int64_t
batch_size
=
x
[
0
].
dims
()[
0
];
framework
::
DDim
ins_dims
=
rank
>
1
framework
::
DDim
ins_dims
=
rank
>
1
?
framework
::
slice_ddim
(
x
[
0
].
dims
(),
1
,
rank
)
?
framework
::
slice_ddim
(
x
[
0
].
dims
(),
1
,
rank
)
...
...
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
3628d894
...
@@ -121,8 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -121,8 +121,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
AttentionLSTMOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
AttentionLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/average_accumulates_op.cc
浏览文件 @
3628d894
...
@@ -103,8 +103,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
...
@@ -103,8 +103,7 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
3628d894
...
@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
// By default, the type of the scale, bias, mean,
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
// or double (For double input tensor).
...
@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP64
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP64
)
{
bn_param_type
=
framework
::
proto
::
VarType
::
FP64
;
bn_param_type
=
framework
::
proto
::
VarType
::
FP64
;
}
}
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
()),
"Scale input should be of float type"
);
"Scale input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
()),
"Bias input should be of float type"
);
"Bias input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Mean"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Mean"
)
->
type
()),
"Mean input should be of float type"
);
"Mean input should be of float type"
);
PADDLE_ENFORCE_EQ
(
bn_param_type
,
framework
::
ToDataType
(
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
(),
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
"Variance input should be of float type"
);
"Variance input should be of float type"
);
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...
@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
...
@@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout
,
library
);
layout
,
library
);
}
}
};
};
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
3628d894
...
@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
scores
->
at
(
0
).
type
()
),
scores
->
at
(
0
).
type
(
),
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
,
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
,
beam_size
,
end_id
));
beam_size
,
end_id
));
}
}
...
...
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
3628d894
...
@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
...
@@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
return
kt
;
return
kt
;
}
}
...
...
paddle/fluid/operators/bpr_loss_op.cc
浏览文件 @
3628d894
...
@@ -47,8 +47,7 @@ class BprLossOp : public framework::OperatorWithKernel {
...
@@ -47,8 +47,7 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -94,8 +93,7 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
...
@@ -94,8 +93,7 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/controlflow/conditional_block_op.cc
浏览文件 @
3628d894
...
@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
...
@@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase {
if
(
!
(
ips
.
size
()
==
1UL
&&
ips
[
0
]
->
IsInitialized
()))
{
if
(
!
(
ips
.
size
()
==
1UL
&&
ips
[
0
]
->
IsInitialized
()))
{
PADDLE_THROW
(
"should have one initialized input as condition"
);
PADDLE_THROW
(
"should have one initialized input as condition"
);
}
}
if
(
!
(
framework
::
IsType
<
bool
>
(
ips
[
0
]
->
type
())
&&
// NOLINT
ips
[
0
]
->
numel
()
==
1
))
{
PADDLE_ENFORCE
(
ips
[
0
]
->
type
()
==
framework
::
proto
::
VarType
::
BOOL
&&
PADDLE_THROW
(
ips
[
0
]
->
numel
()
==
1
,
"condition input's data type should be bool, "
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d"
,
"numel should be 1, actual numel is %d"
,
ips
[
0
]
->
numel
());
ips
[
0
]
->
numel
());
}
bool
res
=
false
;
bool
res
=
false
;
if
(
platform
::
is_gpu_place
(
ips
[
0
]
->
place
()))
{
if
(
platform
::
is_gpu_place
(
ips
[
0
]
->
place
()))
{
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
3628d894
...
@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -261,7 +261,7 @@ class WhileGradOp : public framework::OperatorBase {
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
inside_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
attrs
[
"dtype"
]
=
framework
::
ToDataType
(
inside_tensor
.
type
()
);
attrs
[
"dtype"
]
=
inside_tensor
.
type
(
);
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"value"
]
=
0.0
f
;
attrs
[
"value"
]
=
0.0
f
;
...
...
paddle/fluid/operators/conv_op.cc
浏览文件 @
3628d894
...
@@ -97,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
...
@@ -97,10 +97,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
}
#endif
#endif
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
());
auto
filter_data_type
=
ctx
.
Input
<
Tensor
>
(
"Filter"
)
->
type
();
auto
filter_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Filter"
)
->
type
());
PADDLE_ENFORCE_EQ
(
input_data_type
,
filter_data_type
,
PADDLE_ENFORCE_EQ
(
input_data_type
,
filter_data_type
,
"input and filter data type should be consistent"
);
"input and filter data type should be consistent"
);
...
@@ -384,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
...
@@ -384,9 +382,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
()
,
ctx
.
GetPlace
(),
layout_
,
library_
,
layout_
,
library_
,
customized_type_value
);
customized_type_value
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/conv_transpose_op.cc
浏览文件 @
3628d894
...
@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
...
@@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
void
Conv2DTransposeOpMaker
::
Make
()
{
void
Conv2DTransposeOpMaker
::
Make
()
{
...
@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
...
@@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
framework
::
DataLayout
layout_
=
framework
::
StringToDataLayout
(
data_format
);
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/crf_decoding_op.cc
浏览文件 @
3628d894
...
@@ -118,8 +118,7 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
...
@@ -118,8 +118,7 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/crop_op.cc
浏览文件 @
3628d894
...
@@ -51,8 +51,7 @@ class CropOp : public framework::OperatorWithKernel {
...
@@ -51,8 +51,7 @@ class CropOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
...
@@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/cross_entropy_op.cc
浏览文件 @
3628d894
...
@@ -57,8 +57,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -57,8 +57,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -111,8 +110,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...
@@ -111,8 +110,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
// is determined by its input "X".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/ctc_align_op.cc
浏览文件 @
3628d894
...
@@ -36,8 +36,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
...
@@ -36,8 +36,7 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/anchor_generator_op.cc
浏览文件 @
3628d894
...
@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
...
@@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/bipartite_match_op.cc
浏览文件 @
3628d894
...
@@ -45,8 +45,7 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
...
@@ -45,8 +45,7 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
LoDTensor
>
(
"DistMat"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"DistMat"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/density_prior_box_op.cc
浏览文件 @
3628d894
...
@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
...
@@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/generate_proposals_op.cc
浏览文件 @
3628d894
...
@@ -66,8 +66,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
...
@@ -66,8 +66,7 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Anchors"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Anchors"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/mine_hard_examples_op.cc
浏览文件 @
3628d894
...
@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
...
@@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"ClsLoss"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"ClsLoss"
)
->
type
(),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/multiclass_nms_op.cc
浏览文件 @
3628d894
...
@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
...
@@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Scores"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Scores"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/prior_box_op.cc
浏览文件 @
3628d894
...
@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
...
@@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
Tensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/roi_perspective_transform_op.cc
浏览文件 @
3628d894
...
@@ -498,8 +498,7 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
...
@@ -498,8 +498,7 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -519,8 +518,7 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
...
@@ -519,8 +518,7 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection/rpn_target_assign_op.cc
浏览文件 @
3628d894
...
@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
...
@@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Anchor"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Anchor"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/detection/target_assign_op.cc
浏览文件 @
3628d894
...
@@ -57,8 +57,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
...
@@ -57,8 +57,7 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/detection_map_op.cc
浏览文件 @
3628d894
...
@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
...
@@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"DetectRes"
)
->
type
(),
ctx
.
Input
<
framework
::
Tensor
>
(
"DetectRes"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
3628d894
...
@@ -115,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -115,8 +115,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
ProtoEncodeHelper
e2
(
static_cast
<
char
*>
(
buf
),
128
);
size_t
rows_memory_size
=
size_t
rows_memory_size
=
slr
->
rows
().
size
()
*
sizeof
(
int64_t
);
slr
->
rows
().
size
()
*
framework
::
SizeOfType
(
typeid
(
int64_t
));
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
e2
.
WriteVarlengthBeginning
(
VarMsg
::
kRowsFieldNumber
,
rows_memory_size
);
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
slices
[
2
]
=
::
grpc
::
Slice
(
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
memcpy
(
const_cast
<
uint8_t
*>
(
slices
[
2
].
begin
()),
e2
.
data
(),
e2
.
size
());
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.cc
浏览文件 @
3628d894
...
@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
...
@@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var,
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// FIXME(wuyi): data types in send_recv.proto is copied from
// FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto
// framework.proto
request
->
set_data_type
(
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
tensor
.
type
()));
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
tensor
.
type
())));
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
for
(
auto
&
dim
:
framework
::
vectorize
(
tensor
.
dims
()))
{
request
->
add_dims
(
dim
);
request
->
add_dims
(
dim
);
}
}
...
@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
...
@@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
)
{
VarMsg
*
request
)
{
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
var
->
GetMutable
<
framework
::
SelectedRows
>
();
request
->
set_data_type
(
request
->
set_data_type
(
static_cast
<
VarMsg
::
Type
>
(
slr
->
value
().
type
()));
static_cast
<
VarMsg
::
Type
>
(
framework
::
ToDataType
(
slr
->
value
().
type
())));
request
->
set_lod_level
(
0
);
request
->
set_lod_level
(
0
);
request
->
set_slr_height
(
slr
->
height
());
request
->
set_slr_height
(
slr
->
height
());
...
...
paddle/fluid/operators/distributed/sendrecvop_utils.h
浏览文件 @
3628d894
...
@@ -65,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
...
@@ -65,18 +65,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
request
);
VarMsg
*
request
);
inline
std
::
type_index
ToTypeIndex
(
sendrecv
::
VariableMessage
::
Type
type
)
{
inline
framework
::
proto
::
VarType
::
Type
ToVarType
(
sendrecv
::
VariableMessage
::
Type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
sendrecv
::
VariableMessage
::
FP32
:
case
sendrecv
::
VariableMessage
::
FP32
:
return
typeid
(
float
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
FP32
;
// NOLINT
case
sendrecv
::
VariableMessage
::
FP64
:
case
sendrecv
::
VariableMessage
::
FP64
:
return
typeid
(
double
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
FP64
;
// NOLINT
case
sendrecv
::
VariableMessage
::
INT32
:
case
sendrecv
::
VariableMessage
::
INT32
:
return
typeid
(
int
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
INT32
;
// NOLINT
case
sendrecv
::
VariableMessage
::
INT64
:
case
sendrecv
::
VariableMessage
::
INT64
:
return
typeid
(
int64_t
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
INT64
;
// NOLINT
case
sendrecv
::
VariableMessage
::
BOOL
:
case
sendrecv
::
VariableMessage
::
BOOL
:
return
typeid
(
bool
)
;
// NOLINT
return
framework
::
proto
::
VarType
::
BOOL
;
// NOLINT
default:
default:
PADDLE_THROW
(
"Not support type %d"
,
type
);
PADDLE_THROW
(
"Not support type %d"
,
type
);
}
}
...
...
paddle/fluid/operators/distributed/variable_response.cc
浏览文件 @
3628d894
...
@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
...
@@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData(
tensor
->
set_lod
(
lod
);
tensor
->
set_lod
(
lod
);
void
*
tensor_data
=
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
To
TypeIndex
(
meta_
.
data_type
()));
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
To
VarType
(
meta_
.
data_type
()));
VLOG
(
6
)
<<
"Tensor.memory_size = "
<<
tensor
->
memory_size
()
VLOG
(
6
)
<<
"Tensor.memory_size = "
<<
tensor
->
memory_size
()
<<
", Buffer Size = "
<<
length
;
<<
", Buffer Size = "
<<
length
;
...
@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
...
@@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData(
slr
->
set_height
(
meta_
.
slr_height
());
slr
->
set_height
(
meta_
.
slr_height
());
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
tensor
->
Resize
(
dims
);
tensor
->
Resize
(
dims
);
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
tensor
->
numel
()),
PADDLE_ENFORCE_EQ
(
length
/
framework
::
SizeOfType
(
static_cast
<
size_t
>
(
tensor
->
numel
()),
paddle
::
operators
::
distributed
::
ToTypeIndex
(
length
/
framework
::
SizeOfType
(
paddle
::
operators
::
distributed
::
ToVarType
(
meta_
.
data_type
())));
meta_
.
data_type
())));
void
*
tensor_data
=
tensor
->
mutable_data
(
void
*
tensor_data
=
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
paddle
::
operators
::
distributed
::
To
TypeIndex
(
meta_
.
data_type
()));
paddle
::
operators
::
distributed
::
To
VarType
(
meta_
.
data_type
()));
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
if
(
!
ReadRaw
(
input
,
ctx
,
tensor
->
place
(),
tensor_data
,
length
))
{
return
false
;
return
false
;
...
@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
...
@@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData(
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
const
platform
::
DeviceContext
&
ctx
,
int
length
)
{
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
slr
=
GetVar
()
->
GetMutable
<
framework
::
SelectedRows
>
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
clear
();
slr
->
mutable_rows
()
->
resize
(
length
/
slr
->
mutable_rows
()
->
resize
(
length
/
sizeof
(
int64_t
));
// int64
framework
::
SizeOfType
(
typeid
(
int64_t
)));
// int64
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
int64_t
*
rows_data
=
slr
->
mutable_rows
()
->
data
();
// copy rows CPU data, GPU data will be copied lazily.
// copy rows CPU data, GPU data will be copied lazily.
...
...
paddle/fluid/operators/distributed_ops/merge_ids_op.cc
浏览文件 @
3628d894
...
@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
...
@@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
).
front
()
->
type
(),
ctx
.
GetPlace
());
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
).
front
()
->
type
()),
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc
浏览文件 @
3628d894
...
@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
...
@@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
)[
0
]
->
type
(),
ctx
.
GetPlace
());
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
3628d894
...
@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
...
@@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
framework
::
ToDataType
(
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()
)
;
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
();
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
if
(
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
3628d894
...
@@ -115,8 +115,7 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -115,8 +115,7 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -175,8 +174,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -175,8 +174,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/fc_op.cc
浏览文件 @
3628d894
...
@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
...
@@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType(
library
=
framework
::
LibraryType
::
kMKLDNN
;
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout
,
library
);
layout
,
library
);
}
}
void
FCOpGrad
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
FCOpGrad
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
...
@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
...
@@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
library
=
framework
::
LibraryType
::
kMKLDNN
;
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
layout
,
library
);
layout
,
library
);
}
}
void
FCOpMaker
::
Make
()
{
void
FCOpMaker
::
Make
()
{
...
...
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
3628d894
...
@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
...
@@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase {
if
(
force_cpu
)
{
if
(
force_cpu
)
{
auto
cpu
=
platform
::
CPUPlace
();
auto
cpu
=
platform
::
CPUPlace
();
tensor
->
mutable_data
(
cpu
,
framework
::
ToTypeIndex
(
data_type
)
);
tensor
->
mutable_data
(
cpu
,
data_type
);
}
else
{
}
else
{
tensor
->
mutable_data
(
dev_place
,
framework
::
ToTypeIndex
(
data_type
)
);
tensor
->
mutable_data
(
dev_place
,
data_type
);
}
}
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/operators/fill_op.cc
浏览文件 @
3628d894
...
@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
...
@@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase {
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
Attr
<
int
>
(
"dtype"
));
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
Attr
<
int
>
(
"dtype"
));
platform
::
CPUPlace
cpu
;
platform
::
CPUPlace
cpu
;
auto
force_cpu
=
Attr
<
bool
>
(
"force_cpu"
);
auto
force_cpu
=
Attr
<
bool
>
(
"force_cpu"
);
out
.
mutable_data
(
force_cpu
?
cpu
:
place
,
framework
::
ToTypeIndex
(
dtype
)
);
out
.
mutable_data
(
force_cpu
?
cpu
:
place
,
dtype
);
framework
::
LoDTensor
tensor
;
framework
::
LoDTensor
tensor
;
...
@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
...
@@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase {
}
else
{
}
else
{
// Always make tensor in CPU memory.
// Always make tensor in CPU memory.
tensor
.
Resize
(
out
.
dims
());
tensor
.
Resize
(
out
.
dims
());
tensor
.
mutable_data
(
cpu
,
framework
::
ToTypeIndex
(
dtype
)
);
tensor
.
mutable_data
(
cpu
,
dtype
);
}
}
framework
::
VisitDataType
(
framework
::
VisitDataType
(
...
...
paddle/fluid/operators/fused/fused_elemwise_activation_op.cc
浏览文件 @
3628d894
...
@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
...
@@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
)
->
type
(),
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
)
->
type
(),
"The element's type of input should be the same."
);
"The element's type of input should be the same."
);
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
...
@@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type_index
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
)
->
type
();
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
)
->
type
(),
auto
input_data_type
=
framework
::
ToDataType
(
input_data_type_index
);
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
浏览文件 @
3628d894
...
@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
...
@@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework
::
OpKernelType
FusedEmbeddingFCLSTMOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusedEmbeddingFCLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Embeddings"
)
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Embeddings"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/fused/fusion_gru_op.cc
浏览文件 @
3628d894
...
@@ -93,8 +93,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -93,8 +93,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionGRUOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
3628d894
...
@@ -117,8 +117,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -117,8 +117,7 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
OpKernelType
FusionLSTMOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
浏览文件 @
3628d894
...
@@ -61,8 +61,7 @@ void FusionSeqConvEltAddReluOp::InferShape(
...
@@ -61,8 +61,7 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework
::
OpKernelType
FusionSeqConvEltAddReluOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionSeqConvEltAddReluOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
浏览文件 @
3628d894
...
@@ -67,8 +67,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
...
@@ -67,8 +67,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework
::
OpKernelType
FusionSeqExpandConcatFCOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
FusionSeqExpandConcatFCOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
MultiInput
<
LoDTensor
>
(
"X"
)[
0
]
->
type
(),
framework
::
ToDataType
(
ctx
.
MultiInput
<
LoDTensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/gather_op.cc
浏览文件 @
3628d894
...
@@ -42,8 +42,7 @@ class GatherOp : public framework::OperatorWithKernel {
...
@@ -42,8 +42,7 @@ class GatherOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -60,8 +59,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
...
@@ -60,8 +59,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
3628d894
...
@@ -63,8 +63,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
...
@@ -63,8 +63,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
}
};
};
...
@@ -159,8 +159,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
...
@@ -159,8 +159,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_
=
framework
::
LibraryType
::
kCUDNN
;
library_
=
framework
::
LibraryType
::
kCUDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
ctx
.
GetPlace
(),
framework
::
DataLayout
::
kAnyLayout
,
library_
);
framework
::
DataLayout
::
kAnyLayout
,
library_
);
}
}
};
};
...
...
paddle/fluid/operators/group_norm_op.cc
浏览文件 @
3628d894
...
@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
...
@@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
if
(
t
==
nullptr
)
{
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
"can't find Y@GRAD"
);
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
}
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
t
->
type
()),
return
framework
::
OpKernelType
(
t
->
type
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/hierarchical_sigmoid_op.cc
浏览文件 @
3628d894
...
@@ -76,8 +76,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
...
@@ -76,8 +76,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
@@ -163,8 +162,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
...
@@ -163,8 +162,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/interpolate_op.cc
浏览文件 @
3628d894
...
@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
...
@@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
...
@@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/is_empty_op.cc
浏览文件 @
3628d894
...
@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
...
@@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
return
kt
;
return
kt
;
}
}
};
};
...
...
paddle/fluid/operators/isfinite_op.cc
浏览文件 @
3628d894
...
@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
...
@@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel {
int
dtype
=
-
1
;
int
dtype
=
-
1
;
auto
*
x_var
=
ctx
.
InputVar
(
"X"
);
auto
*
x_var
=
ctx
.
InputVar
(
"X"
);
if
(
x_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
x_var
->
IsType
<
framework
::
LoDTensor
>
())
{
dtype
=
framework
::
ToDataType
(
x_var
->
Get
<
framework
::
LoDTensor
>
().
type
()
);
dtype
=
x_var
->
Get
<
framework
::
LoDTensor
>
().
type
(
);
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
dtype
=
framework
::
ToDataType
(
dtype
=
x_var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
();
x_var
->
Get
<
framework
::
SelectedRows
>
().
value
().
type
());
}
else
{
}
else
{
PADDLE_THROW
(
"Cannot find the input data type by all input data"
);
PADDLE_THROW
(
"Cannot find the input data type by all input data"
);
}
}
...
...
paddle/fluid/operators/layer_norm_op.cc
浏览文件 @
3628d894
...
@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
...
@@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
if
(
t
==
nullptr
)
{
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
"can't find Y@GRAD"
);
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
}
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
t
->
type
()),
return
framework
::
OpKernelType
(
t
->
type
(),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/linear_chain_crf_op.cc
浏览文件 @
3628d894
...
@@ -184,8 +184,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
...
@@ -184,8 +184,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission".
// is determined by its input "Emission".
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
"Emission"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
...
@@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
(),
ctx
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"LogLikelihood"
))
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/load_combine_op.cc
浏览文件 @
3628d894
...
@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase {
...
@@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor
// Get data from fin to tensor
DeserializeFromStream
(
*
buffer
,
tensor
,
dev_ctx
);
DeserializeFromStream
(
*
buffer
,
tensor
,
dev_ctx
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
()
);
auto
in_dtype
=
tensor
->
type
(
);
auto
out_dtype
=
auto
out_dtype
=
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
...
paddle/fluid/operators/load_op.cc
浏览文件 @
3628d894
...
@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
...
@@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase {
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
DeserializeFromStream
(
fin
,
tensor
,
dev_ctx
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
load_as_fp16
=
Attr
<
bool
>
(
"load_as_fp16"
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
->
type
()
);
auto
in_dtype
=
tensor
->
type
(
);
auto
out_dtype
=
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
auto
out_dtype
=
load_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
if
(
in_dtype
!=
out_dtype
)
{
if
(
in_dtype
!=
out_dtype
)
{
...
...
paddle/fluid/operators/lod_reset_op.cc
浏览文件 @
3628d894
...
@@ -39,8 +39,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
...
@@ -39,8 +39,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -144,8 +143,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
...
@@ -144,8 +143,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
3628d894
...
@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
...
@@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
LoDTensorToArrayFunctorImpl
<
DeviceContext
>
func
;
LoDTensorToArrayFunctorImpl
<
DeviceContext
>
func
;
func
.
prev_functor_
=
this
;
func
.
prev_functor_
=
this
;
func
.
dev_ctx_
=
dev_ctx
;
func
.
dev_ctx_
=
dev_ctx
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
input_
.
type
()
),
func
);
framework
::
VisitDataType
(
input_
.
type
(
),
func
);
}
}
};
};
...
...
paddle/fluid/operators/lookup_sparse_table_op.cc
浏览文件 @
3628d894
...
@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
...
@@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape
[
0
]
=
ids_t
.
numel
();
out_shape
[
0
]
=
ids_t
.
numel
();
out_t
->
Resize
(
out_shape
);
out_t
->
Resize
(
out_shape
);
out_t
->
mutable_data
(
cpu
,
w_t
->
value
().
type
());
out_t
->
mutable_data
(
cpu
,
w_t
->
value
().
type
());
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
PADDLE_ENFORCE_EQ
(
w_t
->
value
().
type
(),
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
"The sparse table only support FP32"
);
"The sparse table only support FP32"
);
w_t
->
Get
(
ids_t
,
out_t
,
true
,
is_test
);
w_t
->
Get
(
ids_t
,
out_t
,
true
,
is_test
);
out_t
->
set_lod
(
ids_t
.
lod
());
out_t
->
set_lod
(
ids_t
.
lod
());
...
...
paddle/fluid/operators/lrn_op.cc
浏览文件 @
3628d894
...
@@ -145,8 +145,7 @@ framework::OpKernelType GetExpectedLRNKernel(
...
@@ -145,8 +145,7 @@ framework::OpKernelType GetExpectedLRNKernel(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
}
// namespace
}
// namespace
...
...
paddle/fluid/operators/lstm_op.cc
浏览文件 @
3628d894
...
@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
...
@@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/lstmp_op.cc
浏览文件 @
3628d894
...
@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
...
@@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
...
@@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
()),
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Input"
)
->
type
(),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/math/math_function.cc
浏览文件 @
3628d894
...
@@ -77,16 +77,14 @@ template <>
...
@@ -77,16 +77,14 @@ template <>
void
set_constant_with_place
<
platform
::
CPUPlace
>
(
void
set_constant_with_place
<
platform
::
CPUPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
float
value
)
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
tensor
->
type
()),
framework
::
VisitDataType
(
tensor
->
type
(),
TensorSetConstantCPU
(
tensor
,
value
));
TensorSetConstantCPU
(
tensor
,
value
));
}
}
template
<
>
template
<
>
void
set_constant_with_place
<
platform
::
CUDAPinnedPlace
>
(
void
set_constant_with_place
<
platform
::
CUDAPinnedPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
float
value
)
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
tensor
->
type
()),
framework
::
VisitDataType
(
tensor
->
type
(),
TensorSetConstantCPU
(
tensor
,
value
));
TensorSetConstantCPU
(
tensor
,
value
));
}
}
struct
TensorSetConstantWithPlace
:
public
boost
::
static_visitor
<
void
>
{
struct
TensorSetConstantWithPlace
:
public
boost
::
static_visitor
<
void
>
{
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
3628d894
...
@@ -65,7 +65,7 @@ template <>
...
@@ -65,7 +65,7 @@ template <>
void
set_constant_with_place
<
platform
::
CUDAPlace
>
(
void
set_constant_with_place
<
platform
::
CUDAPlace
>
(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
*
tensor
,
float
value
)
{
float
value
)
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
tensor
->
type
()
),
framework
::
VisitDataType
(
tensor
->
type
(
),
TensorSetConstantGPU
(
context
,
tensor
,
value
));
TensorSetConstantGPU
(
context
,
tensor
,
value
));
}
}
...
...
paddle/fluid/operators/mean_iou_op.cc
浏览文件 @
3628d894
...
@@ -44,8 +44,7 @@ class MeanIoUOp : public framework::OperatorWithKernel {
...
@@ -44,8 +44,7 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Predictions"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Predictions"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/mean_op.cc
浏览文件 @
3628d894
...
@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
...
@@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
3628d894
...
@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
...
@@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
platform
::
Place
place
=
dev_place
;
platform
::
Place
place
=
dev_place
;
int64_t
batch_size
=
in_true
.
dims
()[
0
]
+
in_false
.
dims
()[
0
];
int64_t
batch_size
=
in_true
.
dims
()[
0
]
+
in_false
.
dims
()[
0
];
auto
data_type
=
in_true
.
IsInitialized
()
?
in_true
.
type
()
:
in_false
.
type
();
std
::
type_index
data_type
=
in_true
.
IsInitialized
()
?
in_true
.
type
()
:
in_false
.
type
();
int
rank
;
int
rank
;
framework
::
DDim
in_dims
;
framework
::
DDim
in_dims
;
if
(
in_true
.
IsInitialized
())
{
if
(
in_true
.
IsInitialized
())
{
...
...
paddle/fluid/operators/metrics/accuracy_op.cc
浏览文件 @
3628d894
...
@@ -55,8 +55,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
...
@@ -55,8 +55,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Out"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/metrics/auc_op.cc
浏览文件 @
3628d894
...
@@ -51,8 +51,7 @@ class AucOp : public framework::OperatorWithKernel {
...
@@ -51,8 +51,7 @@ class AucOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Predict"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Predict"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/metrics/precision_recall_op.cc
浏览文件 @
3628d894
...
@@ -82,8 +82,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
...
@@ -82,8 +82,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"MaxProbs"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/multiplex_op.cc
浏览文件 @
3628d894
...
@@ -53,8 +53,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
...
@@ -53,8 +53,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
(),
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -123,8 +122,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
...
@@ -123,8 +122,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
(),
framework
::
ToDataType
(
ctx
.
MultiInput
<
Tensor
>
(
"X"
)[
0
]
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/nce_op.cc
浏览文件 @
3628d894
...
@@ -69,8 +69,7 @@ class NCEOp : public framework::OperatorWithKernel {
...
@@ -69,8 +69,7 @@ class NCEOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -214,8 +213,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
...
@@ -214,8 +213,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/adadelta_op.cc
浏览文件 @
3628d894
...
@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
...
@@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/adagrad_op.cc
浏览文件 @
3628d894
...
@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
...
@@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/adam_op.cc
浏览文件 @
3628d894
...
@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel {
...
@@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/adamax_op.cc
浏览文件 @
3628d894
...
@@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
...
@@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/decayed_adagrad_op.cc
浏览文件 @
3628d894
...
@@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
...
@@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/ftrl_op.cc
浏览文件 @
3628d894
...
@@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel {
...
@@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/proximal_adagrad_op.cc
浏览文件 @
3628d894
...
@@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
...
@@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/optimizers/proximal_gd_op.cc
浏览文件 @
3628d894
...
@@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
...
@@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Param"
)
->
type
());
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/pad2d_op.cc
浏览文件 @
3628d894
...
@@ -511,8 +511,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
...
@@ -511,8 +511,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
@@ -612,8 +612,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
...
@@ -612,8 +612,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/pad_constant_like_op.cc
浏览文件 @
3628d894
...
@@ -47,8 +47,7 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
...
@@ -47,8 +47,7 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -171,8 +170,7 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
...
@@ -171,8 +170,7 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
3628d894
...
@@ -104,8 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
...
@@ -104,8 +104,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
layout_
,
library_
);
layout_
,
library_
);
}
}
...
@@ -135,7 +134,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
...
@@ -135,7 +134,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
}
}
#endif
#endif
auto
input_data_type
=
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()
);
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(
);
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
PADDLE_ENFORCE_EQ
(
library_
,
framework
::
LibraryType
::
kCUDNN
,
PADDLE_ENFORCE_EQ
(
library_
,
framework
::
LibraryType
::
kCUDNN
,
"float16 can only be used when CUDNN is used"
);
"float16 can only be used when CUDNN is used"
);
...
...
paddle/fluid/operators/pool_with_index_op.cc
浏览文件 @
3628d894
...
@@ -76,8 +76,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
...
@@ -76,8 +76,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -97,8 +96,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
...
@@ -97,8 +96,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/positive_negative_pair_op.cc
浏览文件 @
3628d894
...
@@ -87,8 +87,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
...
@@ -87,8 +87,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Score"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/prelu_op.cc
浏览文件 @
3628d894
...
@@ -56,8 +56,7 @@ class PReluOp : public framework::OperatorWithKernel {
...
@@ -56,8 +56,7 @@ class PReluOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -113,8 +112,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
...
@@ -113,8 +112,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/print_op.cc
浏览文件 @
3628d894
...
@@ -172,7 +172,7 @@ class TensorPrintOp : public framework::OperatorBase {
...
@@ -172,7 +172,7 @@ class TensorPrintOp : public framework::OperatorBase {
formater
.
name
=
printed_var_name
;
formater
.
name
=
printed_var_name
;
}
}
if
(
Attr
<
bool
>
(
"print_tensor_type"
))
{
if
(
Attr
<
bool
>
(
"print_tensor_type"
))
{
formater
.
dtype
=
printed_tensor
.
type
(
);
formater
.
dtype
=
framework
::
ToTypeIndex
(
printed_tensor
.
type
()
);
}
}
if
(
Attr
<
bool
>
(
"print_tensor_shape"
))
{
if
(
Attr
<
bool
>
(
"print_tensor_shape"
))
{
auto
&
dims
=
printed_tensor
.
dims
();
auto
&
dims
=
printed_tensor
.
dims
();
...
...
paddle/fluid/operators/psroi_pool_op.cc
浏览文件 @
3628d894
...
@@ -129,8 +129,7 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
...
@@ -129,8 +129,7 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -150,8 +149,7 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
...
@@ -150,8 +149,7 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/random_crop_op.cc
浏览文件 @
3628d894
...
@@ -22,8 +22,7 @@ class RandomCropOp : public framework::OperatorWithKernel {
...
@@ -22,8 +22,7 @@ class RandomCropOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/reader/create_batch_reader_op.cc
浏览文件 @
3628d894
...
@@ -99,10 +99,10 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
...
@@ -99,10 +99,10 @@ void BatchReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
out
->
reserve
(
out_num
);
out
->
reserve
(
out_num
);
for
(
size_t
j
=
0
;
j
<
out_num
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
out_num
;
++
j
)
{
// Merge shape and check date type
// Merge shape and check date type
std
::
type_index
batch_type
=
buffer_
[
0
][
j
].
type
();
auto
batch_type
=
buffer_
[
0
][
j
].
type
();
framework
::
DDim
batch_shape
=
buffer_
[
0
][
j
].
dims
();
framework
::
DDim
batch_shape
=
buffer_
[
0
][
j
].
dims
();
for
(
size_t
i
=
1
;
i
<
buffer_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
buffer_
.
size
();
++
i
)
{
std
::
type_index
ins_type
=
buffer_
[
i
][
j
].
type
();
auto
ins_type
=
buffer_
[
i
][
j
].
type
();
framework
::
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
framework
::
DDim
ins_shape
=
buffer_
[
i
][
j
].
dims
();
PADDLE_ENFORCE_EQ
(
batch_type
,
ins_type
);
PADDLE_ENFORCE_EQ
(
batch_type
,
ins_type
);
PADDLE_ENFORCE_EQ
(
slice_ddim
(
batch_shape
,
1
,
batch_shape
.
size
()),
PADDLE_ENFORCE_EQ
(
slice_ddim
(
batch_shape
,
1
,
batch_shape
.
size
()),
...
...
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
3628d894
...
@@ -414,7 +414,7 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -414,7 +414,7 @@ class RecurrentGradOp : public RecurrentBase {
auto
&
inside_tensor
=
cur_scope
.
FindVar
(
inside_grad_name
)
auto
&
inside_tensor
=
cur_scope
.
FindVar
(
inside_grad_name
)
->
Get
<
framework
::
LoDTensor
>
();
->
Get
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
attrs
[
"dtype"
]
=
framework
::
ToDataType
(
inside_tensor
.
type
()
);
attrs
[
"dtype"
]
=
inside_tensor
.
type
(
);
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
inside_tensor
.
dims
());
attrs
[
"value"
]
=
0.0
f
;
attrs
[
"value"
]
=
0.0
f
;
...
...
paddle/fluid/operators/reshape_op.cc
浏览文件 @
3628d894
...
@@ -108,8 +108,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
...
@@ -108,8 +108,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -189,8 +188,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
...
@@ -189,8 +188,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -322,9 +320,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
...
@@ -322,9 +320,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/rnn_memory_helper_op.cc
浏览文件 @
3628d894
...
@@ -99,7 +99,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
...
@@ -99,7 +99,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
auto
&
in_var_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
in_var_tensor
=
in_var
->
Get
<
framework
::
LoDTensor
>
();
framework
::
AttributeMap
attrs
;
framework
::
AttributeMap
attrs
;
attrs
[
"dtype"
]
=
framework
::
ToDataType
(
in_var_tensor
.
type
()
);
attrs
[
"dtype"
]
=
in_var_tensor
.
type
(
);
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
in_var_tensor
.
dims
());
attrs
[
"shape"
]
=
framework
::
vectorize2int
(
in_var_tensor
.
dims
());
attrs
[
"value"
]
=
0.0
f
;
attrs
[
"value"
]
=
0.0
f
;
...
...
paddle/fluid/operators/roi_align_op.cc
浏览文件 @
3628d894
...
@@ -62,8 +62,7 @@ class ROIAlignOp : public framework::OperatorWithKernel {
...
@@ -62,8 +62,7 @@ class ROIAlignOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -83,8 +82,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
...
@@ -83,8 +82,7 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/roi_pool_op.cc
浏览文件 @
3628d894
...
@@ -69,8 +69,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
...
@@ -69,8 +69,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -90,8 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
...
@@ -90,8 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/save_combine_op.cc
浏览文件 @
3628d894
...
@@ -75,7 +75,7 @@ class SaveCombineOp : public framework::OperatorBase {
...
@@ -75,7 +75,7 @@ class SaveCombineOp : public framework::OperatorBase {
// Serialize tensors one by one
// Serialize tensors one by one
// Check types to see if a fp16 transformation is required
// Check types to see if a fp16 transformation is required
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
()
);
auto
in_dtype
=
tensor
.
type
(
);
auto
out_dtype
=
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
...
...
paddle/fluid/operators/save_op.cc
浏览文件 @
3628d894
...
@@ -85,7 +85,7 @@ class SaveOp : public framework::OperatorBase {
...
@@ -85,7 +85,7 @@ class SaveOp : public framework::OperatorBase {
filename
);
filename
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
save_as_fp16
=
Attr
<
bool
>
(
"save_as_fp16"
);
auto
in_dtype
=
framework
::
ToDataType
(
tensor
.
type
()
);
auto
in_dtype
=
tensor
.
type
(
);
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
auto
out_dtype
=
save_as_fp16
?
framework
::
proto
::
VarType
::
FP16
:
in_dtype
;
if
(
in_dtype
!=
out_dtype
)
{
if
(
in_dtype
!=
out_dtype
)
{
...
...
paddle/fluid/operators/scatter_op.cc
浏览文件 @
3628d894
...
@@ -51,8 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel {
...
@@ -51,8 +51,7 @@ class ScatterOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -70,8 +69,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
...
@@ -70,8 +69,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
浏览文件 @
3628d894
...
@@ -114,8 +114,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
...
@@ -114,8 +114,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc
浏览文件 @
3628d894
...
@@ -112,8 +112,7 @@ class SequenceScatterOp : public framework::OperatorWithKernel {
...
@@ -112,8 +112,7 @@ class SequenceScatterOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -131,8 +130,7 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
...
@@ -131,8 +130,7 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/sequence_ops/sequence_slice_op.cc
浏览文件 @
3628d894
...
@@ -50,8 +50,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
...
@@ -50,8 +50,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -71,8 +70,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
...
@@ -71,8 +70,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc
浏览文件 @
3628d894
...
@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
...
@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
}
}
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()
),
ctx
.
GetPlace
(),
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(
),
ctx
.
GetPlace
(),
framework
::
StringToDataLayout
(
data_format
),
library_
);
framework
::
StringToDataLayout
(
data_format
),
library_
);
}
}
};
};
...
@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
...
@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
}
}
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
std
::
string
data_format
=
ctx
.
Attr
<
std
::
string
>
(
"data_format"
);
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()
),
ctx
.
GetPlace
(),
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(
),
ctx
.
GetPlace
(),
framework
::
StringToDataLayout
(
data_format
),
library_
);
framework
::
StringToDataLayout
(
data_format
),
library_
);
}
}
};
};
...
...
paddle/fluid/operators/similarity_focus_op.cc
浏览文件 @
3628d894
...
@@ -70,8 +70,7 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
...
@@ -70,8 +70,7 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/operators/slice_op.cc
浏览文件 @
3628d894
...
@@ -59,8 +59,7 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -59,8 +59,7 @@ class SliceOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Input"
)
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
}
}
};
};
...
...
paddle/fluid/operators/softmax_op.cc
浏览文件 @
3628d894
...
@@ -62,8 +62,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -62,8 +62,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
}
}
#endif
#endif
auto
input_data_type
=
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
();
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"float16 can only be used on GPU place"
);
"float16 can only be used on GPU place"
);
...
@@ -169,8 +168,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -169,8 +168,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
layout_
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
#endif
#endif
auto
input_data_type
=
framework
::
ToDataType
(
auto
input_data_type
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()
)
;
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
();
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
FP16
)
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"float16 can only be used on GPU place"
);
"float16 can only be used on GPU place"
);
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
3628d894
...
@@ -131,8 +131,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -131,8 +131,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -173,8 +172,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -173,8 +172,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
(),
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
))
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
3628d894
...
@@ -91,9 +91,9 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -91,9 +91,9 @@ class SumOp : public framework::OperatorWithKernel {
continue
;
continue
;
}
}
if
(
dtype
==
-
1
)
{
if
(
dtype
==
-
1
)
{
dtype
=
framework
::
ToDataType
(
tensor
->
type
()
);
dtype
=
tensor
->
type
(
);
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
dtype
,
framework
::
ToDataType
(
tensor
->
type
()
));
PADDLE_ENFORCE_EQ
(
dtype
,
tensor
->
type
(
));
}
}
}
}
PADDLE_ENFORCE_NE
(
dtype
,
-
1
,
PADDLE_ENFORCE_NE
(
dtype
,
-
1
,
...
@@ -106,8 +106,8 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -106,8 +106,8 @@ class SumOp : public framework::OperatorWithKernel {
for
(
auto
&
var
:
x_vars
)
{
for
(
auto
&
var
:
x_vars
)
{
auto
&
value
=
var
->
Get
<
framework
::
SelectedRows
>
().
value
();
auto
&
value
=
var
->
Get
<
framework
::
SelectedRows
>
().
value
();
if
(
value
.
IsInitialized
())
{
if
(
value
.
IsInitialized
())
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
value
.
type
()
),
return
framework
::
OpKernelType
(
value
.
type
(),
ctx
.
device_context
(
),
ctx
.
device_context
(),
layout
,
library
);
layout
,
library
);
}
}
}
}
// if input sparse vars are not initialized, use an default kernel type.
// if input sparse vars are not initialized, use an default kernel type.
...
@@ -118,9 +118,8 @@ class SumOp : public framework::OperatorWithKernel {
...
@@ -118,9 +118,8 @@ class SumOp : public framework::OperatorWithKernel {
auto
&
array
=
x_var
->
Get
<
framework
::
LoDTensorArray
>
();
auto
&
array
=
x_var
->
Get
<
framework
::
LoDTensorArray
>
();
for
(
auto
&
each
:
array
)
{
for
(
auto
&
each
:
array
)
{
if
(
each
.
numel
()
!=
0
)
{
if
(
each
.
numel
()
!=
0
)
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
each
.
type
()),
return
framework
::
OpKernelType
(
each
.
type
(),
ctx
.
device_context
(),
ctx
.
device_context
(),
layout
,
layout
,
library
);
library
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/tensorrt/tensorrt_engine_op.h
浏览文件 @
3628d894
...
@@ -76,10 +76,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
...
@@ -76,10 +76,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel {
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input0
=
ctx
.
Inputs
(
"Xs"
).
front
();
auto
input0
=
ctx
.
Inputs
(
"Xs"
).
front
();
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
OpKernelType
kt
=
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
scope
()
ctx
.
scope
().
FindVar
(
input0
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
(),
.
FindVar
(
input0
)
->
GetMutable
<
framework
::
LoDTensor
>
()
->
type
()),
ctx
.
GetPlace
());
ctx
.
GetPlace
());
return
kt
;
return
kt
;
}
}
...
...
paddle/fluid/operators/transpose_op.cc
浏览文件 @
3628d894
...
@@ -144,8 +144,7 @@ class Transpose2Op : public TransposeOp {
...
@@ -144,8 +144,7 @@ class Transpose2Op : public TransposeOp {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
@@ -194,9 +193,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
...
@@ -194,9 +193,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
(),
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
))
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/unpool_op.cc
浏览文件 @
3628d894
...
@@ -74,8 +74,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
...
@@ -74,8 +74,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
@@ -113,8 +112,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
...
@@ -113,8 +112,7 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
...
...
paddle/fluid/operators/warpctc_op.cc
浏览文件 @
3628d894
...
@@ -56,8 +56,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
...
@@ -56,8 +56,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
}
}
#endif
#endif
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
()),
ctx
.
device_context
(),
layout_
,
library_
);
ctx
.
device_context
(),
layout_
,
library_
);
}
}
};
};
...
@@ -136,8 +135,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
...
@@ -136,8 +135,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"Logits"
)
->
type
()),
ctx
.
device_context
());
ctx
.
device_context
());
}
}
};
};
...
...
paddle/fluid/operators/yolov3_loss_op.cc
浏览文件 @
3628d894
...
@@ -64,8 +64,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
...
@@ -64,8 +64,7 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
@@ -180,8 +179,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
...
@@ -180,8 +179,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
platform
::
CPUPlace
());
platform
::
CPUPlace
());
}
}
};
};
...
...
paddle/fluid/platform/nccl_helper.h
浏览文件 @
3628d894
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <typeindex>
#include <typeindex>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -28,14 +29,14 @@
...
@@ -28,14 +29,14 @@
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
inline
ncclDataType_t
ToNCCLDataType
(
std
::
type_index
type
)
{
inline
ncclDataType_t
ToNCCLDataType
(
framework
::
proto
::
VarType
::
Type
type
)
{
if
(
type
==
typeid
(
float
))
{
// NOLINT
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
return
ncclFloat
;
return
ncclFloat
;
}
else
if
(
type
==
typeid
(
double
))
{
// NOLINT
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP64
)
{
return
ncclDouble
;
return
ncclDouble
;
}
else
if
(
type
==
typeid
(
int
))
{
// NOLINT
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT32
)
{
return
ncclInt
;
return
ncclInt
;
}
else
if
(
type
==
typeid
(
int64_t
))
{
// NOLINT
}
else
if
(
type
==
framework
::
proto
::
VarType
::
INT64
)
{
return
ncclInt64
;
return
ncclInt64
;
}
else
{
}
else
{
PADDLE_THROW
(
"Not supported"
);
PADDLE_THROW
(
"Not supported"
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
3628d894
...
@@ -214,7 +214,7 @@ PYBIND11_MODULE(core, m) {
...
@@ -214,7 +214,7 @@ PYBIND11_MODULE(core, m) {
.
def
(
"_get_float_element"
,
TensorGetElement
<
float
>
)
.
def
(
"_get_float_element"
,
TensorGetElement
<
float
>
)
.
def
(
"_set_double_element"
,
TensorSetElement
<
double
>
)
.
def
(
"_set_double_element"
,
TensorSetElement
<
double
>
)
.
def
(
"_get_double_element"
,
TensorGetElement
<
double
>
)
.
def
(
"_get_double_element"
,
TensorGetElement
<
double
>
)
.
def
(
"_dtype"
,
[](
Tensor
&
self
)
{
return
ToDataType
(
self
.
type
()
);
});
.
def
(
"_dtype"
,
[](
Tensor
&
self
)
{
return
self
.
type
(
);
});
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
,
R"DOC(
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
,
R"DOC(
LoDTensor is a Tensor with optional LoD information.
LoDTensor is a Tensor with optional LoD information.
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
3628d894
...
@@ -43,7 +43,7 @@ template <size_t I, typename... ARGS>
...
@@ -43,7 +43,7 @@ template <size_t I, typename... ARGS>
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
pybind11
::
buffer_info
operator
()(
const
framework
::
Tensor
&
tensor
)
{
pybind11
::
buffer_info
operator
()(
const
framework
::
Tensor
&
tensor
)
{
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
type
())
{
if
(
framework
::
DataTypeTrait
<
CUR_TYPE
>::
DataType
==
tensor
.
type
())
{
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
std
::
vector
<
size_t
>
dims_outside
;
std
::
vector
<
size_t
>
dims_outside
;
std
::
vector
<
size_t
>
strides
;
std
::
vector
<
size_t
>
strides
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录