From 3a2afbf02e9bcc3d0a690564b8ea811b6cb10685 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Tue, 25 Dec 2018 04:24:44 +0000 Subject: [PATCH] polish code test=develop --- paddle/fluid/framework/operator.h | 12 ------------ paddle/fluid/framework/var_type.h | 10 +++++----- .../fluid/framework/var_type_inference_test.cc | 2 +- paddle/fluid/framework/var_type_traits.h | 6 +++--- paddle/fluid/framework/var_type_traits_test.cc | 17 +++++++++++++++++ paddle/fluid/framework/variable.h | 10 ++++++---- paddle/fluid/operators/affine_grid_op.cc | 4 ++-- paddle/fluid/operators/conv_op.cc | 4 ++-- paddle/fluid/operators/grid_sampler_op.cc | 4 ++-- paddle/fluid/operators/pool_op.cc | 4 ++-- paddle/fluid/operators/softmax_op.cc | 4 ++-- paddle/fluid/operators/warpctc_op.cc | 2 +- paddle/fluid/platform/cudnn_helper.h | 13 +++++++++++++ 13 files changed, 56 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 4492470e2..39190d07b 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -310,18 +310,6 @@ class ExecutionContext { const RuntimeContext& ctx_; }; -inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (use_cudnn) { - auto& dev_ctx = ctx.device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - return use_cudnn; -} - template <> const Tensor* ExecutionContext::Input(const std::string& name) const; diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index f1cbaf3fd..73be446f7 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) { template inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (var.Type()) { - case proto::VarType_Type_LOD_TENSOR: + case proto::VarType::LOD_TENSOR: visitor(var.Get()); return; - case proto::VarType_Type_LOD_RANK_TABLE: + case proto::VarType::LOD_RANK_TABLE: visitor(var.Get()); return; - case proto::VarType_Type_LOD_TENSOR_ARRAY: + case proto::VarType::LOD_TENSOR_ARRAY: visitor(var.Get()); return; - case proto::VarType_Type_SELECTED_ROWS: + case proto::VarType::SELECTED_ROWS: visitor(var.Get()); return; - case proto::VarType_Type_READER: + case proto::VarType::READER: visitor(var.Get()); return; default: diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 7842168f6..2a75394fc 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) { op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, + ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test2_out")->GetType()); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index b51b4933e..1b535219c 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor_array.h" @@ -136,8 +136,6 @@ struct VarTypeRegistryImpl { // Users should add other variable types below. // Paddle would generate unique Ids for each registered variable types. -class Scope; - using VarTypeRegistry = detail::VarTypeRegistryImpl< Tensor, LoDTensor, SelectedRows, std::vector, LoDRankTable, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, @@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST); REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER); +REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32); +REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32); /** End of variable type registration */ diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 1c7d9f2ab..00840d634 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) { ASSERT_TRUE(CheckVarId(proto::VarType::LOD_TENSOR_ARRAY)); ASSERT_TRUE(CheckVarId(proto::VarType::PLACE_LIST)); ASSERT_TRUE(CheckVarId(proto::VarType::READER)); + ASSERT_TRUE(CheckVarId(proto::VarType::INT32)); + ASSERT_TRUE(CheckVarId(proto::VarType::FP32)); + + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR); + ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS); + ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES); + ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE); + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY, + proto::VarType::LOD_TENSOR_ARRAY); + ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST); + ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER); + ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH); + ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST); + ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW); + ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE); + ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32); + ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32); } TEST(var_type_traits, test_registry) { diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 8aa68942a..b9d07da82 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -67,7 +67,6 @@ class Variable { private: struct Placeholder { - explicit Placeholder(int type) : type_(type) {} virtual ~Placeholder() = default; inline int Type() const { return type_; } @@ -75,6 +74,11 @@ class Variable { inline void* Ptr() { return ptr_; } protected: + inline void Init(void* p, int type) { + ptr_ = p; + type_ = type; + } + void* ptr_; int type_; }; @@ -86,9 +90,7 @@ class Variable { static_assert( IsRegisteredVarType(), "Not registered type. Please register T inside var_type_traits.h"); - PlaceholderImpl() : Placeholder(VarTypeTrait::kId) { - this->ptr_ = &obj_; - } + PlaceholderImpl() { this->Init(&obj_, VarTypeTrait::kId); } private: T obj_; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 0c0487385..1de59a516 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } #endif @@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index c76bde99f..8e0d28249 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::DataLayout layout = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } #endif @@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index be53a62cc..14a2524bd 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 6781cdf9f..5399ae556 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ad37967f0..bc889a5a0 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index add03bad1..e2ae7caae 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 74b094237..61a25064d 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -450,6 +451,18 @@ class ScopedActivationDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; +inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + #if CUDNN_VERSION >= 7001 class ScopedCTCLossDescriptor { public: -- GitLab