From 26cc1fe5080b3e53616920dc93aaa6be2c816b5d Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 28 Oct 2019 10:11:03 +0800 Subject: [PATCH] Replace risky GetInputType method with secure IndicateVarDataType interface (#20668) * replace part of the old implementation, test=develop * restore concat op, test=develop * update all ops implemention & delete GetDataTypeOfVar func, test=develop --- paddle/fluid/framework/operator.cc | 10 ---- paddle/fluid/framework/operator.h | 1 - paddle/fluid/operators/activation_op.cc | 5 +- .../operators/add_position_encoding_op.cc | 11 ++-- paddle/fluid/operators/affine_channel_op.cc | 6 +-- paddle/fluid/operators/affine_grid_op.cc | 8 +-- paddle/fluid/operators/assign_op.cc | 5 +- paddle/fluid/operators/attention_lstm_op.cc | 4 +- .../fluid/operators/average_accumulates_op.cc | 4 +- paddle/fluid/operators/batch_norm_op.cc | 7 +-- paddle/fluid/operators/beam_search_op.cc | 5 +- paddle/fluid/operators/bpr_loss_op.cc | 10 ++-- paddle/fluid/operators/center_loss_op.cc | 8 +-- .../operators/collective/c_allreduce_op.h | 4 +- .../operators/collective/c_broadcast_op.cc | 4 +- paddle/fluid/operators/concat_op.cc | 7 ++- paddle/fluid/operators/conv_op.cc | 14 +++--- paddle/fluid/operators/conv_transpose_op.cc | 10 ++-- paddle/fluid/operators/crf_decoding_op.cc | 5 +- paddle/fluid/operators/crop_op.cc | 11 ++-- paddle/fluid/operators/crop_tensor_op.cc | 11 ++-- paddle/fluid/operators/cross_entropy_op.cc | 11 ++-- paddle/fluid/operators/ctc_align_op.cc | 5 +- paddle/fluid/operators/cvm_op.cc | 10 ++-- paddle/fluid/operators/data_norm_op.cc | 17 ++++--- paddle/fluid/operators/deformable_conv_op.cc | 10 ++-- .../fluid/operators/deformable_conv_v1_op.cc | 10 ++-- .../operators/deformable_psroi_pooling_op.cc | 10 ++-- paddle/fluid/operators/dequantize_op.cc | 5 +- .../detection/anchor_generator_op.cc | 3 +- .../operators/detection/bipartite_match_op.cc | 5 +- .../detection/collect_fpn_proposals_op.cc | 2 +- .../detection/density_prior_box_op.cc | 2 +- .../detection/distribute_fpn_proposals_op.cc | 2 +- .../detection/generate_mask_labels_op.cc | 2 +- .../detection/generate_proposal_labels_op.cc | 2 +- .../detection/generate_proposals_op.cc | 5 +- .../detection/mine_hard_examples_op.cc | 3 +- .../operators/detection/multiclass_nms_op.cc | 2 +- .../fluid/operators/detection/prior_box_op.cc | 3 +- .../retinanet_detection_output_op.cc | 3 +- .../detection/roi_perspective_transform_op.cc | 10 ++-- .../detection/rpn_target_assign_op.cc | 4 +- .../detection/sigmoid_focal_loss_op.cc | 10 ++-- .../operators/detection/target_assign_op.cc | 5 +- .../fluid/operators/detection/yolo_box_op.cc | 4 +- .../operators/detection/yolov3_loss_op.cc | 10 ++-- paddle/fluid/operators/detection_map_op.cc | 2 +- .../operators/distributed_ops/allreduce_op.cc | 4 +- .../distributed_lookup_table_op.cc | 2 +- .../operators/distributed_ops/merge_ids_op.cc | 2 +- .../distributed_ops/ref_by_trainer_id_op.cc | 2 +- .../operators/distributed_ops/split_ids_op.cc | 3 +- paddle/fluid/operators/dropout_op.cc | 6 +-- .../elementwise/elementwise_div_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 14 +++--- paddle/fluid/operators/expand_op.cc | 11 ++-- paddle/fluid/operators/fake_quantize_op.cc | 23 +++++---- paddle/fluid/operators/fc_op.cc | 5 +- paddle/fluid/operators/filter_by_instag_op.cc | 6 +-- paddle/fluid/operators/flatten_op.cc | 16 +++--- paddle/fluid/operators/fsp_op.cc | 11 ++-- .../fused/fused_elemwise_activation_op.cc | 8 +-- .../fused/fused_embedding_fc_lstm_op.cc | 2 +- .../fused/fused_embedding_seq_pool_op.cc | 4 +- .../fused/fusion_conv_inception_op.cc | 3 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 4 +- .../fluid/operators/fused/fusion_lstm_op.cc | 4 +- .../fused/fusion_repeated_fc_relu_op.cc | 4 +- .../fused/fusion_seqconv_eltadd_relu_op.cc | 4 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 4 +- .../fused/fusion_seqpool_concat_op.cc | 2 +- .../fused/fusion_seqpool_cvm_concat_op.cc | 2 +- .../fused/fusion_squared_mat_sub_op.cc | 4 +- paddle/fluid/operators/gather_nd_op.cc | 8 +-- paddle/fluid/operators/gather_op.cc | 11 ++-- paddle/fluid/operators/gather_tree_op.cc | 5 +- .../get_tensor_from_selected_rows_op.cc | 3 +- paddle/fluid/operators/grid_sampler_op.cc | 12 ++--- .../operators/hierarchical_sigmoid_op.cc | 8 +-- paddle/fluid/operators/instance_norm_op.cc | 10 ++-- paddle/fluid/operators/interpolate_op.cc | 10 ++-- paddle/fluid/operators/is_empty_op.cc | 3 +- paddle/fluid/operators/kldiv_loss_op.cc | 8 +-- paddle/fluid/operators/linear_chain_crf_op.cc | 8 +-- paddle/fluid/operators/linspace_op.cc | 4 +- paddle/fluid/operators/lod_reset_op.cc | 11 ++-- paddle/fluid/operators/lookup_table_op.cc | 6 +-- paddle/fluid/operators/lookup_table_v2_op.cc | 6 +-- paddle/fluid/operators/lrn_op.cc | 50 +++++++++++-------- paddle/fluid/operators/lstm_op.cc | 6 ++- paddle/fluid/operators/lstmp_op.cc | 5 +- paddle/fluid/operators/mean_iou_op.cc | 5 +- paddle/fluid/operators/mean_op.cc | 4 +- paddle/fluid/operators/metrics/accuracy_op.cc | 4 +- paddle/fluid/operators/metrics/auc_op.cc | 5 +- .../operators/metrics/precision_recall_op.cc | 5 +- paddle/fluid/operators/multiplex_op.cc | 11 ++-- paddle/fluid/operators/nce_op.cc | 10 ++-- paddle/fluid/operators/one_hot_op.cc | 5 +- paddle/fluid/operators/one_hot_v2_op.cc | 5 +- .../fluid/operators/optimizers/adadelta_op.cc | 4 +- .../fluid/operators/optimizers/adagrad_op.cc | 4 +- paddle/fluid/operators/optimizers/adam_op.cc | 2 +- .../fluid/operators/optimizers/adamax_op.cc | 4 +- .../optimizers/decayed_adagrad_op.cc | 4 +- paddle/fluid/operators/optimizers/dpsgd_op.cc | 4 +- paddle/fluid/operators/optimizers/ftrl_op.cc | 3 +- .../fluid/operators/optimizers/momentum_op.h | 3 +- .../optimizers/proximal_adagrad_op.cc | 4 +- .../operators/optimizers/proximal_gd_op.cc | 4 +- paddle/fluid/operators/optimizers/sgd_op.cc | 2 +- paddle/fluid/operators/pad2d_op.cc | 10 ++-- .../fluid/operators/pad_constant_like_op.cc | 10 ++-- paddle/fluid/operators/pool_op.cc | 7 +-- paddle/fluid/operators/pool_with_index_op.cc | 10 ++-- .../operators/positive_negative_pair_op.cc | 5 +- paddle/fluid/operators/prelu_op.cc | 10 ++-- paddle/fluid/operators/prroi_pool_op.cc | 10 ++-- paddle/fluid/operators/psroi_pool_op.cc | 10 ++-- paddle/fluid/operators/pull_box_sparse_op.cc | 7 ++- paddle/fluid/operators/quantize_op.cc | 5 +- paddle/fluid/operators/random_crop_op.cc | 5 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 6 +-- paddle/fluid/operators/requantize_op.cc | 5 +- paddle/fluid/operators/reshape_op.cc | 21 ++++---- paddle/fluid/operators/roi_align_op.cc | 10 ++-- paddle/fluid/operators/roi_pool_op.cc | 10 ++-- paddle/fluid/operators/sample_logits_op.cc | 6 +-- paddle/fluid/operators/scatter_nd_add_op.cc | 10 ++-- paddle/fluid/operators/scatter_op.cc | 11 ++-- paddle/fluid/operators/selu_op.cc | 7 ++- .../sequence_ops/sequence_concat_op.cc | 6 +-- .../sequence_ops/sequence_expand_as_op.cc | 10 ++-- .../sequence_ops/sequence_expand_op.cc | 10 ++-- .../sequence_ops/sequence_mask_op.cc | 5 +- .../operators/sequence_ops/sequence_pad_op.cc | 6 +-- .../sequence_ops/sequence_pool_op.cc | 6 +-- .../sequence_ops/sequence_scatter_op.cc | 11 ++-- .../sequence_ops/sequence_slice_op.cc | 11 ++-- .../sequence_ops/sequence_softmax_op.cc | 4 +- .../sequence_topk_avg_pooling_op.cc | 2 +- .../sequence_ops/sequence_unpad_op.cc | 6 +-- paddle/fluid/operators/shard_index_op.cc | 5 +- paddle/fluid/operators/shuffle_channel_op.cc | 11 ++-- paddle/fluid/operators/similarity_focus_op.cc | 5 +- paddle/fluid/operators/slice_op.cc | 11 ++-- paddle/fluid/operators/softmax_op.cc | 6 +-- .../softmax_with_cross_entropy_op.cc | 11 ++-- paddle/fluid/operators/space_to_depth_op.cc | 6 +-- paddle/fluid/operators/spectral_norm_op.cc | 8 +-- .../fluid/operators/squared_l2_distance_op.cc | 5 +- paddle/fluid/operators/squeeze_op.cc | 16 +++--- paddle/fluid/operators/strided_slice_op.cc | 11 ++-- .../teacher_student_sigmoid_loss_op.cc | 10 ++-- paddle/fluid/operators/temporal_shift_op.cc | 10 ++-- paddle/fluid/operators/top_k_op.cc | 5 +- paddle/fluid/operators/transpose_op.cc | 22 ++++---- paddle/fluid/operators/tree_conv_op.cc | 10 ++-- paddle/fluid/operators/unfold_op.cc | 11 ++-- paddle/fluid/operators/unpool_op.cc | 10 ++-- paddle/fluid/operators/unsqueeze_op.cc | 6 +-- paddle/fluid/operators/warpctc_op.cc | 10 ++-- 163 files changed, 637 insertions(+), 529 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/linear_chain_crf_op.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d0a12fcfd12..019ad295cc4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -48,16 +48,6 @@ std::vector> kKernelPriority = { std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), }; -proto::VarType::Type GetDataTypeOfVar(const Variable* var) { - if (var->IsType()) { - return var->Get().type(); - } else if (var->IsType()) { - return var->Get().value().type(); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); - } -} - static DDim GetDimsDebug(const Scope& scope, const std::string& name, bool get_actual_dim = false) { Variable* var = scope.FindVar(name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 640d4aff5e3..ab956a9474c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { } } -proto::VarType::Type GetDataTypeOfVar(const Variable* var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 82ff2c1a72b..be4786fadaf 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(oper.IndicateVarDataType(ctx, name), + ctx.GetPlace(), layout, library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index 2580c5a523e..61a9fa76507 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 1476cfc2c89..6040ed7550d 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -121,9 +121,9 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 9d7100cc3db..c46b4260115 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel { library = framework::LibraryType::kCUDNN; } #endif - auto data_type = ctx.Input("Theta")->type(); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kAnyLayout, library); } @@ -191,9 +191,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("Theta")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 22120487865..c2b3c818c65 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -89,8 +89,9 @@ class AssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index c6d98f1f9a5..53bd2e4c450 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 0922b03b5f5..273df31fc80 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 72c023dd992..546605c8dba 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType BatchNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, + library); } template diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index a6aa35e0569..62cfbfcaae2 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { // Compute on CPU for cases with batch_size > 4. if (batch_size <= 4) { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), + ctx.GetPlace()); } else { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), platform::CPUPlace()); } } diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 51c4d878142..1ad02713049 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -52,8 +52,9 @@ class BprLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -98,8 +99,9 @@ class BprLossGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index bf766a056a7..0b6ce823973 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -61,8 +61,9 @@ class CenterLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.Input("SampleCenterDiff")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 02f6210ca4c..c661d421598 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index 72d330306cc..928fa8549ff 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index c76ccb70f76..daef6310ddf 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel { if (flag == 0) { PADDLE_THROW("All Inputs of Concat OP are Empty!"); } - #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), @@ -175,9 +174,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 514fde453bb..cf720cc627c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - auto input_data_type = ctx.Input("Input")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); std::string data_format = "AnyLayout"; // todo enable data layout when it's ready framework::DataLayout layout = framework::StringToDataLayout(data_format); @@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; @@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( customized_type_value = kConvMKLDNNFP32; } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 6dddd4848e3..4ba330447e2 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void Conv2DTransposeOpMaker::Make() { @@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( } framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker { diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index 4676bd04646..746f96dcac0 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -160,8 +160,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 2ced5467f1a..f42463582f8 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -53,8 +53,9 @@ class CropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -174,9 +175,9 @@ class CropOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc index f8045efb6b3..e4a314cea06 100644 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ b/paddle/fluid/operators/crop_tensor_op.cc @@ -98,8 +98,9 @@ class CropTensorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -254,9 +255,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 8a80619f663..d6da40ddfe6 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -107,8 +107,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { @@ -157,9 +158,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index 4abe9509e6d..9982230495d 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -39,8 +39,9 @@ class CTCAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 53ed86ade48..7675a6acf7e 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -52,8 +52,9 @@ class CVMOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -93,8 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 5dc83ac7b30..6d1168c3ae8 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel { if (input_data_type == framework::proto::VarType::FP64) { dn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSize")->type(), + PADDLE_ENFORCE_EQ(dn_param_type, + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"), "BatchSize input should be of float type"); - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSum")->type(), - "BatchSum input should be of float type"); PADDLE_ENFORCE_EQ(dn_param_type, - ctx.Input("BatchSquareSum")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"), + "BatchSum input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType( + ctx, "BatchSquareSum"), "BatchSquareSum input should be of float type"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index c000787545d..1eedcc010f2 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -216,8 +216,9 @@ class DeformableConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -275,8 +276,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index 6129e296550..8bef1a0b748 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -199,8 +199,9 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -253,8 +254,9 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index d17f22b9b4f..dd2f700901b 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -199,8 +199,9 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -247,8 +248,9 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Trans")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 97f49dbcb08..0ed3293418f 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 4a333b559f8..d3287249166 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index af7797a6d7c..785a2072631 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -45,8 +45,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("DistMat")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index 0603072835e..8c53eb5da2f 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]); + OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois"); return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index cacd47ed4a8..f9ea1dc67d9 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 4cc989b6325..ce37e73b750 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 0d77c7f3a79..bd18d77174f 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 451e0ca8550..873d44b27e2 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 06e48f1262a..bcbd7e1e203 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -60,8 +60,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Anchors")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index c68fe2439ca..c8701d28101 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("ClsLoss")->type(), platform::CPUPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index f5b9be14ad6..28380a04ba1 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Scores")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index da6e132498d..8d821739f6f 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_input_type = ctx.Input("Input")->type(); + auto input_input_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Input"); framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc index 4a6dfec12e6..a79a7608ea9 100644 --- a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); - + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"); return framework::OpKernelType(input_data_type, platform::CPUPlace()); // ctx.GetPlace()); } diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index ce10de40a96..74756a2a22a 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -525,8 +525,9 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -545,8 +546,9 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 338954346c5..67aab192fbe 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; @@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc index 50ff3cb120e..eb59c943e4b 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -63,8 +63,9 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -116,8 +117,9 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index c057c82ce0f..b2487b13523 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -57,8 +57,9 @@ class TargetAssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index e0d7e25d944..602efd7b80a 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 5732b180526..d6cd3171ee3 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -98,8 +98,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -255,8 +256,9 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index dff97f7c77f..cfd159a2cca 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("DetectRes")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cc index 57d68eb931f..86f1c28a9dd 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.cc +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cc @@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc index 3e354791ea9..c34fb7b96f2 100644 --- a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc @@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 1b0b4dd3169..712ff56c8c2 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X").front()->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc index 7e16e6ff66b..6bf70844491 100644 --- a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc @@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X")[0]->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index d46b57e7e15..603f6975922 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()), - ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 273015f9763..0e060c3a1a3 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -121,9 +121,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 3c460242f3d..82cc6df4a66 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = ctx.Input("DDX")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 8da447adaa7..67babe64044 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("DOut")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY if (ctx.HasInput("DDX") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true, "Input(DDY) should not be null"); - input_data_type = ctx.Input("DDY")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY"); } else if (ctx.HasInput("DDY") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true, "Input(DDX) should not be null"); - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } else { - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 677130f2f92..41147b77ee0 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -65,8 +65,9 @@ class ExpandOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -180,9 +181,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 034f3c7dceb..53cdcc99226 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -190,8 +190,9 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -303,8 +304,9 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -375,8 +377,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index f27fbd6726c..f831d1173e2 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/filter_by_instag_op.cc b/paddle/fluid/operators/filter_by_instag_op.cc index ebf44e5b9a5..a48c901f9e6 100644 --- a/paddle/fluid/operators/filter_by_instag_op.cc +++ b/paddle/fluid/operators/filter_by_instag_op.cc @@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 9f2a122203b..c27bb1606b3 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -69,8 +69,9 @@ class FlattenOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -130,8 +131,9 @@ class FlattenGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -221,9 +223,9 @@ class Flatten2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index fbe8e56a616..0706f9ce376 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; @@ -107,9 +108,9 @@ class FSPOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 1cd6c40aa05..9a156147aa4 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), ctx.Input("Y")->type(), "The element's type of input should be the same."); - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 4c13d39406b..9124e0c4c9b 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - ctx.Input("Embeddings")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"), ctx.device_context()); } diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 89bf0aa18f9..727e135c356 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -57,7 +57,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -126,7 +126,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc index 964335ed2bc..e18ac13d345 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 5c895099073..f9ade705ead 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 32f0e37a64b..c256e581ee8 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index 4c11482f507..d98e782562a 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape( framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionRepeatedFCReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index 519670cc6a7..1e25a9490b8 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 95a08d3b0f0..d79bf7cdcc8 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index b14ee88aa53..7ca02a2541d 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape( framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc index 14e327bb37d..0a245bb0505 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape( framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolCVMConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 2d10056044e..2d4a3977980 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape( framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSquaredMatSubOpMaker::Make() { diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index cbeefa0a7f6..b2a4029c8f6 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - const auto& x_type = x->type(); + const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType( x_type, x_type == framework::proto::VarType::BOOL @@ -82,9 +82,9 @@ class GatherNdGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index cbabd59cf63..075be1caf48 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -45,8 +45,9 @@ class GatherOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -62,9 +63,9 @@ class GatherGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 94fa3b6aa1e..26f9989121e 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -40,8 +40,9 @@ class GatherTreeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Ids")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index c0893359af2..d8470bad118 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 57a1fcd42da..5338889363a 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -68,9 +68,9 @@ class GridSampleOp : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; @@ -164,9 +164,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 2b3e2e5c484..a27fcf628cc 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index 6375c92de2d..bb6b37e64e4 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType InstanceNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } template @@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } std::unique_ptr InstanceNormDoubleGradMaker::Apply() const { diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 612f770bb7c..cbe9865673b 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -407,9 +407,9 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index 092a6eae6f5..109e96fb7ba 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); - return framework::OpKernelType(x->type(), x->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), x->place()); } }; diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 983ab3dba6e..d5976e7f4ae 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc old mode 100755 new mode 100644 index b78a6ceb519..b6758c8975b --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -224,8 +224,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { // is determined by its input "Emission". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; @@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input(framework::GradVarName("LogLikelihood"))->type(), + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("LogLikelihood")), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index f4aeb062d8d..7ea3b06e02e 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - ctx.Input("Start")->type(), ctx.device_context(), - layout_, library_); + OperatorWithKernel::IndicateVarDataType(ctx, "Start"), + ctx.device_context(), layout_, library_); } }; diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 409f8397eb2..190a7cdf12f 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -46,8 +46,9 @@ class LoDResetOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -172,9 +173,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 5285e3cae9a..c1d45bb7a0d 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 511f50a83d5..f0cffa4e1fe 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 5ad94cfde90..d5b092ec99d 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -130,26 +130,6 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; -namespace { -framework::OpKernelType GetExpectedLRNKernel( - const framework::ExecutionContext& ctx) { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - } -#endif - - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); -} -} // namespace - class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; } // namespace operators diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index bf68c57e67f..43af877085b 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index b9f42237180..68e204983e8 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("BatchGate")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index bb290046f3a..615b9ea4847 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -44,8 +44,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Predictions")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 2b2f8450768..e19ac59ee57 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index d6360c83f09..bedcfc6a387 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Out")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index e0eebad08bb..3543a334935 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -53,8 +53,9 @@ class AucOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Predict")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index f6d6ffc668c..58b948b5a43 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -92,8 +92,9 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("MaxProbs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 7cb213e8995..843f0a68e1d 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -55,8 +55,9 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,9 +126,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index e78fda11139..0f26e3953fe 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -92,8 +92,9 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; @@ -246,8 +247,9 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 6042b97bf57..e4d50db30a3 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -51,8 +51,9 @@ class OneHotOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index 7a75afca09c..62f85496f9b 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -48,8 +48,9 @@ class OneHotV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 01c0f1bb2d4..bde7131379a 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 0310fe2eba8..b3aff1eff8c 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index fc851e56cbf..c5a6fe5875b 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AdamOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index a0152906235..9ede7a56d0b 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index b44a84ccf71..5c6c38da928 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index f263e67593b..9a7b2112d4e 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 98b71175624..3f0cd8aa3c8 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index bb77d2ea6cd..10b72524efd 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 9dd9b8afbd4..3e2f12137af 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index fccfc2b4584..cf3c3e2ccb9 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index bbd78db51a9..dcc6ce41b27 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(data_type, ctx.device_context()); } diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 3069d560144..461db5fdc94 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -621,9 +621,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 31ed0a686f7..1c4bf7035ef 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -56,8 +56,9 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; @@ -186,8 +187,9 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index f19433115a7..6ece163f697 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -134,8 +134,9 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { @@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, "float16 can only be used when CUDNN is used"); diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 91bd2a902f7..d8c2ccaa960 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -76,8 +76,9 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -96,8 +97,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index e917e778e41..b0677ff10f6 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -95,8 +95,9 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Score")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Score"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index ccb08b245a4..364f3689f9f 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,8 +56,9 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -112,8 +113,9 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index 5c559bda339..c11d09350a4 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -114,8 +114,9 @@ class PRROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -135,8 +136,9 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index c241cf461ac..a9128fbd28d 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -131,8 +131,9 @@ class PSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -151,8 +152,9 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 8532649614c..3af3fb49675 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -104,10 +104,9 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.MultiInput(framework::GradVarName("Out"))[0] - ->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index d8e20f4c4ae..69264e3a45e 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 65a8d603fce..15911a51a2c 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -22,8 +22,9 @@ class RandomCropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 4ed5bd1c70b..5cd2627870d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -267,9 +267,9 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index d156ae20776..c17b6ef8842 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index cf97eac3fbe..9a32a02c11b 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -203,8 +203,9 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -305,8 +306,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -475,9 +477,9 @@ class Reshape2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -511,8 +513,9 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("DDX")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 0914ad81c77..a57266690b0 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -65,8 +65,9 @@ class ROIAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -85,8 +86,9 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("ROIs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index cfac7e09e12..0515768a630 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -70,8 +70,9 @@ class ROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -90,8 +91,9 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index 8ce2d52273d..962b5dbc505 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits"); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; @@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("SampledLogits"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("SampledLogits")); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc index 41f18eaeaf8..ba65832dcea 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cc +++ b/paddle/fluid/operators/scatter_nd_add_op.cc @@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), - ctx.Input("Updates")->type(), + PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + OperatorWithKernel::IndicateVarDataType(ctx, "Updates"), "Ref and Updates must have same type"); return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); @@ -95,9 +95,9 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 4eb5b7ad9d1..b3f43a28dff 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -48,8 +48,9 @@ class ScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 67fca18000a..f71d844d9e9 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/selu_op.h" + +#include #include +#include namespace paddle { namespace operators { @@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index d652f9216f8..118c8ce0b11 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -102,9 +102,9 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index e1f6c3e3d59..c7284d0950c 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -153,9 +153,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index b7c0420636a..90f794ab5ff 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -208,9 +208,9 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index a7225adbf9f..cd0170dd1b4 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -40,8 +40,9 @@ class SequenceMaskOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index fcc49096e2c..de5a0aa45bc 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 51e354dcd17..dcc762f7907 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -122,9 +122,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 5a22212edf2..7f9dbbf7eca 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -113,8 +113,9 @@ class SequenceScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -132,9 +133,9 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 4b2ec6e7cad..537184d8b52 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -51,8 +51,9 @@ class SequenceSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index c0f993ac7aa..12729f43247 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; @@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("Out")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc index 232f324de77..06b16152ef9 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc @@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index 8256460858b..558d180cfba 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc index 578dcd37bb4..a02d0367159 100644 --- a/paddle/fluid/operators/shard_index_op.cc +++ b/paddle/fluid/operators/shard_index_op.cc @@ -41,8 +41,9 @@ class ShardIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index ad6fb3510f0..48da765416f 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,8 +35,9 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -83,9 +84,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 21871d76569..e49ce7c4872 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -70,8 +70,9 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 4cd7b33a4a8..9adb4de01cb 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -128,8 +128,9 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -243,9 +244,9 @@ class SliceOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 9d73a19197c..09c08f23307 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); @@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 8cde72921cb..727d67c2fba 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -171,8 +171,9 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } }; @@ -232,9 +233,9 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Loss"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 3d66613248c..e2c2998095e 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -167,9 +167,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 5690265573f..71049c58e16 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; @@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cc b/paddle/fluid/operators/squared_l2_distance_op.cc index 6e82bf40749..17538c98fe5 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cc +++ b/paddle/fluid/operators/squared_l2_distance_op.cc @@ -152,8 +152,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("sub_result")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index b056d2feacc..a7e10457fd7 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -104,8 +104,9 @@ class SqueezeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -122,8 +123,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -230,9 +232,9 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 7c81d71562f..5cd7a786363 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -124,8 +124,9 @@ class StridedSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.Input("Input")->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.Input("Input")->place()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -230,9 +231,9 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index 7f95d16f09b..7823b9d8501 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -55,8 +55,9 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,8 +126,9 @@ class TeacherStudentSigmoidLossGradientOp // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index a438832b5dc..6663d3f5571 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -139,9 +139,9 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index db763a051d1..fdf5148eb87 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 226aad03845..eab6d437d47 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -164,9 +165,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; @@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -268,9 +270,9 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc index 566939afaa4..0c72275c5bb 100644 --- a/paddle/fluid/operators/tree_conv_op.cc +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -104,8 +104,9 @@ class TreeConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; @@ -153,8 +154,9 @@ class TreeConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/unfold_op.cc b/paddle/fluid/operators/unfold_op.cc index d21340b478b..99907e066b2 100644 --- a/paddle/fluid/operators/unfold_op.cc +++ b/paddle/fluid/operators/unfold_op.cc @@ -120,8 +120,9 @@ class UnfoldOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -141,9 +142,9 @@ class UnfoldGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index fae5041c932..0693df843eb 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -74,8 +74,9 @@ class UnpoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: @@ -117,8 +118,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index fc849e73c57..e55de4508bc 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -215,9 +215,9 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index df9212f9c93..d7f6714710f 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -60,8 +60,9 @@ class WarpCTCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context(), layout_, library_); } }; @@ -173,8 +174,9 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } }; -- GitLab