From d29e9aa4c48c0b3df7e5e08b41e57b9c5f31220f Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 29 Oct 2019 16:57:40 +0800 Subject: [PATCH] [Cherry-pick to 1.6] Block part of "tensor should not be null" error message (#20845) * Add IndicateVarDataType interface to block tensor is not initialized problem in OP GetExceptedKernelType (#20044) * add indicate_var_data_type inferface, test=develop * add unittests & polish error message, test=develop * remove needless include, test=develop * extract public function & polish message, test=develop * delete empty var check, test=develop * change data_type to pointer parameter, test=develop * polish details, test=develop * Replace risky GetInputType method with secure IndicateVarDataType interface (#20668) * replace part of the old implementation, test=develop * restore concat op, test=develop * update all ops implemention & delete GetDataTypeOfVar func, test=develop test=release/1.6 --- paddle/fluid/framework/operator.cc | 89 +++++---- paddle/fluid/framework/operator.h | 6 +- paddle/fluid/framework/operator_test.cc | 179 ++++++++++++++++++ paddle/fluid/framework/variable.h | 14 +- paddle/fluid/operators/activation_op.cc | 5 +- .../operators/add_position_encoding_op.cc | 11 +- paddle/fluid/operators/affine_channel_op.cc | 6 +- paddle/fluid/operators/affine_grid_op.cc | 8 +- paddle/fluid/operators/assign_op.cc | 5 +- paddle/fluid/operators/attention_lstm_op.cc | 4 +- .../fluid/operators/average_accumulates_op.cc | 4 +- paddle/fluid/operators/batch_norm_op.cc | 7 +- paddle/fluid/operators/beam_search_op.cc | 5 +- paddle/fluid/operators/bpr_loss_op.cc | 10 +- paddle/fluid/operators/center_loss_op.cc | 8 +- .../operators/collective/c_allreduce_op.h | 4 +- .../operators/collective/c_broadcast_op.cc | 4 +- paddle/fluid/operators/concat_op.cc | 7 +- paddle/fluid/operators/conv_op.cc | 14 +- paddle/fluid/operators/conv_transpose_op.cc | 10 +- paddle/fluid/operators/crf_decoding_op.cc | 5 +- paddle/fluid/operators/crop_op.cc | 11 +- paddle/fluid/operators/crop_tensor_op.cc | 11 +- paddle/fluid/operators/cross_entropy_op.cc | 11 +- paddle/fluid/operators/ctc_align_op.cc | 5 +- paddle/fluid/operators/cvm_op.cc | 10 +- paddle/fluid/operators/data_norm_op.cc | 17 +- paddle/fluid/operators/deformable_conv_op.cc | 10 +- .../fluid/operators/deformable_conv_v1_op.cc | 10 +- .../operators/deformable_psroi_pooling_op.cc | 10 +- paddle/fluid/operators/dequantize_op.cc | 5 +- .../detection/anchor_generator_op.cc | 3 +- .../operators/detection/bipartite_match_op.cc | 5 +- .../detection/collect_fpn_proposals_op.cc | 2 +- .../detection/density_prior_box_op.cc | 2 +- .../detection/distribute_fpn_proposals_op.cc | 2 +- .../detection/generate_mask_labels_op.cc | 2 +- .../detection/generate_proposal_labels_op.cc | 2 +- .../detection/generate_proposals_op.cc | 5 +- .../detection/mine_hard_examples_op.cc | 3 +- .../operators/detection/multiclass_nms_op.cc | 2 +- .../fluid/operators/detection/prior_box_op.cc | 3 +- .../retinanet_detection_output_op.cc | 3 +- .../detection/roi_perspective_transform_op.cc | 10 +- .../detection/rpn_target_assign_op.cc | 4 +- .../detection/sigmoid_focal_loss_op.cc | 10 +- .../operators/detection/target_assign_op.cc | 5 +- .../fluid/operators/detection/yolo_box_op.cc | 4 +- .../operators/detection/yolov3_loss_op.cc | 10 +- paddle/fluid/operators/detection_map_op.cc | 2 +- .../operators/distributed_ops/allreduce_op.cc | 4 +- .../distributed_lookup_table_op.cc | 2 +- .../operators/distributed_ops/merge_ids_op.cc | 2 +- .../distributed_ops/ref_by_trainer_id_op.cc | 2 +- .../operators/distributed_ops/split_ids_op.cc | 3 +- paddle/fluid/operators/dropout_op.cc | 6 +- .../elementwise/elementwise_div_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 14 +- paddle/fluid/operators/expand_op.cc | 11 +- paddle/fluid/operators/fake_quantize_op.cc | 23 ++- paddle/fluid/operators/fc_op.cc | 5 +- paddle/fluid/operators/filter_by_instag_op.cc | 6 +- paddle/fluid/operators/flatten_op.cc | 16 +- paddle/fluid/operators/fsp_op.cc | 11 +- .../fused/fused_elemwise_activation_op.cc | 8 +- .../fused/fused_embedding_fc_lstm_op.cc | 2 +- .../fused/fused_embedding_seq_pool_op.cc | 4 +- .../fused/fusion_conv_inception_op.cc | 3 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 4 +- .../fluid/operators/fused/fusion_lstm_op.cc | 4 +- .../fused/fusion_repeated_fc_relu_op.cc | 4 +- .../fused/fusion_seqconv_eltadd_relu_op.cc | 4 +- .../fused/fusion_seqexpand_concat_fc_op.cc | 4 +- .../fused/fusion_seqpool_concat_op.cc | 2 +- .../fused/fusion_seqpool_cvm_concat_op.cc | 2 +- .../fused/fusion_squared_mat_sub_op.cc | 4 +- paddle/fluid/operators/gather_nd_op.cc | 8 +- paddle/fluid/operators/gather_op.cc | 11 +- paddle/fluid/operators/gather_tree_op.cc | 5 +- .../get_tensor_from_selected_rows_op.cc | 3 +- paddle/fluid/operators/grid_sampler_op.cc | 12 +- .../operators/hierarchical_sigmoid_op.cc | 8 +- paddle/fluid/operators/instance_norm_op.cc | 10 +- paddle/fluid/operators/interpolate_op.cc | 10 +- paddle/fluid/operators/is_empty_op.cc | 3 +- paddle/fluid/operators/kldiv_loss_op.cc | 8 +- paddle/fluid/operators/linear_chain_crf_op.cc | 8 +- paddle/fluid/operators/linspace_op.cc | 4 +- paddle/fluid/operators/lod_reset_op.cc | 11 +- paddle/fluid/operators/lookup_table_op.cc | 6 +- paddle/fluid/operators/lookup_table_v2_op.cc | 6 +- paddle/fluid/operators/lrn_op.cc | 50 ++--- paddle/fluid/operators/lstm_op.cc | 6 +- paddle/fluid/operators/lstmp_op.cc | 5 +- paddle/fluid/operators/mean_iou_op.cc | 5 +- paddle/fluid/operators/mean_op.cc | 4 +- paddle/fluid/operators/metrics/accuracy_op.cc | 4 +- paddle/fluid/operators/metrics/auc_op.cc | 5 +- .../operators/metrics/precision_recall_op.cc | 5 +- paddle/fluid/operators/mul_op.cc | 2 +- paddle/fluid/operators/multiplex_op.cc | 11 +- paddle/fluid/operators/nce_op.cc | 10 +- paddle/fluid/operators/one_hot_op.cc | 5 +- paddle/fluid/operators/one_hot_v2_op.cc | 5 +- .../fluid/operators/optimizers/adadelta_op.cc | 4 +- .../fluid/operators/optimizers/adagrad_op.cc | 4 +- paddle/fluid/operators/optimizers/adam_op.cc | 2 +- .../fluid/operators/optimizers/adamax_op.cc | 4 +- .../optimizers/decayed_adagrad_op.cc | 4 +- paddle/fluid/operators/optimizers/dpsgd_op.cc | 4 +- paddle/fluid/operators/optimizers/ftrl_op.cc | 3 +- .../fluid/operators/optimizers/momentum_op.h | 3 +- .../optimizers/proximal_adagrad_op.cc | 4 +- .../operators/optimizers/proximal_gd_op.cc | 4 +- paddle/fluid/operators/optimizers/sgd_op.cc | 2 +- paddle/fluid/operators/pad2d_op.cc | 10 +- .../fluid/operators/pad_constant_like_op.cc | 10 +- paddle/fluid/operators/pool_op.cc | 7 +- paddle/fluid/operators/pool_with_index_op.cc | 10 +- .../operators/positive_negative_pair_op.cc | 5 +- paddle/fluid/operators/prelu_op.cc | 10 +- paddle/fluid/operators/prroi_pool_op.cc | 10 +- paddle/fluid/operators/psroi_pool_op.cc | 10 +- paddle/fluid/operators/pull_box_sparse_op.cc | 7 +- paddle/fluid/operators/quantize_op.cc | 5 +- paddle/fluid/operators/random_crop_op.cc | 5 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 6 +- paddle/fluid/operators/requantize_op.cc | 5 +- paddle/fluid/operators/reshape_op.cc | 21 +- paddle/fluid/operators/roi_align_op.cc | 10 +- paddle/fluid/operators/roi_pool_op.cc | 10 +- paddle/fluid/operators/sample_logits_op.cc | 6 +- paddle/fluid/operators/save_op.cc | 2 +- paddle/fluid/operators/scatter_nd_add_op.cc | 10 +- paddle/fluid/operators/scatter_op.cc | 11 +- paddle/fluid/operators/selu_op.cc | 7 +- .../sequence_ops/sequence_concat_op.cc | 6 +- .../sequence_ops/sequence_expand_as_op.cc | 10 +- .../sequence_ops/sequence_expand_op.cc | 10 +- .../sequence_ops/sequence_mask_op.cc | 5 +- .../operators/sequence_ops/sequence_pad_op.cc | 6 +- .../sequence_ops/sequence_pool_op.cc | 6 +- .../sequence_ops/sequence_scatter_op.cc | 11 +- .../sequence_ops/sequence_slice_op.cc | 11 +- .../sequence_ops/sequence_softmax_op.cc | 4 +- .../sequence_topk_avg_pooling_op.cc | 2 +- .../sequence_ops/sequence_unpad_op.cc | 6 +- paddle/fluid/operators/shard_index_op.cc | 5 +- paddle/fluid/operators/shuffle_channel_op.cc | 11 +- paddle/fluid/operators/similarity_focus_op.cc | 5 +- paddle/fluid/operators/slice_op.cc | 11 +- paddle/fluid/operators/softmax_op.cc | 6 +- .../softmax_with_cross_entropy_op.cc | 11 +- paddle/fluid/operators/space_to_depth_op.cc | 6 +- paddle/fluid/operators/spectral_norm_op.cc | 8 +- .../fluid/operators/squared_l2_distance_op.cc | 5 +- paddle/fluid/operators/squeeze_op.cc | 16 +- paddle/fluid/operators/strided_slice_op.cc | 11 +- .../teacher_student_sigmoid_loss_op.cc | 10 +- paddle/fluid/operators/temporal_shift_op.cc | 10 +- paddle/fluid/operators/top_k_op.cc | 5 +- paddle/fluid/operators/transpose_op.cc | 22 ++- paddle/fluid/operators/tree_conv_op.cc | 10 +- paddle/fluid/operators/unfold_op.cc | 11 +- paddle/fluid/operators/unpool_op.cc | 10 +- paddle/fluid/operators/unsqueeze_op.cc | 6 +- paddle/fluid/operators/warpctc_op.cc | 10 +- 167 files changed, 882 insertions(+), 565 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/linear_chain_crf_op.cc diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 3a573da510..0f8aae2eab 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -48,16 +48,6 @@ std::vector> kKernelPriority = { std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), }; -proto::VarType::Type GetDataTypeOfVar(const Variable* var) { - if (var->IsType()) { - return var->Get().type(); - } else if (var->IsType()) { - return var->Get().value().type(); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); - } -} - static DDim GetDimsDebug(const Scope& scope, const std::string& name, bool get_actual_dim = false) { Variable* var = scope.FindVar(name); @@ -1152,40 +1142,65 @@ Scope* OperatorWithKernel::PrepareData( return new_scope; } +void OperatorWithKernel::ParseInputDataType( + const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* data_type) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + const std::vector vars = ctx.MultiInputVar(name); + for (size_t i = 0; i < vars.size(); ++i) { + const Variable* var = vars[i]; + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &(var->Get().value()); + } + if (t != nullptr) { + PADDLE_ENFORCE_EQ(t->IsInitialized(), true, + "The Tensor in the %s Op's Input Variable %s(%s) is " + "not initialized.", + Type(), name, ctx.Inputs(name).at(i)); + proto::VarType::Type tmp = t->type(); + PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type, + "The DataType of %s Op's duplicable Variable %s must be " + "consistent. The current variable type is (%s), but the " + "previous variable type is (%s).", + Type(), name, DataTypeToString(tmp), + DataTypeToString(*data_type)); + *data_type = tmp; + } + } + } +} + proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; for (auto& input : this->inputs_) { - const std::vector vars = ctx.MultiInputVar(input.first); - for (size_t i = 0; i < vars.size(); ++i) { - const Variable* var = vars[i]; - if (var != nullptr) { - const Tensor* t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &(var->Get().value()); - } - if (t != nullptr) { - PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized", - input.first, i); - proto::VarType::Type tmp = t->type(); - PADDLE_ENFORCE( - tmp == data_type || data_type == dafault_data_type, - "DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)", - Type(), input.first, DataTypeToString(data_type), - DataTypeToString(tmp)); - data_type = tmp; - } - } - } + ParseInputDataType(ctx, input.first, &data_type); } - PADDLE_ENFORCE(data_type != dafault_data_type, - "DataType should be indicated by input"); + PADDLE_ENFORCE_NE(data_type, dafault_data_type, + "DataType should be indicated by input Variable."); + return data_type; +} + +proto::VarType::Type OperatorWithKernel::IndicateVarDataType( + const ExecutionContext& ctx, const std::string& name) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + proto::VarType::Type data_type = dafault_data_type; + ParseInputDataType(ctx, name, &data_type); + PADDLE_ENFORCE_NE( + data_type, dafault_data_type, + "The Input Variable(%s) of %s Op used to determine kernel data type " + "is empty or not LoDTensor or SelectedRows.", + name, Type()); return data_type; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5899a14f50..ab956a9474 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { } } -proto::VarType::Type GetDataTypeOfVar(const Variable* var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); @@ -459,6 +458,9 @@ class OperatorWithKernel : public OperatorBase { void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; + proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx, + const std::string& name) const; + virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; std::vector* GetKernelConfig(const OpKernelType& key) const; @@ -470,6 +472,8 @@ class OperatorWithKernel : public OperatorBase { const OpKernelType& expected_kernel_type) const; private: + void ParseInputDataType(const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* type) const; // indicate kernel DataType by input data. By default all input data must be // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index fe4804ac25..aeb1daa4ed 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -315,3 +315,182 @@ TEST(VarNameTest, all) { original_var_name = paddle::framework::GradOriginalVarName(original_var_name); ASSERT_EQ(original_var_name, ""); } + +namespace paddle { +namespace framework { + +class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("LoDTensor", "Input of Tensor type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateSelectedRowsDataTypeTestProtoMaker + : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("SelectedRows", "Input of SelectedRows type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateOtherDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("Other", "Input of Other type Variable"); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +template +class IndicateVarDataTypeKernelTest : public OpKernel { + public: + void Compute(const ExecutionContext& ctx) const {} +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + indicate_lod_tensor_data_type_test, + paddle::framework::IndicateLoDTensorDataTypeTest, + paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_selected_rows_data_type_test, + paddle::framework::IndicateSelectedRowsDataTypeTest, + paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest, + paddle::framework::IndicateOtherDataTypeTestProtoMaker); + +REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); + +TEST(IndicateVarDataTypeTest, lodtensor) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_lod_tensor_data_type_test"); + BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lodtensor_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find( + "The Tensor in the indicate_lod_tensor_data_type_test Op's " + "Input Variable LoDTensor(lodtensor_1) is not initialized") != + std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, selectedrows) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_selected_rows_data_type_test"); + BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("selected_rows_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test " + "Op's Input Variable SelectedRows(selected_rows_1) is not " + "initialized") != std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, other) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_other_data_type_test"); + BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lod_tensor_array_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of " + "indicate_other_data_type_test Op used to " + "determine kernel data type " + "is empty or not LoDTensor or SelectedRows") != + std::string::npos); + } + ASSERT_TRUE(caught); +} diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index b9d07da822..5d9633a61d 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -30,9 +30,9 @@ class Variable { static_assert( IsRegisteredVarType(), "Not registered type. Please register T inside var_type_traits.h"); - PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", + "The Variable type must be %s, but the type it holds is %s.", ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); return *static_cast(holder_->Ptr()); @@ -45,10 +45,10 @@ class Variable { if (!holder_) { holder_.reset(new PlaceholderImpl()); } else { - PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", - ToTypeName(VarTypeTrait::kId), - ToTypeName(holder_->Type())); + PADDLE_ENFORCE( + holder_->Type() == VarTypeTrait::kId, + "The Variable type must be %s, but the type it holds is %s.", + ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); } return static_cast(holder_->Ptr()); } @@ -61,7 +61,7 @@ class Variable { void Clear() { holder_.reset(); } int Type() const { - PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); return holder_->Type(); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 82ff2c1a72..be4786fada 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(oper.IndicateVarDataType(ctx, name), + ctx.GetPlace(), layout, library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index 2580c5a523..61a9fa7650 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 1476cfc2c8..6040ed7550 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -121,9 +121,9 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 9d7100cc3d..c46b426011 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel { library = framework::LibraryType::kCUDNN; } #endif - auto data_type = ctx.Input("Theta")->type(); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kAnyLayout, library); } @@ -191,9 +191,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("Theta")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 2212048786..c2b3c818c6 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -89,8 +89,9 @@ class AssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index c6d98f1f9a..53bd2e4c45 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 0922b03b5f..273df31fc8 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 72c023dd99..546605c8db 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType BatchNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, + library); } template diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index a6aa35e056..62cfbfcaae 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { // Compute on CPU for cases with batch_size > 4. if (batch_size <= 4) { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), + ctx.GetPlace()); } else { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), platform::CPUPlace()); } } diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 51c4d87814..1ad0271304 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -52,8 +52,9 @@ class BprLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -98,8 +99,9 @@ class BprLossGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index bf766a056a..0b6ce82397 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -61,8 +61,9 @@ class CenterLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.Input("SampleCenterDiff")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 02f6210ca4..c661d42159 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index 72d330306c..928fa8549f 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index c76ccb70f7..daef6310dd 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel { if (flag == 0) { PADDLE_THROW("All Inputs of Concat OP are Empty!"); } - #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), @@ -175,9 +174,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 514fde453b..cf720cc627 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - auto input_data_type = ctx.Input("Input")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); std::string data_format = "AnyLayout"; // todo enable data layout when it's ready framework::DataLayout layout = framework::StringToDataLayout(data_format); @@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; @@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( customized_type_value = kConvMKLDNNFP32; } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 6dddd4848e..4ba330447e 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void Conv2DTransposeOpMaker::Make() { @@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( } framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker { diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index 4676bd0464..746f96dcac 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -160,8 +160,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 2ced5467f1..f42463582f 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -53,8 +53,9 @@ class CropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -174,9 +175,9 @@ class CropOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc index 9b536e98e4..43fa27ef4b 100644 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ b/paddle/fluid/operators/crop_tensor_op.cc @@ -87,8 +87,9 @@ class CropTensorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -243,9 +244,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 8a80619f66..d6da40ddfe 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -107,8 +107,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { @@ -157,9 +158,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index 4abe9509e6..9982230495 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -39,8 +39,9 @@ class CTCAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 53ed86ade4..7675a6acf7 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -52,8 +52,9 @@ class CVMOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -93,8 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 5dc83ac7b3..6d1168c3ae 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel { if (input_data_type == framework::proto::VarType::FP64) { dn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSize")->type(), + PADDLE_ENFORCE_EQ(dn_param_type, + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"), "BatchSize input should be of float type"); - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSum")->type(), - "BatchSum input should be of float type"); PADDLE_ENFORCE_EQ(dn_param_type, - ctx.Input("BatchSquareSum")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"), + "BatchSum input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType( + ctx, "BatchSquareSum"), "BatchSquareSum input should be of float type"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index c000787545..1eedcc010f 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -216,8 +216,9 @@ class DeformableConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -275,8 +276,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index 6129e29655..8bef1a0b74 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -199,8 +199,9 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -253,8 +254,9 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index d17f22b9b4..dd2f700901 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -199,8 +199,9 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -247,8 +248,9 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Trans")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 97f49dbcb0..0ed3293418 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 4a333b559f..d328724916 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index af7797a6d7..785a207263 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -45,8 +45,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("DistMat")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index 0603072835..8c53eb5da2 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]); + OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois"); return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index cacd47ed4a..f9ea1dc67d 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 4cc989b632..ce37e73b75 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 0d77c7f3a7..bd18d77174 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 451e0ca855..873d44b27e 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 06e48f1262..bcbd7e1e20 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -60,8 +60,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Anchors")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index c68fe2439c..c8701d2810 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("ClsLoss")->type(), platform::CPUPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index f5b9be14ad..28380a04ba 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Scores")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index da6e132498..8d821739f6 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_input_type = ctx.Input("Input")->type(); + auto input_input_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Input"); framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc index 4a6dfec12e..a79a7608ea 100644 --- a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); - + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"); return framework::OpKernelType(input_data_type, platform::CPUPlace()); // ctx.GetPlace()); } diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index ce10de40a9..74756a2a22 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -525,8 +525,9 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -545,8 +546,9 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 338954346c..67aab192fb 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; @@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc index 50ff3cb120..eb59c943e4 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -63,8 +63,9 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -116,8 +117,9 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index c057c82ce0..b2487b1352 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -57,8 +57,9 @@ class TargetAssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index e0d7e25d94..602efd7b80 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 5732b18052..d6cd3171ee 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -98,8 +98,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -255,8 +256,9 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index dff97f7c77..cfd159a2cc 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("DetectRes")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cc index 57d68eb931..86f1c28a9d 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.cc +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cc @@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc index 3e354791ea..c34fb7b96f 100644 --- a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc @@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 1b0b4dd316..712ff56c8c 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X").front()->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc index 7e16e6ff66..6bf7084449 100644 --- a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc @@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X")[0]->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index d46b57e7e1..603f697592 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()), - ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 273015f976..0e060c3a1a 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -121,9 +121,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 3c460242f3..82cc6df4a6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = ctx.Input("DDX")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 8da447adaa..67babe6404 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("DOut")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY if (ctx.HasInput("DDX") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true, "Input(DDY) should not be null"); - input_data_type = ctx.Input("DDY")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY"); } else if (ctx.HasInput("DDY") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true, "Input(DDX) should not be null"); - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } else { - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 677130f2f9..41147b77ee 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -65,8 +65,9 @@ class ExpandOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -180,9 +181,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 034f3c7dce..53cdcc9922 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -190,8 +190,9 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -303,8 +304,9 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -375,8 +377,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index da30fef555..484c4baef9 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/filter_by_instag_op.cc b/paddle/fluid/operators/filter_by_instag_op.cc index ebf44e5b9a..a48c901f9e 100644 --- a/paddle/fluid/operators/filter_by_instag_op.cc +++ b/paddle/fluid/operators/filter_by_instag_op.cc @@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 9f2a122203..c27bb1606b 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -69,8 +69,9 @@ class FlattenOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -130,8 +131,9 @@ class FlattenGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -221,9 +223,9 @@ class Flatten2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index fbe8e56a61..0706f9ce37 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; @@ -107,9 +108,9 @@ class FSPOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 1cd6c40aa0..9a156147aa 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), ctx.Input("Y")->type(), "The element's type of input should be the same."); - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 4c13d39406..9124e0c4c9 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - ctx.Input("Embeddings")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"), ctx.device_context()); } diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 9110099013..5661877cb0 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -56,7 +56,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -125,7 +125,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc index 964335ed2b..e18ac13d34 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 5c89509907..f9ade705ea 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 32f0e37a64..c256e581ee 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index 4c11482f50..d98e782562 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape( framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionRepeatedFCReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index 519670cc6a..1e25a9490b 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 95a08d3b0f..d79bf7cdcc 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index b14ee88aa5..7ca02a2541 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape( framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc index 14e327bb37..0a245bb050 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape( framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolCVMConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 2d10056044..2d4a397798 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape( framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSquaredMatSubOpMaker::Make() { diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index cbeefa0a7f..b2a4029c8f 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - const auto& x_type = x->type(); + const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType( x_type, x_type == framework::proto::VarType::BOOL @@ -82,9 +82,9 @@ class GatherNdGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index cbabd59cf6..075be1caf4 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -45,8 +45,9 @@ class GatherOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -62,9 +63,9 @@ class GatherGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 94fa3b6aa1..26f9989121 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -40,8 +40,9 @@ class GatherTreeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Ids")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index c0893359af..d8470bad11 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 57a1fcd42d..5338889363 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -68,9 +68,9 @@ class GridSampleOp : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; @@ -164,9 +164,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 2b3e2e5c48..a27fcf628c 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index 6375c92de2..bb6b37e64e 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType InstanceNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } template @@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } std::unique_ptr InstanceNormDoubleGradMaker::Apply() const { diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 612f770bb7..cbe9865673 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -407,9 +407,9 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index 092a6eae6f..109e96fb7b 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); - return framework::OpKernelType(x->type(), x->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), x->place()); } }; diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 983ab3dba6..d5976e7f4a 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc old mode 100755 new mode 100644 index b78a6ceb51..b6758c8975 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -224,8 +224,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { // is determined by its input "Emission". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; @@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input(framework::GradVarName("LogLikelihood"))->type(), + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("LogLikelihood")), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index f4aeb062d8..7ea3b06e02 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - ctx.Input("Start")->type(), ctx.device_context(), - layout_, library_); + OperatorWithKernel::IndicateVarDataType(ctx, "Start"), + ctx.device_context(), layout_, library_); } }; diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 409f8397eb..190a7cdf12 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -46,8 +46,9 @@ class LoDResetOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -172,9 +173,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 5285e3cae9..c1d45bb7a0 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 511f50a83d..f0cffa4e1f 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 5ad94cfde9..d5b092ec99 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -130,26 +130,6 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; -namespace { -framework::OpKernelType GetExpectedLRNKernel( - const framework::ExecutionContext& ctx) { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - } -#endif - - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); -} -} // namespace - class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; } // namespace operators diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index bf68c57e67..43af877085 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index b9f4223718..68e204983e 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("BatchGate")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index bb290046f3..615b9ea484 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -44,8 +44,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Predictions")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 2b2f845076..e19ac59ee5 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index d6360c83f0..bedcfc6a38 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Out")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index e0eebad08b..3543a33493 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -53,8 +53,9 @@ class AucOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Predict")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index f6d6ffc668..58b948b5a4 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -92,8 +92,9 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("MaxProbs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 80059ff14c..8d75809428 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -90,7 +90,7 @@ class MulOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 7cb213e899..843f0a68e1 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -55,8 +55,9 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,9 +126,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index e78fda1113..0f26e3953f 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -92,8 +92,9 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; @@ -246,8 +247,9 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 6042b97bf5..e4d50db30a 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -51,8 +51,9 @@ class OneHotOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index 7a75afca09..62f85496f9 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -48,8 +48,9 @@ class OneHotV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 01c0f1bb2d..bde7131379 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 0310fe2eba..b3aff1eff8 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index fc851e56cb..c5a6fe5875 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AdamOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index a015290623..9ede7a56d0 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index b44a84ccf7..5c6c38da92 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index f263e67593..9a7b2112d4 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 98b7117562..3f0cd8aa3c 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index bb77d2ea6c..10b72524ef 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 9dd9b8afbd..3e2f12137a 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index fccfc2b458..cf3c3e2ccb 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index bbd78db51a..dcc6ce41b2 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(data_type, ctx.device_context()); } diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 3069d56014..461db5fdc9 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -621,9 +621,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 31ed0a686f..1c4bf7035e 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -56,8 +56,9 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; @@ -186,8 +187,9 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index f19433115a..6ece163f69 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -134,8 +134,9 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { @@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, "float16 can only be used when CUDNN is used"); diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 91bd2a902f..d8c2ccaa96 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -76,8 +76,9 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -96,8 +97,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index e917e778e4..b0677ff10f 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -95,8 +95,9 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Score")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Score"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index ccb08b245a..364f3689f9 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,8 +56,9 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -112,8 +113,9 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index 5c559bda33..c11d09350a 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -114,8 +114,9 @@ class PRROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -135,8 +136,9 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index c241cf461a..a9128fbd28 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -131,8 +131,9 @@ class PSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -151,8 +152,9 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 8532649614..3af3fb4967 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -104,10 +104,9 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.MultiInput(framework::GradVarName("Out"))[0] - ->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index d8e20f4c4a..69264e3a45 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 65a8d603fc..15911a51a2 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -22,8 +22,9 @@ class RandomCropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 4ed5bd1c70..5cd2627870 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -267,9 +267,9 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index d156ae2077..c17b6ef884 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5dd9dfba43..c7f3d888bc 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -200,8 +200,9 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -302,8 +303,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -472,9 +474,9 @@ class Reshape2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -508,8 +510,9 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("DDX")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 0914ad81c7..a57266690b 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -65,8 +65,9 @@ class ROIAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -85,8 +86,9 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("ROIs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index cfac7e09e1..0515768a63 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -70,8 +70,9 @@ class ROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -90,8 +91,9 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index 8ce2d52273..962b5dbc50 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits"); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; @@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("SampledLogits"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("SampledLogits")); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index c660bbb8ed..73bac5c2fd 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -31,7 +31,7 @@ class SaveOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc index 41f18eaeaf..ba65832dce 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cc +++ b/paddle/fluid/operators/scatter_nd_add_op.cc @@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), - ctx.Input("Updates")->type(), + PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + OperatorWithKernel::IndicateVarDataType(ctx, "Updates"), "Ref and Updates must have same type"); return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); @@ -95,9 +95,9 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 4eb5b7ad9d..b3f43a28df 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -48,8 +48,9 @@ class ScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 67fca18000..f71d844d9e 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/selu_op.h" + +#include #include +#include namespace paddle { namespace operators { @@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index d652f9216f..118c8ce0b1 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -102,9 +102,9 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index e1f6c3e3d5..c7284d0950 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -153,9 +153,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index b7c0420636..90f794ab5f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -208,9 +208,9 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index a7225adbf9..cd0170dd1b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -40,8 +40,9 @@ class SequenceMaskOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index fcc49096e2..de5a0aa45b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 51e354dcd1..dcc762f790 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -122,9 +122,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 5a22212edf..7f9dbbf7ec 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -113,8 +113,9 @@ class SequenceScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -132,9 +133,9 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 4b2ec6e7ca..537184d8b5 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -51,8 +51,9 @@ class SequenceSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 027073e5d7..af6b7477ea 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; @@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc index 232f324de7..06b16152ef 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc @@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index 8256460858..558d180cfb 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc index 578dcd37bb..a02d036715 100644 --- a/paddle/fluid/operators/shard_index_op.cc +++ b/paddle/fluid/operators/shard_index_op.cc @@ -41,8 +41,9 @@ class ShardIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index ad6fb3510f..48da765416 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,8 +35,9 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -83,9 +84,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 21871d7656..e49ce7c487 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -70,8 +70,9 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 4cd7b33a4a..9adb4de01c 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -128,8 +128,9 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -243,9 +244,9 @@ class SliceOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 9d73a19197..09c08f2330 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); @@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 8cde72921c..727d67c2fb 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -171,8 +171,9 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } }; @@ -232,9 +233,9 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Loss"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 3d66613248..e2c2998095 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -167,9 +167,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 5690265573..71049c58e1 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; @@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cc b/paddle/fluid/operators/squared_l2_distance_op.cc index 6e82bf4074..17538c98fe 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cc +++ b/paddle/fluid/operators/squared_l2_distance_op.cc @@ -152,8 +152,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("sub_result")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index b056d2feac..a7e10457fd 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -104,8 +104,9 @@ class SqueezeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -122,8 +123,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -230,9 +232,9 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 7c81d71562..5cd7a78636 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -124,8 +124,9 @@ class StridedSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.Input("Input")->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.Input("Input")->place()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -230,9 +231,9 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index 7f95d16f09..7823b9d850 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -55,8 +55,9 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,8 +126,9 @@ class TeacherStudentSigmoidLossGradientOp // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index a438832b5d..6663d3f557 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -139,9 +139,9 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index db763a051d..fdf5148eb8 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 226aad0384..eab6d437d4 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -164,9 +165,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; @@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -268,9 +270,9 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc index 566939afaa..0c72275c5b 100644 --- a/paddle/fluid/operators/tree_conv_op.cc +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -104,8 +104,9 @@ class TreeConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; @@ -153,8 +154,9 @@ class TreeConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/unfold_op.cc b/paddle/fluid/operators/unfold_op.cc index d21340b478..99907e066b 100644 --- a/paddle/fluid/operators/unfold_op.cc +++ b/paddle/fluid/operators/unfold_op.cc @@ -120,8 +120,9 @@ class UnfoldOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -141,9 +142,9 @@ class UnfoldGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index fae5041c93..0693df843e 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -74,8 +74,9 @@ class UnpoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: @@ -117,8 +118,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index fc849e73c5..e55de4508b 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -215,9 +215,9 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index df9212f9c9..d7f6714710 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -60,8 +60,9 @@ class WarpCTCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context(), layout_, library_); } }; @@ -173,8 +174,9 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } }; -- GitLab