未验证 提交 c0bcff00 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #14962 from sneaxiy/rewrite_variable_type

Rewrite variable type
...@@ -68,18 +68,23 @@ cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memor ...@@ -68,18 +68,23 @@ cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memor
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
cc_test(reader_test SRCS reader_test.cc DEPS reader) cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
cc_library(threadpool SRCS threadpool.cc DEPS enforce) cc_library(threadpool SRCS threadpool.cc DEPS enforce)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(scope SRCS scope.cc DEPS glog threadpool) cc_library(var_type_traits SRCS var_type_traits DEPS lod_tensor selected_rows framework_proto)
if (WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
cc_test(var_type_traits_test SRCS var_type_traits_test.cc DEPS var_type_traits)
cc_library(scope SRCS scope.cc DEPS glog threadpool var_type_traits)
cc_library(scope_pool SRCS scope_pool.cc DEPS scope) cc_library(scope_pool SRCS scope_pool.cc DEPS scope)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_test(variable_test SRCS variable_test.cc DEPS tensor var_type_traits)
cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor) cc_library(data_device_transform SRCS data_device_transform.cc DEPS tensor)
nv_test(data_device_transform_test SRCS data_device_transform_test.cu nv_test(data_device_transform_test SRCS data_device_transform_test.cu
DEPS operator op_registry device_context math_function) DEPS operator op_registry device_context math_function scope)
if(WITH_GPU) if(WITH_GPU)
if (WIN32) if (WIN32)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
......
...@@ -88,7 +88,7 @@ void EagerDeletionOpHandle::RunImpl() { ...@@ -88,7 +88,7 @@ void EagerDeletionOpHandle::RunImpl() {
} }
} else { } else {
PADDLE_THROW("Type %s of %s is not supported eager deletion", PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name); framework::ToTypeName(var->Type()), name);
} }
} }
......
...@@ -24,7 +24,7 @@ static void VisitVariable(Variable* var, Func* func) { ...@@ -24,7 +24,7 @@ static void VisitVariable(Variable* var, Func* func) {
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
(*func)(var->GetMutable<SelectedRows>()); (*func)(var->GetMutable<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", var->Type().name()); PADDLE_THROW("Not supported type %s", ToTypeName(var->Type()));
} }
} }
...@@ -35,7 +35,7 @@ static void VisitVariable(const Variable& var, Func* func) { ...@@ -35,7 +35,7 @@ static void VisitVariable(const Variable& var, Func* func) {
} else if (var.IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
(*func)(var.Get<SelectedRows>()); (*func)(var.Get<SelectedRows>());
} else { } else {
PADDLE_THROW("Not supported type %s", var.Type().name()); PADDLE_THROW("Not supported type %s", ToTypeName(var.Type()));
} }
} }
......
...@@ -119,7 +119,7 @@ static void DeleteUnusedTensors( ...@@ -119,7 +119,7 @@ static void DeleteUnusedTensors(
} }
} else { } else {
PADDLE_THROW("Type %s of %s is not supported eager deletion", PADDLE_THROW("Type %s of %s is not supported eager deletion",
var->Type().name(), name); framework::ToTypeName(var->Type()), name);
} }
} }
} }
......
...@@ -380,7 +380,7 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) { ...@@ -380,7 +380,7 @@ const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
return &(var.Get<SelectedRows>().value()); return &(var.Get<SelectedRows>().value());
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var.Type().name()); ToTypeName(var.Type()));
} }
} }
...@@ -391,7 +391,7 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { ...@@ -391,7 +391,7 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
return var->GetMutable<SelectedRows>()->mutable_value(); return var->GetMutable<SelectedRows>()->mutable_value();
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name()); ToTypeName(var->Type()));
} }
} }
...@@ -485,7 +485,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -485,7 +485,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
PADDLE_ENFORCE( PADDLE_ENFORCE(
var->IsType<LoDTensor>(), var->IsType<LoDTensor>(),
"should be LoDTensor, but the received type is %s", "should be LoDTensor, but the received type is %s",
var->Type().name()); ToTypeName(var->Type()));
return &(var->Get<LoDTensor>()); return &(var->Get<LoDTensor>());
}); });
return res; return res;
...@@ -504,7 +504,7 @@ const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>( ...@@ -504,7 +504,7 @@ const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
PADDLE_ENFORCE( PADDLE_ENFORCE(
var->IsType<LoDTensor>(), var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s", "%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name()); sub_name, ToTypeName(var->Type()));
return &(var->Get<LoDTensor>()); return &(var->Get<LoDTensor>());
}); });
return res; return res;
...@@ -533,7 +533,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( ...@@ -533,7 +533,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
PADDLE_ENFORCE( PADDLE_ENFORCE(
var->IsType<LoDTensor>(), var->IsType<LoDTensor>(),
"%s should be LoDTensor, but the received type is %s", "%s should be LoDTensor, but the received type is %s",
sub_name, var->Type().name()); sub_name, ToTypeName(var->Type()));
return var->GetMutable<LoDTensor>(); return var->GetMutable<LoDTensor>();
}); });
return res; return res;
...@@ -775,7 +775,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -775,7 +775,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW( PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variables " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s.", "type_id is %s.",
var->Type().name()); ToTypeName(var->Type()));
} }
} }
...@@ -798,7 +798,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -798,7 +798,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name()); ToTypeName(var->Type()));
} }
} }
......
...@@ -165,11 +165,9 @@ std::string Scope::Rename(const std::string& origin_name) const { ...@@ -165,11 +165,9 @@ std::string Scope::Rename(const std::string& origin_name) const {
Variable* Scope::VarInternal(const std::string& name) { Variable* Scope::VarInternal(const std::string& name) {
auto* v = FindVarLocally(name); auto* v = FindVarLocally(name);
if (v != nullptr) return v; if (v != nullptr) return v;
v = new Variable(); v = new Variable();
vars_[name].reset(v); vars_.emplace(name, std::unique_ptr<Variable>(v));
VLOG(3) << "Create variable " << name; VLOG(3) << "Create variable " << name;
v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
......
...@@ -19,52 +19,50 @@ limitations under the License. */ ...@@ -19,52 +19,50 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T> template <typename T>
inline bool IsType(const std::type_index& type_index) { inline bool IsType(const std::type_index& type) {
return type_index == std::type_index(typeid(T)); return type == typeid(T);
} }
inline proto::VarType::Type ToVarType(std::type_index type) { inline proto::VarType::Type ToVarType(int type) {
if (IsType<LoDTensor>(type)) { switch (type) {
return proto::VarType_Type_LOD_TENSOR; case proto::VarType::LOD_TENSOR:
} else if (IsType<LoDRankTable>(type)) { case proto::VarType::SELECTED_ROWS:
return proto::VarType_Type_LOD_RANK_TABLE; case proto::VarType::LOD_RANK_TABLE:
} else if (IsType<LoDTensorArray>(type)) { case proto::VarType::LOD_TENSOR_ARRAY:
return proto::VarType_Type_LOD_TENSOR_ARRAY; case proto::VarType::READER:
} else if (IsType<SelectedRows>(type)) { return static_cast<proto::VarType::Type>(type);
return proto::VarType_Type_SELECTED_ROWS; default:
} else if (IsType<ReaderHolder>(type)) { PADDLE_THROW("ToVarType:Unsupported type %d", type);
return proto::VarType_Type_READER;
} else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
} }
template <typename Visitor> template <typename Visitor>
inline void VisitVarType(const framework::Variable& var, Visitor visitor) { inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
switch (ToVarType(var.Type())) { switch (var.Type()) {
case proto::VarType_Type_LOD_TENSOR: case proto::VarType::LOD_TENSOR:
visitor(var.Get<LoDTensor>()); visitor(var.Get<LoDTensor>());
return; return;
case proto::VarType_Type_LOD_RANK_TABLE: case proto::VarType::LOD_RANK_TABLE:
visitor(var.Get<LoDRankTable>()); visitor(var.Get<LoDRankTable>());
return; return;
case proto::VarType_Type_LOD_TENSOR_ARRAY: case proto::VarType::LOD_TENSOR_ARRAY:
visitor(var.Get<LoDTensorArray>()); visitor(var.Get<LoDTensorArray>());
return; return;
case proto::VarType_Type_SELECTED_ROWS: case proto::VarType::SELECTED_ROWS:
visitor(var.Get<SelectedRows>()); visitor(var.Get<SelectedRows>());
return; return;
case proto::VarType_Type_READER: case proto::VarType::READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
default: default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); PADDLE_THROW("Not supported visit type, %s", ToTypeName(var.Type()));
} }
} }
......
...@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) { ...@@ -108,7 +108,7 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
op->InferVarType(prog.MutableBlock(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, ASSERT_EQ(proto::VarType::LOD_TENSOR,
prog.MutableBlock(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include <cudnn.h>
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif
namespace paddle {
namespace framework {
// Besides registering variable type id, it is helpful to register a
// var_id -> std::type_index map (for example, get type names according to id)
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct VarIdToTypeIndexMapInitializerImpl {
template <typename MapType1, typename MapType2>
static void Init(MapType1 *id_to_type, MapType2 *type_to_id) {
using Type =
typename std::tuple_element<kStart, VarTypeRegistry::ArgTuple>::type;
static_assert(!std::is_same<Type, void>::value, "Type cannot be void");
constexpr int kId = VarTypeTrait<Type>::kId;
auto type = std::type_index(typeid(Type));
PADDLE_ENFORCE(id_to_type->count(kId) == 0,
"Registered duplicate type id %d for type %s", kId,
type.name());
PADDLE_ENFORCE(type_to_id->count(type) == 0,
"Registered duplicate type_index %s for id %d", type.name(),
kId);
id_to_type->emplace(kId, type);
type_to_id->emplace(type, kId);
VarIdToTypeIndexMapInitializerImpl<kStart + 1, kEnd,
kStart + 1 == kEnd>::Init(id_to_type,
type_to_id);
}
};
template <int kStart, int kEnd>
struct VarIdToTypeIndexMapInitializerImpl<kStart, kEnd, true> {
template <typename MapType1, typename MapType2>
static void Init(MapType1 *, MapType2 *) {}
};
// VarIdToTypeIndexMapInitializer is designed to initialize var_id ->
// std::type_index map and std::type_index -> var_id map
using VarIdToTypeIndexMapInitializer =
VarIdToTypeIndexMapInitializerImpl<0, VarTypeRegistry::kRegisteredTypeNum,
VarTypeRegistry::kRegisteredTypeNum ==
0>;
struct VarIdToTypeIndexMapHolder {
DISABLE_COPY_AND_ASSIGN(VarIdToTypeIndexMapHolder);
public:
static const std::type_index &ToTypeIndex(int var_id) {
auto it = Instance().id_to_type_map_.find(var_id);
PADDLE_ENFORCE(it != Instance().id_to_type_map_.end(),
"VarId %d is not registered.", var_id);
return it->second;
}
static int ToTypeId(const std::type_index &type) {
auto it = Instance().type_to_id_map_.find(type);
PADDLE_ENFORCE(it != Instance().type_to_id_map_.end(),
"VarType %s is not registered.", type.name());
return it->second;
}
private:
VarIdToTypeIndexMapHolder() {
VarIdToTypeIndexMapInitializer::Init(&id_to_type_map_, &type_to_id_map_);
}
static const VarIdToTypeIndexMapHolder &Instance() {
static const VarIdToTypeIndexMapHolder instance;
return instance;
}
std::unordered_map<int, std::type_index> id_to_type_map_;
std::unordered_map<std::type_index, int> type_to_id_map_;
};
} // namespace detail
const std::type_index &ToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
}
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); }
int ToTypeId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <tuple>
#include <typeindex>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include <cudnn.h>
#ifndef _WIN32
#include <nccl.h>
#endif
#endif
// Users should add forward declarations here
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
class Communicator;
#endif
#endif
} // namespace platform
namespace framework {
class Tensor;
class LoDTensor;
class SelectedRows;
class LoDRankTable;
class ReaderHolder;
class Scope;
} // namespace framework
namespace operators {
template <typename T>
class AlgorithmsCache;
class CudnnRNNCache;
namespace reader {
class LoDTensorBlockingQueueHolder;
} // namespace reader
} // namespace operators
} // namespace paddle
namespace paddle {
namespace framework {
const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type);
namespace detail {
template <bool kStop, int kStart, int kEnd, typename T1, typename T2,
typename... Args>
struct TypePosFinderImpl {
static constexpr int kPos =
std::is_same<T1, T2>::value
? kStart
: TypePosFinderImpl<kStart + 2 == kEnd, kStart + 1, kEnd, T1,
Args...>::kPos;
};
template <int kStart, int kEnd, typename T1, typename T2>
struct TypePosFinderImpl<true, kStart, kEnd, T1, T2> {
static constexpr int kPos = std::is_same<T1, T2>::value ? kStart : -1;
};
// TypePosFinder helps to find the position in which T is inside Args...
// If T is not inside Args..., kPos would be -1
template <typename T, typename... Args>
struct TypePosFinder {
static constexpr int kPos =
TypePosFinderImpl<sizeof...(Args) == 1, 0, sizeof...(Args), T,
Args...>::kPos;
};
template <typename... Args>
struct VarTypeRegistryImpl {
static constexpr size_t kRegisteredTypeNum = sizeof...(Args);
using ArgTuple = std::tuple<Args...>;
// TypePos() returns the position in which T is inside Args...
// If T is not inside Args..., return -1
template <typename T>
static constexpr int TypePos() {
return TypePosFinder<T, Args...>::kPos;
}
// IsRegistered() returns whether T is registered inside RegistryImpl
template <typename T>
static constexpr bool IsRegistered() {
return TypePos<T>() >= 0;
}
};
} // namespace detail
#define REG_PROTO_VAR_TYPE_TRAIT(type, proto_id) \
template <> \
struct VarTypeTrait<type> { \
static_assert(VarTypeRegistry::IsRegistered<type>(), \
"Must be registered type"); \
using Type = type; \
static constexpr int kId = static_cast<int>(proto_id); \
}
/**
* The following codes are designed to register variable types.
* Only registered types can be stored in Variable.
* This registry mechanism is designed to speed up Variable.
*
* Caution: If you want to add more var types, please consider carefully
* whether you really need to add it.
*/
// Users should add other variable types below.
// Paddle would generate unique Ids for each registered variable types.
using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor, LoDTensor, SelectedRows, std::vector<Scope *>, LoDRankTable,
LoDTensorArray, platform::PlaceList, ReaderHolder, std::string, Scope *,
std::map<size_t, Tensor>, operators::reader::LoDTensorBlockingQueueHolder,
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
ncclUniqueId, platform::Communicator,
#endif
operators::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>,
operators::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>,
operators::CudnnRNNCache,
#endif
int, float>;
template <typename T>
struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
using Type = T;
/**
* Unique VarType Id generation.
*
* The auto-generated id should not be the same as any protobuf id defined in
* framework.proto. Therefore, we generate id by adding the type pos and
* maximum protobuf id (i.e., proto::VarType::TUPLE).
*
* However, we may need more protobuf id in the future.
* To avoid changing this auto id generation algorithm frequently, we
* generate id by adding the type pos and twice of maximum protobuf id (i.e.,
* proto::VarType::TUPLE).
*/
static constexpr int kId = VarTypeRegistry::TypePos<T>() +
static_cast<int>(proto::VarType::TUPLE) * 2;
};
// Users should set some of variable type ids to be what is defined in
// framework.proto below
REG_PROTO_VAR_TYPE_TRAIT(LoDTensor, proto::VarType::LOD_TENSOR);
REG_PROTO_VAR_TYPE_TRAIT(SelectedRows, proto::VarType::SELECTED_ROWS);
REG_PROTO_VAR_TYPE_TRAIT(std::vector<Scope *>, proto::VarType::STEP_SCOPES);
REG_PROTO_VAR_TYPE_TRAIT(LoDRankTable, proto::VarType::LOD_RANK_TABLE);
REG_PROTO_VAR_TYPE_TRAIT(LoDTensorArray, proto::VarType::LOD_TENSOR_ARRAY);
REG_PROTO_VAR_TYPE_TRAIT(platform::PlaceList, proto::VarType::PLACE_LIST);
REG_PROTO_VAR_TYPE_TRAIT(ReaderHolder, proto::VarType::READER);
REG_PROTO_VAR_TYPE_TRAIT(int, proto::VarType::INT32);
REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
/** End of variable type registration */
template <typename T>
inline constexpr bool IsRegisteredVarType() {
return VarTypeRegistry::IsRegistered<T>();
}
#undef REG_PROTO_VAR_TYPE_TRAIT
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cstdint>
#include <iostream>
#include <unordered_set>
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#endif
namespace paddle {
namespace framework {
template <int kPos, int kEnd, bool kStop>
struct TypeIndexChecker {
template <typename SetType1, typename SetType2>
static void Check(SetType1 *var_id_set, SetType2 *type_index_set) {
using Type =
typename std::tuple_element<kPos, VarTypeRegistry::ArgTuple>::type;
static_assert(std::is_same<typename VarTypeTrait<Type>::Type, Type>::value,
"Type must be the same");
constexpr auto kId = VarTypeTrait<Type>::kId;
std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type);
EXPECT_EQ(ToTypeId(actual_type), kId);
EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type);
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
var_id_set->insert(kId);
type_index_set->insert(std::type_index(typeid(Type)));
TypeIndexChecker<kPos + 1, kEnd, kPos + 1 == kEnd>::Check(var_id_set,
type_index_set);
}
};
template <int kPos, int kEnd>
struct TypeIndexChecker<kPos, kEnd, true> {
template <typename SetType1, typename SetType2>
static void Check(SetType1 *, SetType2 *) {}
};
TEST(var_type_traits, check_no_duplicate_registry) {
constexpr size_t kRegisteredNum = VarTypeRegistry::kRegisteredTypeNum;
std::unordered_set<int> var_id_set;
std::unordered_set<std::type_index> type_index_set;
TypeIndexChecker<0, kRegisteredNum, kRegisteredNum == 0>::Check(
&var_id_set, &type_index_set);
}
template <typename T>
bool CheckVarId(int proto_id) {
static_assert(std::is_same<typename VarTypeTrait<T>::Type, T>::value,
"Type must be the same");
return VarTypeTrait<T>::kId == proto_id;
}
TEST(var_type_traits, check_proto_type_id) {
ASSERT_TRUE(CheckVarId<LoDTensor>(proto::VarType::LOD_TENSOR));
ASSERT_TRUE(CheckVarId<SelectedRows>(proto::VarType::SELECTED_ROWS));
ASSERT_TRUE(CheckVarId<std::vector<Scope *>>(proto::VarType::STEP_SCOPES));
ASSERT_TRUE(CheckVarId<LoDRankTable>(proto::VarType::LOD_RANK_TABLE));
ASSERT_TRUE(CheckVarId<LoDTensorArray>(proto::VarType::LOD_TENSOR_ARRAY));
ASSERT_TRUE(CheckVarId<platform::PlaceList>(proto::VarType::PLACE_LIST));
ASSERT_TRUE(CheckVarId<ReaderHolder>(proto::VarType::READER));
ASSERT_TRUE(CheckVarId<int>(proto::VarType::INT32));
ASSERT_TRUE(CheckVarId<float>(proto::VarType::FP32));
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR, proto::VarType::LOD_TENSOR);
ASSERT_EQ(proto::VarType_Type_SELECTED_ROWS, proto::VarType::SELECTED_ROWS);
ASSERT_EQ(proto::VarType_Type_STEP_SCOPES, proto::VarType::STEP_SCOPES);
ASSERT_EQ(proto::VarType_Type_LOD_RANK_TABLE, proto::VarType::LOD_RANK_TABLE);
ASSERT_EQ(proto::VarType_Type_LOD_TENSOR_ARRAY,
proto::VarType::LOD_TENSOR_ARRAY);
ASSERT_EQ(proto::VarType_Type_PLACE_LIST, proto::VarType::PLACE_LIST);
ASSERT_EQ(proto::VarType_Type_READER, proto::VarType::READER);
ASSERT_EQ(proto::VarType_Type_FEED_MINIBATCH, proto::VarType::FEED_MINIBATCH);
ASSERT_EQ(proto::VarType_Type_FETCH_LIST, proto::VarType::FETCH_LIST);
ASSERT_EQ(proto::VarType_Type_RAW, proto::VarType::RAW);
ASSERT_EQ(proto::VarType_Type_TUPLE, proto::VarType::TUPLE);
ASSERT_EQ(proto::VarType_Type_INT32, proto::VarType::INT32);
ASSERT_EQ(proto::VarType_Type_FP32, proto::VarType::FP32);
}
TEST(var_type_traits, test_registry) {
using Registry = detail::VarTypeRegistryImpl<int8_t, int32_t, size_t, double>;
ASSERT_TRUE(Registry::TypePos<int8_t>() == 0);
ASSERT_TRUE(Registry::TypePos<int32_t>() == 1);
ASSERT_TRUE(Registry::TypePos<size_t>() == 2);
ASSERT_TRUE(Registry::TypePos<double>() == 3);
ASSERT_TRUE(Registry::TypePos<float>() == -1);
}
} // namespace framework
} // namespace paddle
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/framework/var_type_traits.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,10 +27,14 @@ class Variable { ...@@ -27,10 +27,14 @@ class Variable {
public: public:
template <typename T> template <typename T>
const T& Get() const { const T& Get() const {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing"); PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(IsType<T>(), PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s", "Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name()); ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr()); return *static_cast<const T*>(holder_->Ptr());
} }
...@@ -39,61 +43,61 @@ class Variable { ...@@ -39,61 +43,61 @@ class Variable {
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (!holder_) { if (!holder_) {
holder_.reset(new PlaceholderImpl<T>(new T())); holder_.reset(new PlaceholderImpl<T>());
} else { } else {
PADDLE_ENFORCE(IsType<T>(), PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s", "Variable must be type %s, the holding type is %s",
typeid(T).name(), holder_->Type().name()); ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
template <typename T> template <typename T>
bool IsType() const { bool IsType() const {
return holder_ != nullptr && return holder_ && holder_->Type() == VarTypeTrait<T>::kId;
std::type_index(typeid(T)) == std::type_index(holder_->Type());
} }
void Clear() { holder_.reset(); } void Clear() { holder_.reset(); }
std::type_index Type() const { int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory"); PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
return holder_->Type(); return holder_->Type();
} }
private: private:
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() = default;
virtual const std::type_info& Type() const = 0;
virtual void* Ptr() const = 0; inline int Type() const { return type_; }
inline const void* Ptr() const { return ptr_; }
inline void* Ptr() { return ptr_; }
protected:
inline void Init(void* p, int type) {
ptr_ = p;
type_ = type;
}
void* ptr_;
int type_;
}; };
// Placeholder hides type T, so it doesn't appear as a template // Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable. // parameter of Variable.
template <typename T> template <typename T>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
explicit PlaceholderImpl(T* ptr) : ptr_(ptr), type_(typeid(T)) {} static_assert(
IsRegisteredVarType<T>(),
virtual const std::type_info& Type() const { return type_; } "Not registered type. Please register T inside var_type_traits.h");
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } PlaceholderImpl() { this->Init(&obj_, VarTypeTrait<T>::kId); }
std::unique_ptr<T> ptr_; private:
const std::type_info& type_; T obj_;
}; };
std::unique_ptr<Placeholder> // pointers to a PlaceholderImpl object indeed.
holder_; // pointers to a PlaceholderImpl object indeed. std::unique_ptr<Placeholder> holder_;
// name_ is only meaningful with a Scope and accessible by it.
//
// NOTE: Please don't expose name_ by adding methods like
// Variable::Name or Scope::VarName! A variable could have a human
// readable name or an auto-generated scope-unique name. In the
// former case, the caller knows the name and doesn't need to access
// the name; in the latter case, the variable should be identified
// by its address but not the unreadable name.
friend class Scope;
const std::string* name_;
}; };
} // namespace framework } // namespace framework
......
...@@ -16,27 +16,28 @@ ...@@ -16,27 +16,28 @@
#include <string> #include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
TEST(Variable, GetMutable) { namespace paddle {
using paddle::framework::Variable; namespace framework {
struct Tensor {
int content_;
};
TEST(Variable, GetMutable) {
std::unique_ptr<Variable> v(new Variable()); std::unique_ptr<Variable> v(new Variable());
Tensor* t = v->GetMutable<Tensor>(); auto* t = v->GetMutable<std::string>();
t->content_ = 1234; *t = "1234";
const Tensor& tt = v->Get<Tensor>(); const auto& tt = v->Get<std::string>();
EXPECT_EQ(1234, tt.content_); EXPECT_EQ("1234", tt);
try { try {
v->GetMutable<std::string>(); v->GetMutable<Tensor>();
} catch (std::exception& e) { } catch (std::exception& e) {
return; return;
} }
EXPECT_TRUE(false); EXPECT_TRUE(false);
} }
} // namespace framework
} // namespace paddle
...@@ -25,7 +25,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) { ...@@ -25,7 +25,7 @@ void TensorArrayBatchCleaner::CollectTensorArrays(framework::Scope *scope) {
// TODO(Superjomn) should avoid the case when a TensorArray is a // TODO(Superjomn) should avoid the case when a TensorArray is a
// parameter. // parameter.
if (var_name == "feed" || var_name == "fetch") continue; if (var_name == "feed" || var_name == "fetch") continue;
if (var->Type() == typeid(framework::LoDTensorArray)) { if (var->IsType<framework::LoDTensorArray>()) {
VLOG(4) << "collect " << var_name; VLOG(4) << "collect " << var_name;
arrays_.push_back(var->GetMutable<framework::LoDTensorArray>()); arrays_.push_back(var->GetMutable<framework::LoDTensorArray>());
} }
......
...@@ -27,8 +27,11 @@ namespace details { ...@@ -27,8 +27,11 @@ namespace details {
// training phase. // training phase.
struct TensorArrayBatchCleaner { struct TensorArrayBatchCleaner {
TensorArrayBatchCleaner() { TensorArrayBatchCleaner() {
valid_types_.insert(typeid(framework::Tensor)); constexpr auto kTensorId = framework::VarTypeTrait<framework::Tensor>::kId;
valid_types_.insert(typeid(framework::LoDTensor)); constexpr auto kLoDTensorId =
framework::VarTypeTrait<framework::LoDTensor>::kId;
valid_types_.insert(kTensorId);
valid_types_.insert(kLoDTensorId);
} }
// Collect the variables that are not Tensor or LoDTensor, and reset them to a // Collect the variables that are not Tensor or LoDTensor, and reset them to a
// bool(trick), because some of them are containers, and some operators just // bool(trick), because some of them are containers, and some operators just
...@@ -46,7 +49,7 @@ struct TensorArrayBatchCleaner { ...@@ -46,7 +49,7 @@ struct TensorArrayBatchCleaner {
bool no_tensor_flag_{true}; bool no_tensor_flag_{true};
std::vector<framework::LoDTensorArray *> arrays_; std::vector<framework::LoDTensorArray *> arrays_;
std::unordered_set<std::type_index> valid_types_; std::unordered_set<int> valid_types_;
std::unordered_set<framework::Variable *> no_tensor_vars_; std::unordered_set<framework::Variable *> no_tensor_vars_;
}; };
......
...@@ -64,7 +64,7 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class ClipByNormKernel : public framework::OpKernel<T> {
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
} else { } else {
PADDLE_THROW("Unexpected branch, input variable type is %s", PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name()); framework::ToTypeName(in_var->Type()));
} }
PADDLE_ENFORCE_NOT_NULL(input); PADDLE_ENFORCE_NOT_NULL(input);
......
...@@ -175,14 +175,13 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -175,14 +175,13 @@ class WhileGradOp : public framework::OperatorBase {
auto &og_inside = auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name), detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name); "Cannot find inside gradient %s", inside_og_name);
if (framework::IsType<framework::LoDTensor>(og_outside.Type())) { if (og_outside.IsType<framework::LoDTensor>()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>(); auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor = auto &inside_tensor =
detail::Ref(og_inside.GetMutable<framework::LoDTensor>()); detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor); inside_tensor.ShareDataWith(outside_tensor);
} else if (framework::IsType<framework::LoDTensorArray>( } else if (og_outside.IsType<framework::LoDTensorArray>()) {
og_outside.Type())) {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>(); auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array = auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>()); detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
...@@ -256,7 +255,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -256,7 +255,7 @@ class WhileGradOp : public framework::OperatorBase {
var->IsType<LoDTensor>(), var->IsType<LoDTensor>(),
"Currently the type of var only can be LoDTensorArray, " "Currently the type of var only can be LoDTensorArray, "
"or LoDTensor, but the received var[%s] is %s.", "or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, var->Type().name()); inside_grad_name, framework::ToTypeName(var->Type()));
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_rnn_cache.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,239 +22,6 @@ namespace operators { ...@@ -22,239 +22,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
struct CudnnRNNCache {
CudnnRNNCache() {
x_desc_ = NULL;
y_desc_ = NULL;
dx_desc_ = NULL;
dy_desc_ = NULL;
}
~CudnnRNNCache() { release(); }
cudnnRNNDescriptor_t rnn_desc_;
cudnnTensorDescriptor_t *x_desc_;
cudnnTensorDescriptor_t *y_desc_;
cudnnTensorDescriptor_t *dx_desc_;
cudnnTensorDescriptor_t *dy_desc_;
cudnnTensorDescriptor_t hx_desc_;
cudnnTensorDescriptor_t cx_desc_;
cudnnTensorDescriptor_t hy_desc_;
cudnnTensorDescriptor_t cy_desc_;
cudnnTensorDescriptor_t dhx_desc_;
cudnnTensorDescriptor_t dcx_desc_;
cudnnTensorDescriptor_t dhy_desc_;
cudnnTensorDescriptor_t dcy_desc_;
cudnnTensorDescriptor_t output_x_desc_;
cudnnTensorDescriptor_t output_y_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
size_t weights_size_;
cudnnFilterDescriptor_t w_desc_;
cudnnFilterDescriptor_t dw_desc_;
size_t workspace_size_;
size_t reserve_size_;
Tensor reserve_data_;
Tensor workspace_data_;
Tensor dropout_state_;
size_t max_length_;
float dropout_prob_;
bool is_bidirec_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
int seed_;
void init(cudnnHandle_t handle, const framework::ExecutionContext &ctx,
size_t max_len, int batch_size, int input_size, int hidden_size,
int num_layers, float dropout_prob, bool is_bidirec, int seed,
int weight_numel) {
max_length_ = max_len;
batch_size_ = batch_size;
input_size_ = input_size;
hidden_size_ = hidden_size;
num_layers_ = num_layers;
dropout_prob_ = dropout_prob;
is_bidirec_ = is_bidirec;
seed_ = seed;
x_desc_ = new cudnnTensorDescriptor_t[max_length_];
y_desc_ = new cudnnTensorDescriptor_t[max_length_];
dx_desc_ = new cudnnTensorDescriptor_t[max_length_];
dy_desc_ = new cudnnTensorDescriptor_t[max_length_];
int dim_a[3];
int stride_a[3];
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i]));
dim_a[0] = batch_size_;
dim_a[1] = input_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dim_a[0] = batch_size_;
dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
}
dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1);
dim_a[1] = batch_size_;
dim_a[2] = hidden_size_;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_));
size_t state_size;
CUDNN_ENFORCE(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size);
dropout_state_.Resize({static_cast<int64_t>(state_size)}));
auto *dropout_state_data =
dropout_state_.mutable_data<uint8_t>(ctx.GetPlace());
CUDNN_ENFORCE(platform::dynload::cudnnSetDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size,
seed_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
#if CUDNN_VERSION >= 6000
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
#else
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_DATA_FLOAT));
#endif
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT));
PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel,
"cudnn lstm weight size should be SAME");
int dim_w[3];
dim_w[0] = weights_size_ / sizeof(float);
dim_w[1] = 1;
dim_w[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, max_length_, x_desc_, &workspace_size_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, max_length_, x_desc_, &reserve_size_));
reserve_data_.Resize({static_cast<int64_t>(reserve_size_)});
reserve_data_.mutable_data<uint8_t>(ctx.GetPlace());
workspace_data_.Resize({static_cast<int64_t>(workspace_size_)});
workspace_data_.mutable_data<uint8_t>(ctx.GetPlace());
}
void release() {
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i]));
}
delete[] x_desc_;
delete[] y_desc_;
delete[] dx_desc_;
delete[] dy_desc_;
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcy_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyDropoutDescriptor(dropout_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(dw_desc_));
}
};
template <typename T> template <typename T>
class CudnnLSTMGPUKernel : public framework::OpKernel<T> { class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -315,9 +82,9 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -315,9 +82,9 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
auto input_w_numel = w->numel(); auto input_w_numel = w->numel();
auto batch_size = x->dims()[1]; auto batch_size = x->dims()[1];
cudnn_rnn_cache->init(handle, ctx, max_len, batch_size, input_size, cudnn_rnn_cache->init(handle, ctx.GetPlace(), max_len, batch_size,
hidden_size, num_layers, dropout_prob, is_bidirec, input_size, hidden_size, num_layers, dropout_prob,
seed, input_w_numel); is_bidirec, seed, input_w_numel);
} }
auto run_seq_len = x->dims()[0]; auto run_seq_len = x->dims()[0];
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
struct CudnnRNNCache {
CudnnRNNCache() {
x_desc_ = NULL;
y_desc_ = NULL;
dx_desc_ = NULL;
dy_desc_ = NULL;
}
~CudnnRNNCache() { release(); }
cudnnRNNDescriptor_t rnn_desc_;
cudnnTensorDescriptor_t *x_desc_;
cudnnTensorDescriptor_t *y_desc_;
cudnnTensorDescriptor_t *dx_desc_;
cudnnTensorDescriptor_t *dy_desc_;
cudnnTensorDescriptor_t hx_desc_;
cudnnTensorDescriptor_t cx_desc_;
cudnnTensorDescriptor_t hy_desc_;
cudnnTensorDescriptor_t cy_desc_;
cudnnTensorDescriptor_t dhx_desc_;
cudnnTensorDescriptor_t dcx_desc_;
cudnnTensorDescriptor_t dhy_desc_;
cudnnTensorDescriptor_t dcy_desc_;
cudnnTensorDescriptor_t output_x_desc_;
cudnnTensorDescriptor_t output_y_desc_;
cudnnDropoutDescriptor_t dropout_desc_;
size_t weights_size_;
cudnnFilterDescriptor_t w_desc_;
cudnnFilterDescriptor_t dw_desc_;
size_t workspace_size_;
size_t reserve_size_;
framework::Tensor reserve_data_;
framework::Tensor workspace_data_;
framework::Tensor dropout_state_;
size_t max_length_;
float dropout_prob_;
bool is_bidirec_;
int batch_size_;
int input_size_;
int hidden_size_;
int num_layers_;
int seed_;
void init(cudnnHandle_t handle, const platform::Place &place, size_t max_len,
int batch_size, int input_size, int hidden_size, int num_layers,
float dropout_prob, bool is_bidirec, int seed, int weight_numel) {
max_length_ = max_len;
batch_size_ = batch_size;
input_size_ = input_size;
hidden_size_ = hidden_size;
num_layers_ = num_layers;
dropout_prob_ = dropout_prob;
is_bidirec_ = is_bidirec;
seed_ = seed;
x_desc_ = new cudnnTensorDescriptor_t[max_length_];
y_desc_ = new cudnnTensorDescriptor_t[max_length_];
dx_desc_ = new cudnnTensorDescriptor_t[max_length_];
dy_desc_ = new cudnnTensorDescriptor_t[max_length_];
int dim_a[3];
int stride_a[3];
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i]));
dim_a[0] = batch_size_;
dim_a[1] = input_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
dim_a[0] = batch_size_;
dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_;
dim_a[2] = 1;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
}
dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1);
dim_a[1] = batch_size_;
dim_a[2] = hidden_size_;
stride_a[0] = dim_a[2] * dim_a[1];
stride_a[1] = dim_a[2];
stride_a[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a));
CUDNN_ENFORCE(
platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_));
size_t state_size;
CUDNN_ENFORCE(
platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size);
dropout_state_.Resize({static_cast<int64_t>(state_size)}));
auto *dropout_state_data = dropout_state_.mutable_data<uint8_t>(place);
CUDNN_ENFORCE(platform::dynload::cudnnSetDropoutDescriptor(
dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size,
seed_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));
#if CUDNN_VERSION >= 6000
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor_v6(
handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_,
CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT));
#else
CUDNN_ENFORCE(platform::dynload::cudnnSetRNNDescriptor(
rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM,
CUDNN_DATA_FLOAT));
#endif
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT));
PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel,
"cudnn lstm weight size should be SAME");
int dim_w[3];
dim_w[0] = weights_size_ / sizeof(float);
dim_w[1] = 1;
dim_w[2] = 1;
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnSetFilterNdDescriptor(
dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, max_length_, x_desc_, &workspace_size_));
CUDNN_ENFORCE(platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, max_length_, x_desc_, &reserve_size_));
reserve_data_.Resize({static_cast<int64_t>(reserve_size_)});
reserve_data_.mutable_data<uint8_t>(place);
workspace_data_.Resize({static_cast<int64_t>(workspace_size_)});
workspace_data_.mutable_data<uint8_t>(place);
}
void release() {
for (size_t i = 0; i < max_length_; ++i) {
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i]));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i]));
}
delete[] x_desc_;
delete[] y_desc_;
delete[] dx_desc_;
delete[] dy_desc_;
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(dcy_desc_));
CUDNN_ENFORCE(
platform::dynload::cudnnDestroyDropoutDescriptor(dropout_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_ENFORCE(platform::dynload::cudnnDestroyFilterDescriptor(dw_desc_));
}
};
} // namespace operators
} // namespace paddle
...@@ -116,7 +116,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -116,7 +116,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"% should be LoDTensor or SelectedRows, but the received type is %s", "% should be LoDTensor or SelectedRows, but the received type is %s",
ctx.Inputs("Ids")[0], ids_var->Type().name()); ctx.Inputs("Ids")[0], framework::ToTypeName(ids_var->Type()));
} }
} }
}; };
......
...@@ -83,7 +83,7 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
z = ctx.Output<framework::LoDTensor>("Out"); z = ctx.Output<framework::LoDTensor>("Out");
} else { } else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.", PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
x_var->Type().name()); framework::ToTypeName(x_var->Type()));
} }
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
......
...@@ -27,12 +27,14 @@ class AdadeltaOpKernel : public framework::OpKernel<T> { ...@@ -27,12 +27,14 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto* grad_var = ctx.InputVar("Grad"); const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name()); ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor = auto avg_squared_grad_out_tensor =
......
...@@ -50,7 +50,8 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -50,7 +50,8 @@ class AdagradOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
...@@ -347,7 +347,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -347,7 +347,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref; using paddle::operators::detail::Ref;
......
...@@ -27,12 +27,14 @@ class AdamaxOpKernel : public framework::OpKernel<T> { ...@@ -27,12 +27,14 @@ class AdamaxOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto* grad_var = ctx.InputVar("Grad"); const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name()); ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
...@@ -27,12 +27,14 @@ class DecayedAdagradOpKernel : public framework::OpKernel<T> { ...@@ -27,12 +27,14 @@ class DecayedAdagradOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto* grad_var = ctx.InputVar("Grad"); const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name()); ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
...@@ -32,12 +32,14 @@ class FTRLOpKernel : public framework::OpKernel<T> { ...@@ -32,12 +32,14 @@ class FTRLOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto* grad_var = ctx.InputVar("Grad"); const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name()); ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
auto* param_out = ctx.Output<Tensor>("ParamOut"); auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut"); auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
......
...@@ -395,7 +395,7 @@ class MomentumOpKernel : public framework::OpKernel<T> { ...@@ -395,7 +395,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
PADDLE_THROW( PADDLE_THROW(
string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows " string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s", "gradient, but the received Variable Type is %s",
grad_var->Type().name())); framework::ToTypeName(grad_var->Type())));
} }
} }
}; };
......
...@@ -60,7 +60,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> { ...@@ -60,7 +60,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name()); ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
auto* param = ctx.Input<framework::Tensor>("Param"); auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut"); auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
......
...@@ -245,7 +245,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -245,7 +245,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
} else { } else {
PADDLE_THROW("Unexpected branch, output variable type is %s", PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name()); framework::ToTypeName(out_var->Type()));
} }
} }
}; };
......
...@@ -126,7 +126,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -126,7 +126,7 @@ class SumOp : public framework::OperatorWithKernel {
PADDLE_THROW("Cannot find the input data type by all input data"); PADDLE_THROW("Cannot find the input data type by all input data");
} }
PADDLE_THROW("Unexpected branch. Input type is %s", PADDLE_THROW("Unexpected branch. Input type is %s",
x_vars[0]->Type().name()); framework::ToTypeName(x_vars[0]->Type()));
} }
}; };
......
...@@ -163,7 +163,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -163,7 +163,7 @@ class SumKernel : public framework::OpKernel<T> {
} }
} else { } else {
PADDLE_THROW("Unexpected branch, output variable type is %s", PADDLE_THROW("Unexpected branch, output variable type is %s",
out_var->Type().name()); framework::ToTypeName(out_var->Type()));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册