From e66bd4cb732003d083f981bc5b2c7fe238590aa0 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 4 Apr 2018 23:31:13 +0800 Subject: [PATCH] add GetDataTypeOfVar --- paddle/fluid/framework/operator.cc | 11 +++++++++++ paddle/fluid/framework/operator.h | 2 ++ paddle/fluid/operators/lookup_table_op.cc | 22 ++++------------------ paddle/fluid/operators/sgd_op.cc | 15 ++------------- 4 files changed, 19 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f6a43804e..a3b4a8c08 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -35,6 +35,17 @@ std::vector> kKernelPriority = { std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), }; +proto::VarType::Type GetDataTypeOfVar(const Variable* var) { + if (var->IsType()) { + return framework::ToDataType(var->Get().type()); + } else if (var->IsType()) { + return framework::ToDataType( + var->Get().value().type()); + } else { + PADDLE_THROW("Var should be LoDTensor or SelectedRows"); + } +} + static DDim GetDims(const Scope& scope, const std::string& name) { Variable* var = scope.FindVar(name); if (var == nullptr) { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 41214b41c..b7a7c69b4 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) { return var_name + kGradVarSuffix; } +proto::VarType::Type GetDataTypeOfVar(const Variable* var); + class OperatorBase; class ExecutionContext; diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index deabcdc99..bf33be310 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -18,22 +18,6 @@ limitations under the License. */ namespace paddle { namespace operators { -static inline framework::OpKernelType ExpectedKernelType( - const framework::ExecutionContext& ctx) { - auto* table_var = ctx.InputVar("W"); - if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().type()), - ctx.device_context()); - } else if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().value().type()), - ctx.device_context()); - } else { - PADDLE_THROW("W should be LoDTensor or SelectedRows"); - } -} - class LookupTableOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return ExpectedKernelType(ctx); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; @@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return ExpectedKernelType(ctx); + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sgd_op.cc b/paddle/fluid/operators/sgd_op.cc index 9cdc5b3f1..074fa9e00 100644 --- a/paddle/fluid/operators/sgd_op.cc +++ b/paddle/fluid/operators/sgd_op.cc @@ -43,19 +43,8 @@ class SGDOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto* table_var = ctx.InputVar("Param"); - if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType(table_var->Get().type()), - ctx.device_context()); - } else if (table_var->IsType()) { - return framework::OpKernelType( - framework::ToDataType( - table_var->Get().value().type()), - ctx.device_context()); - } else { - PADDLE_THROW("Param should be LoDTensor or SelectedRows"); - } + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); + return framework::OpKernelType(data_type, ctx.device_context()); } }; -- GitLab