diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 3a573da5106145755a620478d3522adf495183da..0f8aae2eab3b737415bbb8571a02a2e69cc7d5ba 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -48,16 +48,6 @@ std::vector> kKernelPriority = { std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), }; -proto::VarType::Type GetDataTypeOfVar(const Variable* var) { - if (var->IsType()) { - return var->Get().type(); - } else if (var->IsType()) { - return var->Get().value().type(); - } else { - PADDLE_THROW("Var should be LoDTensor or SelectedRows"); - } -} - static DDim GetDimsDebug(const Scope& scope, const std::string& name, bool get_actual_dim = false) { Variable* var = scope.FindVar(name); @@ -1152,40 +1142,65 @@ Scope* OperatorWithKernel::PrepareData( return new_scope; } +void OperatorWithKernel::ParseInputDataType( + const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* data_type) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + const std::vector vars = ctx.MultiInputVar(name); + for (size_t i = 0; i < vars.size(); ++i) { + const Variable* var = vars[i]; + if (var != nullptr) { + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &(var->Get().value()); + } + if (t != nullptr) { + PADDLE_ENFORCE_EQ(t->IsInitialized(), true, + "The Tensor in the %s Op's Input Variable %s(%s) is " + "not initialized.", + Type(), name, ctx.Inputs(name).at(i)); + proto::VarType::Type tmp = t->type(); + PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type, + "The DataType of %s Op's duplicable Variable %s must be " + "consistent. The current variable type is (%s), but the " + "previous variable type is (%s).", + Type(), name, DataTypeToString(tmp), + DataTypeToString(*data_type)); + *data_type = tmp; + } + } + } +} + proto::VarType::Type OperatorWithKernel::IndicateDataType( const ExecutionContext& ctx) const { proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; for (auto& input : this->inputs_) { - const std::vector vars = ctx.MultiInputVar(input.first); - for (size_t i = 0; i < vars.size(); ++i) { - const Variable* var = vars[i]; - if (var != nullptr) { - const Tensor* t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &(var->Get().value()); - } - if (t != nullptr) { - PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized", - input.first, i); - proto::VarType::Type tmp = t->type(); - PADDLE_ENFORCE( - tmp == data_type || data_type == dafault_data_type, - "DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)", - Type(), input.first, DataTypeToString(data_type), - DataTypeToString(tmp)); - data_type = tmp; - } - } - } + ParseInputDataType(ctx, input.first, &data_type); } - PADDLE_ENFORCE(data_type != dafault_data_type, - "DataType should be indicated by input"); + PADDLE_ENFORCE_NE(data_type, dafault_data_type, + "DataType should be indicated by input Variable."); + return data_type; +} + +proto::VarType::Type OperatorWithKernel::IndicateVarDataType( + const ExecutionContext& ctx, const std::string& name) const { + proto::VarType::Type dafault_data_type = + static_cast(-1); + proto::VarType::Type data_type = dafault_data_type; + ParseInputDataType(ctx, name, &data_type); + PADDLE_ENFORCE_NE( + data_type, dafault_data_type, + "The Input Variable(%s) of %s Op used to determine kernel data type " + "is empty or not LoDTensor or SelectedRows.", + name, Type()); return data_type; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5899a14f503fffe603803bfe56533aa40425a252..ab956a9474cd7c3b65e2304867fe11bf28787510 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { } } -proto::VarType::Type GetDataTypeOfVar(const Variable* var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); @@ -459,6 +458,9 @@ class OperatorWithKernel : public OperatorBase { void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const override; + proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx, + const std::string& name) const; + virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; std::vector* GetKernelConfig(const OpKernelType& key) const; @@ -470,6 +472,8 @@ class OperatorWithKernel : public OperatorBase { const OpKernelType& expected_kernel_type) const; private: + void ParseInputDataType(const ExecutionContext& ctx, const std::string& name, + proto::VarType::Type* type) const; // indicate kernel DataType by input data. By default all input data must be // same. proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; diff --git a/paddle/fluid/framework/operator_test.cc b/paddle/fluid/framework/operator_test.cc index fe4804ac253925c112cf7b508efc42c45868a2fa..aeb1daa4ed9fff3dfae465668195f7da1fe46ad3 100644 --- a/paddle/fluid/framework/operator_test.cc +++ b/paddle/fluid/framework/operator_test.cc @@ -315,3 +315,182 @@ TEST(VarNameTest, all) { original_var_name = paddle::framework::GradOriginalVarName(original_var_name); ASSERT_EQ(original_var_name, ""); } + +namespace paddle { +namespace framework { + +class IndicateLoDTensorDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("LoDTensor", "Input of Tensor type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateSelectedRowsDataTypeTestProtoMaker + : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("SelectedRows", "Input of SelectedRows type Variable."); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +class IndicateOtherDataTypeTest : public OperatorWithKernel { + public: + using OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} + OpKernelType GetExpectedKernelType( + const ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; +class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("Other", "Input of Other type Variable"); + AddComment("This Op is only for IndicateVarDataType inferface test."); + } +}; + +template +class IndicateVarDataTypeKernelTest : public OpKernel { + public: + void Compute(const ExecutionContext& ctx) const {} +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP_WITHOUT_GRADIENT( + indicate_lod_tensor_data_type_test, + paddle::framework::IndicateLoDTensorDataTypeTest, + paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_selected_rows_data_type_test, + paddle::framework::IndicateSelectedRowsDataTypeTest, + paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker); +REGISTER_OP_WITHOUT_GRADIENT( + indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest, + paddle::framework::IndicateOtherDataTypeTestProtoMaker); + +REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); +REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test, + paddle::framework::IndicateVarDataTypeKernelTest< + paddle::platform::CPUDeviceContext, int>); + +TEST(IndicateVarDataTypeTest, lodtensor) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_lod_tensor_data_type_test"); + BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lodtensor_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find( + "The Tensor in the indicate_lod_tensor_data_type_test Op's " + "Input Variable LoDTensor(lodtensor_1) is not initialized") != + std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, selectedrows) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_selected_rows_data_type_test"); + BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("selected_rows_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE( + ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test " + "Op's Input Variable SelectedRows(selected_rows_1) is not " + "initialized") != std::string::npos); + } + ASSERT_TRUE(caught); +} + +TEST(IndicateVarDataTypeTest, other) { + paddle::framework::InitDevices(true); + paddle::framework::proto::OpDesc op_desc; + op_desc.set_type("indicate_other_data_type_test"); + BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs()); + + paddle::platform::CPUPlace cpu_place; + paddle::framework::Scope scope; + + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto* var = scope.Var("lod_tensor_array_1"); + var->GetMutable(); + + bool caught = false; + try { + op->Run(scope, cpu_place); + } catch (paddle::platform::EnforceNotMet err) { + caught = true; + std::string ex_msg = err.what(); + EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of " + "indicate_other_data_type_test Op used to " + "determine kernel data type " + "is empty or not LoDTensor or SelectedRows") != + std::string::npos); + } + ASSERT_TRUE(caught); +} diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index b9d07da822cf1eb42859e1d7d84437582fada8ff..5d9633a61dd781f6723ee1a25f33c1cd0b2aa563 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -30,9 +30,9 @@ class Variable { static_assert( IsRegisteredVarType(), "Not registered type. Please register T inside var_type_traits.h"); - PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", + "The Variable type must be %s, but the type it holds is %s.", ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); return *static_cast(holder_->Ptr()); @@ -45,10 +45,10 @@ class Variable { if (!holder_) { holder_.reset(new PlaceholderImpl()); } else { - PADDLE_ENFORCE(holder_->Type() == VarTypeTrait::kId, - "Variable must be type %s, the holding type is %s", - ToTypeName(VarTypeTrait::kId), - ToTypeName(holder_->Type())); + PADDLE_ENFORCE( + holder_->Type() == VarTypeTrait::kId, + "The Variable type must be %s, but the type it holds is %s.", + ToTypeName(VarTypeTrait::kId), ToTypeName(holder_->Type())); } return static_cast(holder_->Ptr()); } @@ -61,7 +61,7 @@ class Variable { void Clear() { holder_.reset(); } int Type() const { - PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory"); + PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized."); return holder_->Type(); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 82ff2c1a72b89efcfca4c7d3350f09a4f7216063..be4786fadaf56e2dd0076ba3929d2e03a7364821 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(oper.IndicateVarDataType(ctx, name), + ctx.GetPlace(), layout, library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/add_position_encoding_op.cc b/paddle/fluid/operators/add_position_encoding_op.cc index 2580c5a523e13fb489bf9810c205257102d8a72e..61a9fa765079d313a5b93c82d97658ae87fa4f8e 100644 --- a/paddle/fluid/operators/add_position_encoding_op.cc +++ b/paddle/fluid/operators/add_position_encoding_op.cc @@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index 1476cfc2c89130677de22bc6f43cb258cd5e0be2..6040ed7550d9482f1d56c537dbcd3c802e5976c8 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -121,9 +121,9 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 9d7100cc3db91f5bf7dbd993c9f9ba5d4fc98ea6..c46b42601156e2471877b563c52a32a59e2134be 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel { library = framework::LibraryType::kCUDNN; } #endif - auto data_type = ctx.Input("Theta")->type(); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kAnyLayout, library); } @@ -191,9 +191,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("Theta")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 221204878659e64488425a6e783d3c0feb269577..c2b3c818c65e3902054ee67c2d291d308db9ee13 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -89,8 +89,9 @@ class AssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index c6d98f1f9a534aa98923afc1ead0ffc1f83a8b99..53bd2e4c4503f72f9c267c7fac48a726d0e249dc 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index 0922b03b5f5fbd2a7a62b0a325ebed9600767497..273df31fc80c43873bdb333c971ff37d44bce500 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 72c023dd9924351543a496a70645e5aa876cc639..546605c8dba2fbd455f2f8dfcefb0c1f23a69d33 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType BatchNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, + library); } template diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index a6aa35e0569364d79c15aea6e6dbc6ca670d49f0..62cfbfcaae217d879e0128181e6ea60de86f1640 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { // Compute on CPU for cases with batch_size > 4. if (batch_size <= 4) { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), + ctx.GetPlace()); } else { return framework::OpKernelType( - ctx.Input("pre_ids")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"), platform::CPUPlace()); } } diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 51c4d878142dcd93a170c9ea4211b9c6ec8e4422..1ad0271304929452f665d287e161917ebd9d1b71 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -52,8 +52,9 @@ class BprLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -98,8 +99,9 @@ class BprLossGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/center_loss_op.cc b/paddle/fluid/operators/center_loss_op.cc index bf766a056a767f4b5e152800e9305d1f51f6d901..0b6ce82397384eed6ff2263822646f54c3500d97 100644 --- a/paddle/fluid/operators/center_loss_op.cc +++ b/paddle/fluid/operators/center_loss_op.cc @@ -61,8 +61,9 @@ class CenterLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.Input("SampleCenterDiff")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 02f6210ca4c5fcf2dd53aed23db586aed597df43..c661d4215988df57b801e6d0ff860b33f7646933 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc index 72d330306cc9df2836f27309d4f5617dacced34f..928fa8549ffb9209dea975a049db4beed0add6b6 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index c76ccb70f769644cf0996109c645b94223047510..daef6310ddf5808ab8765362427a005a387096af 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel { if (flag == 0) { PADDLE_THROW("All Inputs of Concat OP are Empty!"); } - #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), @@ -175,9 +174,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 514fde453bbfdbd93ed2a9daec002f2f2857ff1e..cf720cc627c60cbb5b733dc1cb6ccf3b8a005690 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType::kDefaultCustomizedTypeValue; framework::LibraryType library{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - auto input_data_type = ctx.Input("Input")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); std::string data_format = "AnyLayout"; // todo enable data layout when it's ready framework::DataLayout layout = framework::StringToDataLayout(data_format); @@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; @@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( customized_type_value = kConvMKLDNNFP32; } #endif - auto type = framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_, customized_type_value); #ifdef PADDLE_WITH_CUDA if (library_ == framework::LibraryType::kCUDNN) { std::vector& configs = kernel_configs_map_[type]; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 6dddd4848e3e4870bfcc3051fd7cd0b90feb3ae8..4ba330447e22d860b380579275d2ac1921e20629 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void Conv2DTransposeOpMaker::Make() { @@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( } framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker { diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index 4676bd04646ff04d65539b35d9e45eb8a06188c2..746f96dcac09d507f820e877c483ae533fe07cdc 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -160,8 +160,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index 2ced5467f1a006efac9e647ee6daba4815ee5d28..f42463582f86475209352a558b13a960425f665e 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -53,8 +53,9 @@ class CropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -174,9 +175,9 @@ class CropOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/crop_tensor_op.cc b/paddle/fluid/operators/crop_tensor_op.cc index 9b536e98e41f7360867f349769875567c75ad2a7..43fa27ef4b130247765587aac0d4b31982062ae7 100644 --- a/paddle/fluid/operators/crop_tensor_op.cc +++ b/paddle/fluid/operators/crop_tensor_op.cc @@ -87,8 +87,9 @@ class CropTensorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -243,9 +244,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 8a80619f6636f9f0cab1d0b6332ca05742b9e7f8..d6da40ddfe6fe2c262aeee8e2db1c0b5f9ecf83e 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -107,8 +107,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { @@ -157,9 +158,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index 4abe9509e6d4a5143698fcdf343bc54f6ad207fc..9982230495d7e96110876dce5a98ebf2ead5f133 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -39,8 +39,9 @@ class CTCAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 53ed86ade48ce52d49285495388f93f1bc4f5d9e..7675a6acf7e4458a23ad60faf3e5bcb2d158887e 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -52,8 +52,9 @@ class CVMOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -93,8 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 5dc83ac7b3078960b2aa36b3c6c8a77d502f9a05..6d1168c3ae80cce86ee801d9514b7bf5f2e7620c 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel { if (input_data_type == framework::proto::VarType::FP64) { dn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSize")->type(), + PADDLE_ENFORCE_EQ(dn_param_type, + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"), "BatchSize input should be of float type"); - PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input("BatchSum")->type(), - "BatchSum input should be of float type"); PADDLE_ENFORCE_EQ(dn_param_type, - ctx.Input("BatchSquareSum")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"), + "BatchSum input should be of float type"); + PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType( + ctx, "BatchSquareSum"), "BatchSquareSum input should be of float type"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index c000787545d0dc777cf4f691608cbab32159b4aa..1eedcc010f27a426b9b66fc897e22f679bd4acba 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -216,8 +216,9 @@ class DeformableConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -275,8 +276,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc index 6129e29655048ea7001bf1e48846f6801c16459d..8bef1a0b7486dff57c626963c2bc5d5ea3abf5b3 100644 --- a/paddle/fluid/operators/deformable_conv_v1_op.cc +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -199,8 +199,9 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -253,8 +254,9 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index d17f22b9b4f7641f7d69e0056e19762945f2d05c..dd2f700901b177083df8d1b3e5d34c46f82a9950 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -199,8 +199,9 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -247,8 +248,9 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Trans")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Trans"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 97f49dbcb08e4428b4857f4a70ab21399fb35612..0ed3293418fb124873b422036cb6c946823a83bf 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 4a333b559f82e6d39d2d4345c8ad58bc8d430c69..d3287249166833f50fd4182c9cfcba6c03906652 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index af7797a6d7cde6e81c66a3d29ed36154b6e11529..785a207263191e516c47da76ecb6ce771f0242f2 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -45,8 +45,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("DistMat")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc index 0603072835e8f146e5bb006d5759220900a29e56..8c53eb5da2f69f7a516ad98c9cee182ec232bab8 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cc @@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]); + OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois"); return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index cacd47ed4a80489c59cdd80747d69c70bd5ea286..f9ea1dc67d9dcdfb8de7e91bc348e50f5d03e319 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 4cc989b6325f4da0cb38dd25a1529178a9af2268..ce37e73b750d669190615d01697d77039bf857c3 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/generate_mask_labels_op.cc b/paddle/fluid/operators/detection/generate_mask_labels_op.cc index 0d77c7f3a79fc491dfdc54d74c7cfebd85a5992e..bd18d77174f881b5773f775054091489fbdb2363 100644 --- a/paddle/fluid/operators/detection/generate_mask_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_mask_labels_op.cc @@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 451e0ca85501bccd2588dd58d0c8efe7142559d9..873d44b27e245ae7f12a5601903b462b568e3bc9 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois"); return framework::OpKernelType(data_type, platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 06e48f1262a74dfdfd6d38e71cd02116f3e6eca5..bcbd7e1e2031ff1db0f1f0fabb3be2c339b62e19 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -60,8 +60,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Anchors")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index c68fe2439cad9bc5a49a742c1a38e704a7618156..c8701d28101f5f70691258d15e33a6b5bb7c44ae 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("ClsLoss")->type(), platform::CPUPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index f5b9be14ad6819f040b915f42d6e7ffb7dcdc908..28380a04ba194b8735f01d1dc19c0997405a027c 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Scores")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index da6e132498d34e35a198550352069c56c2e343b4..8d821739f6f1a05cda966cb1ab1064194b5953dd 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_input_type = ctx.Input("Input")->type(); + auto input_input_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Input"); framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; diff --git a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc index 4a6dfec12e660431844682694632a3b18d91bf3e..a79a7608ea9c565b4c59cc6e4dfb11fca2f5be2c 100644 --- a/paddle/fluid/operators/detection/retinanet_detection_output_op.cc +++ b/paddle/fluid/operators/detection/retinanet_detection_output_op.cc @@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = - framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); - + OperatorWithKernel::IndicateVarDataType(ctx, "Scores"); return framework::OpKernelType(input_data_type, platform::CPUPlace()); // ctx.GetPlace()); } diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index ce10de40a9682204f9643296be0b02c74300cebe..74756a2a22ac7abe8f782cd1e6ddf51640f581f1 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -525,8 +525,9 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -545,8 +546,9 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 338954346c5af2c04ff6bf09b11873caec4a04dd..67aab192fbedcf74bbbd9194ffc4875493785979 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; @@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Anchor")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc index 50ff3cb120e8199f51af1f3aaa71368da0561d3b..eb59c943e4bd4f63442191a900b6d03a6594d118 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cc @@ -63,8 +63,9 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -116,8 +117,9 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index c057c82ce0f5eef67c09d0ed719ddd24382f451d..b2487b13523a467a36e9663bf424c1e81ebc6e4f 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -57,8 +57,9 @@ class TargetAssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc index e0d7e25d944cf2321799da4c73de9f74d9fd287d..602efd7b80ab137f308a1902f578f10dbe047c93 100644 --- a/paddle/fluid/operators/detection/yolo_box_op.cc +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc index 5732b180526c502efea0ca72af87b38e45bfbec2..d6cd3171ee359fb7994cd1a60e2ca50c22b03c06 100644 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ b/paddle/fluid/operators/detection/yolov3_loss_op.cc @@ -98,8 +98,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -255,8 +256,9 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index dff97f7c77fc26af4cd4e7794d9092aec14cfa6e..cfd159a2cca529a7d82dfecd095bb420570c4a18 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("DetectRes")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/allreduce_op.cc b/paddle/fluid/operators/distributed_ops/allreduce_op.cc index 57d68eb931f089e46df07f45186246568bc297c8..86f1c28a9dd4f53400418c93f8598b7a9c38f4cc 100644 --- a/paddle/fluid/operators/distributed_ops/allreduce_op.cc +++ b/paddle/fluid/operators/distributed_ops/allreduce_op.cc @@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc index 3e354791ea9af4fa833026e3170856d823a5fd78..c34fb7b96f2377a8ca12f3488efa823ac012a5e0 100644 --- a/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc +++ b/paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.cc @@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 1b0b4dd31693340bc39c0da8995a2a2d40b13e00..712ff56c8c241b4a7c7a301b037eb8635b71d0f9 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X").front()->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc index 7e16e6ff66b603634aa7cd26f71a4f2d3159c4e4..6bf70844491fe1b21c7f55ff6189e2628e84a7a4 100644 --- a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc @@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - ctx.MultiInput("X")[0]->type(), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/split_ids_op.cc b/paddle/fluid/operators/distributed_ops/split_ids_op.cc index d46b57e7e15807756efd85fde765454260ea9d7b..603f697592279aff21cd4a59d4556021aafcae5b 100644 --- a/paddle/fluid/operators/distributed_ops/split_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/split_ids_op.cc @@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()), - ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 273015f9763c2c7375aa0609436a2e8ab190b696..0e060c3a1a339f3ae9e40d6eaadf47d0f8bc566b 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -121,9 +121,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 3c460242f3d871cf3415ede203267a7928494678..82cc6df4a6623e6940c391f3e46fe2ed0ecfa900 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = ctx.Input("DDX")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 8da447adaa70e226664d0eb0c64811ac219f3e8b..67babe640441875f882aa2b783eda43b35dacce3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("DOut")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY if (ctx.HasInput("DDX") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true, "Input(DDY) should not be null"); - input_data_type = ctx.Input("DDY")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY"); } else if (ctx.HasInput("DDY") == false) { PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true, "Input(DDX) should not be null"); - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } else { - input_data_type = ctx.Input("DDX")->type(); + input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); } #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 677130f2f92e005bdeddd43952bb319f88b41c69..41147b77ee0585e2427e065ffb2cd9d23373e7d2 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -65,8 +65,9 @@ class ExpandOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -180,9 +181,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 034f3c7dcebf906e600b9a6a651a1c857ddc4189..53cdcc9922639812450692a6d8f9ebe27a5dd14d 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -190,8 +190,9 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -303,8 +304,9 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -375,8 +377,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index da30fef555e1657d60e9493ab9e70beea838e801..484c4baef94de1a5cff24a9f879a522e98b11183 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout, library); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout, library); } }; diff --git a/paddle/fluid/operators/filter_by_instag_op.cc b/paddle/fluid/operators/filter_by_instag_op.cc index ebf44e5b9a5b3d0fe421a6d512f70f74a4146d56..a48c901f9e6d4431e8a9aa3efe0a70089bbdd92e 100644 --- a/paddle/fluid/operators/filter_by_instag_op.cc +++ b/paddle/fluid/operators/filter_by_instag_op.cc @@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 9f2a122203bf9bed2d8737dc2056b16b4d7b7b8e..c27bb1606b346f9cd607f74e3b9a21013918e5f5 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -69,8 +69,9 @@ class FlattenOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -130,8 +131,9 @@ class FlattenGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -221,9 +223,9 @@ class Flatten2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index fbe8e56a6160219175bd573a2ff186eb35e56fdf..0706f9ce3769a7b8e497f7d924d8ac553b12de68 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; @@ -107,9 +108,9 @@ class FSPOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 1cd6c40aa0540f5e5c9ea4b3e3e771dcc827eccf..9a156147aa4c65280844a7277e8576daca3cb6e1 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), ctx.Input("Y")->type(), "The element's type of input should be the same."); - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 4c13d39406be3bb5ed6b6103032b7fe811078ca1..9124e0c4c9b0bceffb08c6b3ac9e8f1ef4b2e383 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - ctx.Input("Embeddings")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"), ctx.device_context()); } diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc index 9110099013a20f2718038a31ec94c1e76583149b..5661877cb07121697e249a4279191068d877a7eb 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc @@ -56,7 +56,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -125,7 +125,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc index 964335ed2bc8df5e26eee3d1e1ae2b88bcbf25d9..e18ac13d345e8b2f59172300891434e316435665 100644 --- a/paddle/fluid/operators/fused/fusion_conv_inception_op.cc +++ b/paddle/fluid/operators/fused/fusion_conv_inception_op.cc @@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 5c89509907375b5f2089224c21dd1ef67872c2fd..f9ade705ead23fd03ba470c1c72642abd8a1edb4 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 32f0e37a64b98d7e184bd6522504b6821a548af4..c256e581ee892157a9458e2a19f662772401e3ab 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc index 4c11482f5077eeeb2d446dc0cbe9c08f890f390f..d98e782562a427f199d3ae0870d02f95deadf53f 100644 --- a/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_repeated_fc_relu_op.cc @@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape( framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionRepeatedFCReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index 519670cc6a7b73b679645e5ee6d98b74613cdacc..1e25a9490b8c53610879c91583f6b5114a35ba0d 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 95a08d3b0f030e7dae6668a788b52cfe66daa250..d79bf7cdcc82f2e3be368023dd87289afeeb8a17 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc index b14ee88aa53b64791fa09c848e23d4f01826e339..7ca02a2541ddf1421f15adb4cf1affad92443d03 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_concat_op.cc @@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape( framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc index 14e327bb37d1381affe0189ce220fe13c63eac99..0a245bb05057598bbc41c3753eaf5adaa47f4c0e 100644 --- a/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.cc @@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape( framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSeqPoolCVMConcatOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc index 2d10056044efa851898c8cf597fa14e495305fce..2d4a39779801938f3b0a3de9409257248d7eb4af 100644 --- a/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc +++ b/paddle/fluid/operators/fused/fusion_squared_mat_sub_op.cc @@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape( framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } void FusionSquaredMatSubOpMaker::Make() { diff --git a/paddle/fluid/operators/gather_nd_op.cc b/paddle/fluid/operators/gather_nd_op.cc index cbeefa0a7f6510232cf0758f0184f07ddbb0595b..b2a4029c8f6c86f632b899f264a00a6192cbdad9 100644 --- a/paddle/fluid/operators/gather_nd_op.cc +++ b/paddle/fluid/operators/gather_nd_op.cc @@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - const auto& x_type = x->type(); + const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType( x_type, x_type == framework::proto::VarType::BOOL @@ -82,9 +82,9 @@ class GatherNdGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index cbabd59cf634f09c0a55d3822995b4d0f5f170ee..075be1caf48f0e96c2349c2060caf2637f55ce41 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -45,8 +45,9 @@ class GatherOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -62,9 +63,9 @@ class GatherGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/gather_tree_op.cc b/paddle/fluid/operators/gather_tree_op.cc index 94fa3b6aa1e7e9ebb166b4344fbd62e0242de660..26f9989121e9eb32d27817d23fa972e5e8fcf1b7 100644 --- a/paddle/fluid/operators/gather_tree_op.cc +++ b/paddle/fluid/operators/gather_tree_op.cc @@ -40,8 +40,9 @@ class GatherTreeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Ids")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc index c0893359af2f4de4ed8fd88ebff122447e8d84c7..d8470bad1188b5a69e08e4b9c74957339bb04294 100644 --- a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index 57a1fcd42da04a766ebd8713e3863f259b3784ac..5338889363af1f4f8eade81a6b966bd5fe86b3e0 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -68,9 +68,9 @@ class GridSampleOp : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; @@ -164,9 +164,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 2b3e2e5c484a1f04c03f0c2482072f0452382aa1..a27fcf628cc63f483a7313253fda019c6ca54e5c 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index 6375c92de2d219e9e66ce8899fed991a1a75d00d..bb6b37e64e4ed0a4d683d69397416d1ef468102a 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { framework::OpKernelType InstanceNormOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } template @@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( if (t == nullptr) { PADDLE_THROW("cannot find Y@GRAD"); } - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } std::unique_ptr InstanceNormDoubleGradMaker::Apply() const { diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 612f770bb7cee695724a39635bdd2d884813d7ca..cbe9865673b0efd2080157d7562d635fdea51922 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -407,9 +407,9 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index 092a6eae6f5b7edcc5656522377de10a08a01ea8..109e96fb7baa26c25087819c54c21d14b602d2cb 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto *x = ctx.Input("X"); - return framework::OpKernelType(x->type(), x->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), x->place()); } }; diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index 983ab3dba6e2fb8354559ac5a2a688baa39f7c9b..d5976e7f4ae2240722f58d2348503a75692ff553 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc old mode 100755 new mode 100644 index b78a6ceb5199bc39d04e3560d350f1bd1b6aee52..b6758c8975bd0bda5e0c89e93c4793956ae0fc58 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -224,8 +224,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { // is determined by its input "Emission". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Emission")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Emission"), + platform::CPUPlace()); } }; @@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input(framework::GradVarName("LogLikelihood"))->type(), + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("LogLikelihood")), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index f4aeb062d8dfae31a72b8ebccb3d377276662da6..7ea3b06e02ed3a50b22463693b7df3372587dc0c 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - ctx.Input("Start")->type(), ctx.device_context(), - layout_, library_); + OperatorWithKernel::IndicateVarDataType(ctx, "Start"), + ctx.device_context(), layout_, library_); } }; diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 409f8397eb22cfcf11d7485d86b5b4b9bdddd81e..190a7cdf12faf4ae38eb1f3c1a777c4621dbeace 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -46,8 +46,9 @@ class LoDResetOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -172,9 +173,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 5285e3cae9ab1c962944c97c28677c85655a882c..c1d45bb7a0d1e57d1bf6fb06d42d63d074fea200 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 511f50a83d58292fa08fb603dc616dc7f7e5a626..f0cffa4e1fe1046c580c5e2d0cecfbd7626ccd9b 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 5ad94cfde901bedae4af28e5b2a43bad08e28cf9..d5b092ec99d71b12ec3b922a2bae1916d182dab1 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -130,26 +130,6 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; -namespace { -framework::OpKernelType GetExpectedLRNKernel( - const framework::ExecutionContext& ctx) { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; - } -#endif - - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); -} -} // namespace - class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetExpectedLRNKernel(ctx); + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; } // namespace operators diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index bf68c57e67fbff9216f51d805c78e49714fdb736..43af877085bd06c654c26dc471eaf0d137347ab0 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index b9f42237180007eecc8b558c6939a7156dfc6e45..68e204983e8e54d5f9ebe46405a3a1fd7a249c04 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("Input")->type(), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - ctx.Input("BatchGate")->type(), + OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index bb290046f3a62d971dccd95f8550acdd5f68c847..615b9ea48477383f49783884259fdff533ef04b8 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -44,8 +44,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Predictions")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 2b2f8450768b9885381f10b19631a6a200c7f703..e19ac59ee576e83e6328fd6388bc481fe662819e 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index d6360c83f092602a9780196945e335f2884b5b46..bedcfc6a387f588ea3cf64115162c6d34946308b 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Out")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index e0eebad08bb6b9a15d9c0f356215404884bee0e9..3543a33493525be68e8b5b59ac20b4bc55dc5cc1 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -53,8 +53,9 @@ class AucOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Predict")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Predict"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index f6d6ffc668c9aaa40e12e7289d4f97fc656e2c70..58b948b5a438847711dd90e2d6537b46bde48274 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -92,8 +92,9 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("MaxProbs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 80059ff14ca4d475b1a2c625ef1dcfe8912e6947..8d758094282f2380a50a7cbcef8de862c7a12ec6 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -90,7 +90,7 @@ class MulOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 7cb213e89958e017c62d7cded261570307d3e64b..843f0a68e1d6832f87acc0b2b838afe61c9a6704 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -55,8 +55,9 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,9 +126,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index e78fda111397dfc89023375f1bea175a615b4c03..0f26e3953feeb55fd378d74d2c0dece7bf4d863a 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -92,8 +92,9 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; @@ -246,8 +247,9 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index 6042b97bf57a699b566ec2cee955f7db3bb7b2de..e4d50db30a3f0caae6b5733f1015534d2c3cd7e8 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -51,8 +51,9 @@ class OneHotOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/one_hot_v2_op.cc b/paddle/fluid/operators/one_hot_v2_op.cc index 7a75afca09cea13eb07749eb565ea880f8a5acf0..62f85496f9b9f6362112de267bba152ab03589a6 100644 --- a/paddle/fluid/operators/one_hot_v2_op.cc +++ b/paddle/fluid/operators/one_hot_v2_op.cc @@ -48,8 +48,9 @@ class OneHotV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 01c0f1bb2d4778c3ba4980b9e7d4faef77901c0b..bde7131379a272e31fb1effe2f92204fa27f9a14 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index 0310fe2eba8e9fcd02ac6c229f90a1d75ddea63e..b3aff1eff8c46cc0dc41f1f58087deb831030032 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index fc851e56cbfd2ab6780a3c812309bced2b693acd..c5a6fe5875baa2e1cba160b1a020916c2f42a285 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AdamOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index a0152906235cbc8a870a05da990409f661338f6e..9ede7a56d0b535b2c9a1c538d442424ca6f3e4b7 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index b44a84ccf71b574663ba5e425c4537d3769fdffe..5c6c38da92808f05c90e7dad2482e7c7364a1f80 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/dpsgd_op.cc b/paddle/fluid/operators/optimizers/dpsgd_op.cc index f263e67593bbd15f062648e5f09627d5fea64f0d..9a7b2112d4e5c3af02188f0daea34c34acd5f699 100644 --- a/paddle/fluid/operators/optimizers/dpsgd_op.cc +++ b/paddle/fluid/operators/optimizers/dpsgd_op.cc @@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 98b71175624e77bf3ea1d402b9ab13c84d93c8a5..3f0cd8aa3c8534e348de0e679b31e68ccbfd7822 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("Param")->type(); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index bb77d2ea6cda3beaffaf109f2591f1339502a020..10b72524efd4a8f9174eab4f45e6173dc56f2c27 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 9dd9b8afbd4915202df120b02f7e62de79e9e224..3e2f12137afc2368aa12fa836c935f804f8c02d9 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index fccfc2b4584a25e5f703750393464bbc3026de42..cf3c3e2ccb92cd588edea6468b61e6d2e5678be5 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Param")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index bbd78db51a9ff856107c014fbdf2109e8801c2d2..dcc6ce41b279803d8a84b7cfb6e94df6f3cd06f8 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); return framework::OpKernelType(data_type, ctx.device_context()); } diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index 3069d5601442c8bc1fb1dc0a4d08558da4dfd9f1..461db5fdc9421d2f7728e496a88f57d188bbf5c0 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -621,9 +621,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 31ed0a686f712bd286b4accda68716b156037dbc..1c4bf7035ef978abf9c5dcdab8872181c6af477f 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -56,8 +56,9 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; @@ -186,8 +187,9 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Y")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index f19433115a7dbcdc7aa955428ec422e25fd72ad0..6ece163f69770ecfea4d9aa280d19913a2bb2842 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -134,8 +134,9 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { @@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, "float16 can only be used when CUDNN is used"); diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 91bd2a902f7cc53f76682d99195ed0d2c08352a3..d8c2ccaa96079393aacb53f841dfea0562004fb4 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -76,8 +76,9 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -96,8 +97,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index e917e778e41ff8994f248e905635da702b428fc2..b0677ff10f6576e3892216201306ca9e12f7b03d 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -95,8 +95,9 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Score")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Score"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index ccb08b245a4696865b46f555b1ef2500bd39aadd..364f3689f9fa810d37a8f4746e079e29d47645a4 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,8 +56,9 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -112,8 +113,9 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prroi_pool_op.cc b/paddle/fluid/operators/prroi_pool_op.cc index 5c559bda339818a0351d754d5c2cd88e9ca058a4..c11d09350a477e59fdde5d1caa01b87a7d870daa 100644 --- a/paddle/fluid/operators/prroi_pool_op.cc +++ b/paddle/fluid/operators/prroi_pool_op.cc @@ -114,8 +114,9 @@ class PRROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -135,8 +136,9 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index c241cf461ac07249fe28bcb535b74c537b7ac9b5..a9128fbd28db82440225b56833a5528ef2cecce6 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -131,8 +131,9 @@ class PSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -151,8 +152,9 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pull_box_sparse_op.cc b/paddle/fluid/operators/pull_box_sparse_op.cc index 8532649614c867a860774378e4ffd9b251dd76d5..3af3fb4967590068fa2a936cbb2e01db583e565a 100644 --- a/paddle/fluid/operators/pull_box_sparse_op.cc +++ b/paddle/fluid/operators/pull_box_sparse_op.cc @@ -104,10 +104,9 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.MultiInput(framework::GradVarName("Out"))[0] - ->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index d8e20f4c4ae6059551bfff3603a2ad6c0a7aa86d..69264e3a45e139bf324d9a252878db92782bda29 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 65a8d603fcee27223182984769df221e3f519b05..15911a51a2c8e56c4279931b4fcebe35ce7b268c 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -22,8 +22,9 @@ class RandomCropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 4ed5bd1c70b050a185b38b40d5c08d20206b78fb..5cd2627870d4a55d8c8a161c5ed5087bab609fa7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -267,9 +267,9 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index d156ae207763433ea2ed7fb97a08cbe5880da3cd..c17b6ef8842ad1682619ce6bf6c1a4a17fcc67b4 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType( framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), + layout_, library_); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 5dd9dfba43e46ee50d7043a88792ae240700e4d0..c7f3d888bc04b250e40608b4a421e9137d370f87 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -200,8 +200,9 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -302,8 +303,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -472,9 +474,9 @@ class Reshape2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( @@ -508,8 +510,9 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("DDX")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "DDX"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 0914ad81c77ef2f4908868de04ce1bdadf7321e9..a57266690b01f4794d08b001d64f914e2badf1fa 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -65,8 +65,9 @@ class ROIAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -85,8 +86,9 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("ROIs")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index cfac7e09e123c43204454adacb87a7c3c158690e..0515768a630423f3bc062b3cac62f2d5907b9969 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -70,8 +70,9 @@ class ROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -90,8 +91,9 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sample_logits_op.cc b/paddle/fluid/operators/sample_logits_op.cc index 8ce2d52273d7cc3d523e5d77c2c79b9989b9227f..962b5dbc50593768ad594d841d9646c9c6048b17 100644 --- a/paddle/fluid/operators/sample_logits_op.cc +++ b/paddle/fluid/operators/sample_logits_op.cc @@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits"); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; @@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("SampledLogits"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("SampledLogits")); framework::OpKernelType kt = framework::OpKernelType(data_type, ctx.device_context()); return kt; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index c660bbb8ed9a4caf564fd75d3c248827ea46d35a..73bac5c2fdb9711d10de8ae759cea7905e33c724 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -31,7 +31,7 @@ class SaveOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/scatter_nd_add_op.cc b/paddle/fluid/operators/scatter_nd_add_op.cc index 41f18eaeaf8bd894282929321a483ef5859c5895..ba65832dceaee53fbda209527c0824f51b282627 100644 --- a/paddle/fluid/operators/scatter_nd_add_op.cc +++ b/paddle/fluid/operators/scatter_nd_add_op.cc @@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), - ctx.Input("Updates")->type(), + PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"), + OperatorWithKernel::IndicateVarDataType(ctx, "Updates"), "Ref and Updates must have same type"); return framework::OpKernelType(ctx.Input("X")->type(), ctx.device_context()); @@ -95,9 +95,9 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index 4eb5b7ad9d1fe128ade904cf61e0178d59b374b8..b3f43a28dffd7d0f0d41a57619b78283bee7f02f 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -48,8 +48,9 @@ class ScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/selu_op.cc b/paddle/fluid/operators/selu_op.cc index 67fca18000a4fac1e2ca39fc26ebe67649a51bc3..f71d844d9e9a21b1f494dbf648a3622b9aff5ec6 100644 --- a/paddle/fluid/operators/selu_op.cc +++ b/paddle/fluid/operators/selu_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/selu_op.h" + +#include #include +#include namespace paddle { namespace operators { @@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc index d652f9216f8faf53deeac2c9ce1f737651c3939b..118c8ce0b1145dfc42e11d7fc91888e92d84b4c6 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.cc @@ -102,9 +102,9 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc index e1f6c3e3d599340acfa9bb5b47017b003721e4a3..c7284d0950c295b576af7a692d2db2f6388c2fd4 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.cc @@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -153,9 +153,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc index b7c0420636ab60e8a3e0a9332cbd3858aacda1b0..90f794ab5ff37780e4cc451f4c0f3bbd5e06561b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_op.cc @@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -208,9 +208,9 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index a7225adbf9fcafdff30ecf0f6c7a5f6a73c4f3e8..cd0170dd1b4c64fdf8b81bb15c41feae6e95af47 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -40,8 +40,9 @@ class SequenceMaskOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, diff --git a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc index fcc49096e2c48c264179e95133c9f9b4ec973e1f..de5a0aa45bc3698e29c86c3a5aa00cfa8978c269 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pad_op.cc @@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 51e354dcd175845c3db2cce78dac6039361aed08..dcc762f790738b0765fa181aac204ae93f8ab2ce 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -122,9 +122,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index 5a22212edf29cc79d28b12029dc7595ae5f1aab3..7f9dbbf7ecae611946bb7c0e488a8614c030c97d 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -113,8 +113,9 @@ class SequenceScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; @@ -132,9 +133,9 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - platform::CPUPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 4b2ec6e7cad7c04e248c0ffbb117951fba1ec877..537184d8b521d2f9f5d7d2d0e9358e98e77467fb 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -51,8 +51,9 @@ class SequenceSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -71,9 +72,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 027073e5d7d6c767ebb02662c6fd8b2cf9306904..af6b7477ea23f74f3a917b8c38e9088fd53cbd56 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; @@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - ctx.Input("X")->type(), ctx.GetPlace(), + OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc index 232f324de77e4808a0731c9ca7d79906d6b69cde..06b16152ef9f0d0c92151a0c4f372360fe948e3b 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_topk_avg_pooling_op.cc @@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc index 8256460858b88bf30e4d70e796eebc84bd68c0da..558d180cfba276c355c74b8153bbc28597a1210a 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_unpad_op.cc @@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto data_type = framework::GetDataTypeOfVar( - ctx.InputVar(framework::GradVarName("Out"))); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shard_index_op.cc b/paddle/fluid/operators/shard_index_op.cc index 578dcd37bb42bdc4c69020c2cf500d4a6c203a55..a02d03671591d1a859d168a8195fcc6cda0e3434 100644 --- a/paddle/fluid/operators/shard_index_op.cc +++ b/paddle/fluid/operators/shard_index_op.cc @@ -41,8 +41,9 @@ class ShardIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index ad6fb3510f02ae783c8ae4318f559a8db74a59d1..48da765416f6d47b7eaad8c869acbc31a1329f2e 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,8 +35,9 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -83,9 +84,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 21871d76569d0ce410824cf4760cb22529535094..e49ce7c487225f6e8e077ba32e1da2b759879b1e 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -70,8 +70,9 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 4cd7b33a4a83eeee1977a94afaf90d91b7edb766..9adb4de01cb6802eafb3b975fafa6a7f7c91d8e1 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -128,8 +128,9 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -243,9 +244,9 @@ class SliceOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 9d73a19197c29fae29728cd6ab770bc0cc7a3ab1..09c08f2330725e585c8576f92a3af853867e15c8 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { } #endif - auto input_data_type = ctx.Input("X")->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); @@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = - ctx.Input(framework::GradVarName("Out"))->type(); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 8cde72921cb10bd6cbd7522e32bc5fafcaf46bb9..727d67c2fba63a662ccd22bf95b57e4853ba1daa 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -171,8 +171,9 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } }; @@ -232,9 +233,9 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Loss"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Loss")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 3d66613248c27f683faf6e3f075c495ed6e71b06..e2c2998095eb2d976377e284fbe84eee03feb095 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -167,9 +167,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index 5690265573fcf6bd0cd902917da87b98877a0cf7..71049c58e1670d787b608616fcd4d51aabb54a71 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; @@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Weight")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cc b/paddle/fluid/operators/squared_l2_distance_op.cc index 6e82bf407496ab2d37d3fe81aacccfc128d57aec..17538c98fe5531dde4f6471840128e58e54b3b9c 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cc +++ b/paddle/fluid/operators/squared_l2_distance_op.cc @@ -152,8 +152,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("sub_result")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index b056d2feacce72c3a4f285225b4e6ed6a5f57f8c..a7e10457fd7a96d20005014a14e9533830b6031c 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -104,8 +104,9 @@ class SqueezeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -122,8 +123,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -230,9 +232,9 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index 7c81d71562f01840c82171daf53acfa80d8a438e..5cd7a78636379f19c6be686198f249f2370ceb40 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -124,8 +124,9 @@ class StridedSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.Input("Input")->place()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.Input("Input")->place()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, @@ -230,9 +231,9 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index 7f95d16f09b5182e4da33763751ac87b53f41cf3..7823b9d8501782d84e23916ef54892338f617ca3 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -55,8 +55,9 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -125,8 +126,9 @@ class TeacherStudentSigmoidLossGradientOp // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index a438832b5dc5d6cb7f6a87bda5c227ad94598701..6663d3f5571af0bf2ed020fab7290a1abe7c1afc 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -139,9 +139,9 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index db763a051d1e08b962a40913d290c69e7c61ec32..fdf5148eb8790bb3a9fd89ab3ec107569caf7c9d 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); } }; diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 226aad03845d7629d7be556b394ebe06abba44d5..eab6d437d4705d95428ef7fa8bb47dcb84391b08 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -164,9 +165,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; @@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout_, library_); } }; @@ -268,9 +270,9 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace(), layout_, library_); } }; diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc index 566939afaa4b435c58717a49cfdec69d6c616587..0c72275c5bbacceab6391935e286614ac8f0adca 100644 --- a/paddle/fluid/operators/tree_conv_op.cc +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -104,8 +104,9 @@ class TreeConvOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; @@ -153,8 +154,9 @@ class TreeConvGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("NodesVector")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"), + ctx.device_context()); } }; } // namespace operators diff --git a/paddle/fluid/operators/unfold_op.cc b/paddle/fluid/operators/unfold_op.cc index d21340b478b590259b04ce66a3db129fdb50c7e7..99907e066b27a5f0ab586d0960d9c9c03d306c2f 100644 --- a/paddle/fluid/operators/unfold_op.cc +++ b/paddle/fluid/operators/unfold_op.cc @@ -120,8 +120,9 @@ class UnfoldOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } }; @@ -141,9 +142,9 @@ class UnfoldGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Y"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index fae5041c9328fe48aed388c1400aefaaf8bea5e7..0693df843eb8411f0ed39adbffb0f36e7021037f 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -74,8 +74,9 @@ class UnpoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: @@ -117,8 +118,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); } public: diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index fc849e73c579f3457852e05dec404c001b74b19e..e55de4508bcd46be5d0b4ae1766213eb688b50d3 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -215,9 +215,9 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index df9212f9c930c683045edcb56c85cd661e54f769..d7f6714710fc6528f5bb6c87ce845e34d6a791b8 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -60,8 +60,9 @@ class WarpCTCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context(), layout_, library_); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context(), layout_, library_); } }; @@ -173,8 +174,9 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(ctx.Input("Logits")->type(), - ctx.device_context()); + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Logits"), + ctx.device_context()); } };