diff --git a/doc/design/infer_var_type.md b/doc/design/infer_var_type.md new file mode 100644 index 0000000000000000000000000000000000000000..d9d5397becba2ef1806d9341cd49cd9aabbf4a6a --- /dev/null +++ b/doc/design/infer_var_type.md @@ -0,0 +1,78 @@ +# Design Doc: InferVarType + +## The Problem Posed + +The variable in our design can hold variant types. Such as `LoDTensor` and `SelectedRows`. An operator should be able to inference the variable types of its output. + +For example, a `lookup table` operator takes two `LoDTensor`; one is a float tensor as the embedding table, the other is an int tensor as word ID. The gradient operator of `lookup table` will generate a `SelectedRows` as its output. A `sum` operator can take both `LoDTensor` and `SelectedRows` as its inputs and will generate a `LoDTensor` if any of its inputs is `LoDTensor`, otherwise, the `sum` operator will generate `SelectedRows` as its output. + +The variable type will be constant at runtime. Every variable's type can either be set by the user (input data and parameter) or be inferred by the operator in compile time. + +## Proposed Solution + +The `InferVarType` is a compile-time function which is registered to each operator. The inferface of that function is: + + +```c++ +using InferVarTypeFN = std::function< + void (const OpDescBind& /*op_desc*/, BlockDescBind* /*block*/)>; +``` + +It takes an operator description as its input and will write the output variable type and store them in block description. + +The `InferVarTypeFN` will be registered in `OpInfo`, to replace `infer_var_type_` field. The `OpInfo` should be + +```cpp +struct OpInfo { + InferVarTypeFN infer_var_type_; + ... +}; +``` + +The default `InferVarType` will set output type as `LoDTensor`. It can be done by `GetInferVarType()`. + +```cpp +void DefaultInferVarType(const OpDescBind& op_desc, BlockDescBind* block) { + // set the output type of variable as `LoDTensor`. + // ... +} + +struct OpInfo { + InferVarTypeFN infer_var_type_; + InferVarTypeFN GetInferVarType() const { + if (infer_var_type_) { + return infer_var_type_; + } else { + return DefaultInferVarType; + } + } +}; +``` + +## Register InferVarType + +We provide a thin base class for registering an `InferVarTypeFN`. To use a base class will ease the implementation of registry since we can detect the registry entry is an `InferVarTypeFN` or not. + +```cpp +class VarTypeInferer { +public: + virtual void operator()(const OpDescBind& op_desc, BlockDescBind* block) const = 0; +} +``` + +Operator developers can write the specialize `VarTypeInferer` as follow. + +```cpp +class SpecialVarTypeInferer : public VarTypeInferer { +public: + virtual void operator()(const OpDescBind& op_desc, BlockDescBind* block) const { + // .. own logic + } +} +``` + +Then user can register the `InferVarType` just like `GradOpDescMaker` and `OpInfoMaker`. + +``` +REGISTER_OPERATOR(some_op, OpType, SpecialVarTypeInferer, ...); +``` diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 8f8a53eec8f947b088124a3f034fedb17fd86a48..5bf5e91f25ab1d920ae368eaf2000fce77d2eb07 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -65,16 +65,12 @@ void Scope::DropKids() { kids_.clear(); } -std::once_flag feed_variable_flag; - framework::Scope& GetGlobalScope() { - static std::unique_ptr g_scope{nullptr}; - std::call_once(feed_variable_flag, [&]() { - g_scope.reset(new framework::Scope()); - g_scope->Var("feed_value"); - g_scope->Var("fetch_value"); - }); - return *(g_scope.get()); + static framework::Scope* g_scope = nullptr; + if (g_scope == nullptr) { + g_scope = new framework::Scope(); + } + return *g_scope; } } // namespace framework diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 5087c02385f7f37d78d134b739f3f22522977fb8..8e561528f0e7e6ff524fc51b4776efc4e5bd28cd 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -14,11 +14,6 @@ limitations under the License. */ #include "paddle/memory/memory.h" -#include // for transform -#include // for memcpy -#include // for unique_ptr -#include // for call_once - #include "glog/logging.h" #include "paddle/memory/detail/buddy_allocator.h" @@ -32,19 +27,14 @@ namespace memory { using BuddyAllocator = detail::BuddyAllocator; -std::once_flag cpu_allocator_flag; -std::once_flag gpu_allocator_flag; - BuddyAllocator* GetCPUBuddyAllocator() { - static std::unique_ptr a{nullptr}; - - std::call_once(cpu_allocator_flag, [&]() { - a.reset(new BuddyAllocator(new detail::CPUAllocator, - platform::CpuMinChunkSize(), - platform::CpuMaxChunkSize())); - }); - - return a.get(); + static detail::BuddyAllocator* a = nullptr; + if (a == nullptr) { + a = new detail::BuddyAllocator(new detail::CPUAllocator, + platform::CpuMinChunkSize(), + platform::CpuMaxChunkSize()); + } + return a; } template <> @@ -65,35 +55,24 @@ size_t Used(platform::CPUPlace place) { #ifdef PADDLE_WITH_CUDA BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { - using BuddyAllocVec = std::vector; - static std::unique_ptr as{ - new BuddyAllocVec, [](BuddyAllocVec* p) { - std::for_each(p->begin(), p->end(), - [](BuddyAllocator* p) { delete p; }); - }}; - - // GPU buddy allocators - auto& allocators = *as.get(); - - // GPU buddy allocator initialization - std::call_once(gpu_allocator_flag, [&]() { + static BuddyAllocator** as = NULL; + if (as == NULL) { int gpu_num = platform::GetCUDADeviceCount(); - allocators.reserve(gpu_num); + as = new BuddyAllocator*[gpu_num]; for (int gpu = 0; gpu < gpu_num; gpu++) { platform::SetDeviceId(gpu); - allocators.emplace_back(new BuddyAllocator(new detail::GPUAllocator, - platform::GpuMinChunkSize(), - platform::GpuMaxChunkSize())); + as[gpu] = new BuddyAllocator(new detail::GPUAllocator, + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize()); } VLOG(3) << "\n\nNOTE: each GPU device use " << FLAGS_fraction_of_gpu_memory_to_use * 100 << "% of GPU memory.\n" << "You can set environment variable '" << platform::kEnvFractionGpuMemoryToUse << "' to change the fraction of GPU usage.\n\n"; - }); - + } platform::SetDeviceId(gpu_id); - return allocators[gpu_id]; + return as[gpu_id]; } template <>