提交 3a2afbf0 编写于 作者: S sneaxiy

polish code

test=develop
上级 7f6e513b
...@@ -310,18 +310,6 @@ class ExecutionContext { ...@@ -310,18 +310,6 @@ class ExecutionContext {
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const; const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
......
...@@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) { ...@@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) {
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (var.Type()) { switch (var.Type()) {
case proto::VarType_Type_LOD_TENSOR: case proto::VarType::LOD_TENSOR:
visitor(var.Get<LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarType_Type_LOD_RANK_TABLE: case proto::VarType::LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case proto::VarType_Type_LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarType_Type_SELECTED_ROWS: case proto::VarType::SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarType_Type_READER: case proto::VarType::READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
default: default:
......
...@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <typeinfo> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
...@@ -136,8 +136,6 @@ struct VarTypeRegistryImpl { ...@@ -136,8 +136,6 @@ struct VarTypeRegistryImpl {
// Users should add other variable types below. // Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types. // Paddle would generate unique Ids for each registered variable types.
class Scope;
using VarTypeRegistry = detail::VarTypeRegistryImpl< using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable, Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *, LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
...@@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE); ...@@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY); REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST); REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER); REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
/** End of variable type registration */ /** End of variable type registration */
......
...@@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) { ...@@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY)); ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST)); ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER)); ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
ASSERT_TRUE(CheckVarId<int>(proto::VarType::INT32));
ASSERT_TRUE(CheckVarId<float>(proto::VarType::FP32));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS);
ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES);
ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE);
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY,
proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST);
ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER);
ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH);
ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST);
ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW);
ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE);
ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32);
ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32);
} }
TEST(var_type_traits, test_registry) { TEST(var_type_traits, test_registry) {
......
...@@ -67,7 +67,6 @@ class Variable { ...@@ -67,7 +67,6 @@ class Variable {
private: private:
struct Placeholder { struct Placeholder {
explicit Placeholder(int type) : type_(type) {}
virtual ~Placeholder() = default; virtual ~Placeholder() = default;
inline int Type() const { return type_; } inline int Type() const { return type_; }
...@@ -75,6 +74,11 @@ class Variable { ...@@ -75,6 +74,11 @@ class Variable {
inline void* Ptr() { return ptr_; } inline void* Ptr() { return ptr_; }
protected: protected:
inline void Init(void* p, int type) {
ptr_ = p;
type_ = type;
}
void* ptr_; void* ptr_;
int type_; int type_;
}; };
...@@ -86,9 +90,7 @@ class Variable { ...@@ -86,9 +90,7 @@ class Variable {
static_assert( static_assert(
IsRegisteredVarType<T>(), IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h"); "Not registered type. Please register T inside var_type_traits.h");
PlaceholderImpl() : Placeholder(VarTypeTrait<T>::kId) { PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
this->ptr_ = &obj_;
}
private: private:
T obj_; T obj_;
......
...@@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
...@@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::DataLayout layout = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
...@@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
...@@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
...@@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
...@@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -450,6 +451,18 @@ class ScopedActivationDescriptor { ...@@ -450,6 +451,18 @@ class ScopedActivationDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor); DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
class ScopedCTCLossDescriptor { class ScopedCTCLossDescriptor {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册