diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f6a43804ef2fd73c4a2c2c3b3dfbb90bff1c451b..a3b4a8c0829ae3324e933309b2eaea35fe571997 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -35,6 +35,17 @@ std::vector> kKernelPriority = { std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), }; +proto::VarType::Type GetDataTypeOfVar(const Variable* var) { + if (var->IsType()) { + return framework::ToDataType(var->Get().type()); + } else if (var->IsType()) { + return framework::ToDataType( + var->Get().value().type()); + } else { + PADDLE_THROW("Var should be LoDTensor or SelectedRows"); + } +} + static DDim GetDims(const Scope& scope, const std::string& name) { Variable* var = scope.FindVar(name); if (var == nullptr) { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 41214b41cb68cbd7049552f39195ae5257e0d06f..b7a7c69b4c8493f945926c75797c49d327a3197e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) { return var_name + kGradVarSuffix; } +proto::VarType::Type GetDataTypeOfVar(const Variable* var); + class OperatorBase; class ExecutionContext; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index deabcdc99f819851b2df9bb0c7b05a5b339568f3..bf33be310686640fa187a07cf46a157b7f433340 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -18,22 +18,6 @@ limitations under the License. */ namespace paddle { namespace operators { -static inline framework::OpKernelType ExpectedKernelType( - const framework::ExecutionContext& ctx) { - auto* table_var = ctx.InputVar("W"); - if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().type()), - ctx.device_context()); - } else if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().value().type()), - ctx.device_context()); - } else { - PADDLE_THROW("W should be LoDTensor or SelectedRows"); - } -} - class LookupTableOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return ExpectedKernelType(ctx); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return ExpectedKernelType(ctx); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index 9cdc5b3f1e01bdddf56488571eadaeb9d2dff0b2..074fa9e00f2ec531f324ff10113d95144687d500 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -43,19 +43,8 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto* table_var = ctx.InputVar("Param"); - if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().type()), - ctx.device_context()); - } else if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType( - table_var->Get().value().type()), - ctx.device_context()); - } else { - PADDLE_THROW("Param should be LoDTensor or SelectedRows"); - } + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + return framework::OpKernelType(data_type, ctx.device_context()); } };