提交 3a2afbf0 编写于 作者: S sneaxiy

polish code

test=develop
上级 7f6e513b
......@@ -310,18 +310,6 @@ class ExecutionContext {
const RuntimeContext& ctx_;
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
......
......@@ -46,19 +46,19 @@ inline proto::VarType::Type ToVarType(int type) {
template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (var.Type()) {
case proto::VarType_Type_LOD_TENSOR:
case proto::VarType::LOD_TENSOR:
visitor(var.Get<LoDTensor>());
return;
case proto::VarType_Type_LOD_RANK_TABLE:
case proto::VarType::LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>());
return;
case proto::VarType_Type_LOD_TENSOR_ARRAY:
case proto::VarType::LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>());
return;
case proto::VarType_Type_SELECTED_ROWS:
case proto::VarType::SELECTED_ROWS:
visitor(var.Get<SelectedRows>());
return;
case proto::VarType_Type_READER:
case proto::VarType::READER:
visitor(var.Get<ReaderHolder>());
return;
default:
......
......@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR,
ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType());
}
......
......@@ -17,7 +17,7 @@
#include <map>
#include <string>
#include <tuple>
#include <typeinfo>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
......@@ -136,8 +136,6 @@ struct VarTypeRegistryImpl {
// Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types.
class Scope;
using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
......@@ -171,6 +169,8 @@ REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
/** End of variable type registration */
......
......@@ -88,6 +88,23 @@ TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
ASSERT_TRUE(CheckVarId<int>(proto::VarType::INT32));
ASSERT_TRUE(CheckVarId<float>(proto::VarType::FP32));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS);
ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES);
ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE);
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY,
proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST);
ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER);
ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH);
ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST);
ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW);
ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE);
ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32);
ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32);
}
TEST(var_type_traits, test_registry) {
......
......@@ -67,7 +67,6 @@ class Variable {
private:
struct Placeholder {
explicit Placeholder(int type) : type_(type) {}
virtual ~Placeholder() = default;
inline int Type() const { return type_; }
......@@ -75,6 +74,11 @@ class Variable {
inline void* Ptr() { return ptr_; }
protected:
inline void Init(void* p, int type) {
ptr_ = p;
type_ = type;
}
void* ptr_;
int type_;
};
......@@ -86,9 +90,7 @@ class Variable {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PlaceholderImpl() : Placeholder(VarTypeTrait<T>::kId) {
this->ptr_ = &obj_;
}
PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
private:
T obj_;
......
......@@ -74,7 +74,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
......@@ -184,7 +184,7 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -84,7 +84,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::DataLayout layout = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
......@@ -369,7 +369,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -59,7 +59,7 @@ class GridSampleOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......@@ -155,7 +155,7 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -92,7 +92,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......@@ -122,7 +122,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -50,7 +50,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......@@ -157,7 +157,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -51,7 +51,7 @@ class WarpCTCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (framework::CanCUDNNBeUsed(ctx)) {
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
......@@ -450,6 +451,18 @@ class ScopedActivationDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
#if CUDNN_VERSION >= 7001
class ScopedCTCLossDescriptor {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册