diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 4492470e2ad4cabc28768152881415f9d2fb6077..39190d07b4ccdd5ffd03e2d50bb0e577ac00af75 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -310,18 +310,6 @@ class ExecutionContext { const RuntimeContext& ctx_; }; -inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_cudnn = ctx.Attr("use_cudnn"); - use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); -#ifdef PADDLE_WITH_CUDA - if (use_cudnn) { - auto& dev_ctx = ctx.device_context(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; - } -#endif - return use_cudnn; -} - template <> const Tensor* ExecutionContext::Input(const std::string& name) const; diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index f1cbaf3fdc22cc115c66ac12623982c530b08125..73be446f71f193bea203c986b482e6b98a9826c5 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) { template inline void VisitVarType(const framework::Variable& var, Visitor visitor) { switch (var.Type()) { - case proto::VarType_Type_LOD_TENSOR: + case proto::VarType::LOD_TENSOR: visitor(var.Get()); return; - case proto::VarType_Type_LOD_RANK_TABLE: + case proto::VarType::LOD_RANK_TABLE: visitor(var.Get()); return; - case proto::VarType_Type_LOD_TENSOR_ARRAY: + case proto::VarType::LOD_TENSOR_ARRAY: visitor(var.Get()); return; - case proto::VarType_Type_SELECTED_ROWS: + case proto::VarType::SELECTED_ROWS: visitor(var.Get()); return; - case proto::VarType_Type_READER: + case proto::VarType::READER: visitor(var.Get()); return; default: diff --git a/paddle/fluid/framework/var_type_inference_test.cc b/paddle/fluid/framework/var_type_inference_test.cc index 7842168f603885ce7dc87d2a01dfa4f544389faa..2a75394fca719196a9d53894b080598e942baa45 100644 --- a/paddle/fluid/framework/var_type_inference_test.cc +++ b/paddle/fluid/framework/var_type_inference_test.cc @@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) { op->InferVarType(prog.MutableBlock(0)); - ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, + ASSERT_EQ(proto::VarType::LOD_TENSOR, prog.MutableBlock(0)->Var("test2_out")->GetType()); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index b51b4933e6c9ef5eaaa9cdcb11d0c3deea4f6d65..1b535219c1510fff0553cfbbae4c0d069c89dfbb 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor_array.h" @@ -136,8 +136,6 @@ struct VarTypeRegistryImpl { // Users should add other variable types below. // Paddle would generate unique Ids for each registered variable types. -class Scope; - using VarTypeRegistry = detail::VarTypeRegistryImpl< Tensor, LoDTensor, SelectedRows, std::vector, LoDRankTable, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, @@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST); REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER); +REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32); +REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32); /** End of variable type registration */ diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 1c7d9f2abed20f6d7f2f98a22e2e05aace50dad5..00840d634d802cfe17fbff127a75606cb5e2cf79 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) { ASSERT_TRUE(CheckVarId(proto::VarType::LOD_TENSOR_ARRAY)); ASSERT_TRUE(CheckVarId(proto::VarType::PLACE_LIST)); ASSERT_TRUE(CheckVarId(proto::VarType::READER)); + ASSERT_TRUE(CheckVarId(proto::VarType::INT32)); + ASSERT_TRUE(CheckVarId(proto::VarType::FP32)); + + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR); + ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS); + ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES); + ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE); + ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY, + proto::VarType::LOD_TENSOR_ARRAY); + ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST); + ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER); + ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH); + ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST); + ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW); + ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE); + ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32); + ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32); } TEST(var_type_traits, test_registry) { diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 8aa68942ad16fca94b84fa420a5d2092aae9408d..b9d07da822cf1eb42859e1d7d84437582fada8ff 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -67,7 +67,6 @@ class Variable { private: struct Placeholder { - explicit Placeholder(int type) : type_(type) {} virtual ~Placeholder() = default; inline int Type() const { return type_; } @@ -75,6 +74,11 @@ class Variable { inline void* Ptr() { return ptr_; } protected: + inline void Init(void* p, int type) { + ptr_ = p; + type_ = type; + } + void* ptr_; int type_; }; @@ -86,9 +90,7 @@ class Variable { static_assert( IsRegisteredVarType(), "Not registered type. Please register T inside var_type_traits.h"); - PlaceholderImpl() : Placeholder(VarTypeTrait::kId) { - this->ptr_ = &obj_; - } + PlaceholderImpl() { this->Init(&obj_, VarTypeTrait::kId); } private: T obj_; diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 0c048738522ca051d92d3d056209736489eb2113..1de59a5165c83a314a0ff8f4e4351aa3326beb67 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } #endif @@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index c76bde99f4a0acfccc9eb0c470276ca72c400b70..8e0d2824953a372b96d5819be658636f9a3d78ba 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( framework::DataLayout layout = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library = framework::LibraryType::kCUDNN; } #endif @@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index be53a62cc9ccf0e45e66b932b701f9aa63565a35..14a2524bd8f4a9f7685c84f1d9767f5f7eedf0e7 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 6781cdf9f34481aaf95b8e3bc39011b40e0146f0..5399ae556e7f38a551d680704d8d825e2fdba88a 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ad37967f0ac3c968d2b954d4ee34452b364e9d2e..bc889a5a042a27838ba6ba0fccb187ec11b5f0c5 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif @@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index add03bad13dfdd0fdcaa3da44727eaa2b5012d7e..e2ae7caae1ebe46b30c811ae4537f718ca587939 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (framework::CanCUDNNBeUsed(ctx)) { + if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; } #endif diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 74b0942379014ae3c796310bd28141dec5bacd39..61a25064d17994e3ce5853017263f24a859c69be 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" @@ -450,6 +451,18 @@ class ScopedActivationDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); }; +inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_cudnn = ctx.Attr("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.device_context(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + #if CUDNN_VERSION >= 7001 class ScopedCTCLossDescriptor { public: