未验证 提交 225eec2b 编写于 作者: L Li Xinqi 提交者: GitHub

Reduce usage intrusive macro (#6543)

* remove most usage of macros INTRUSIVE_*

* rename most INTRUSVE_XXX macros to REFLECTIVE_XXX

* move intrusive::Base to intrusive/base.h

* 1) remove OFFSET_STRUCT_FIELD; 2) mv test cases of HeadFreeList into head_free_list_test.cpp
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 e6540a60
......@@ -25,10 +25,8 @@ namespace oneflow {
class Device;
// clang-format off
// Helps VirtualMachine building instruction edges
INTRUSIVE_BEGIN(LocalDepObject);
class LocalDepObject final : public intrusive::Base {
public:
// Getters
const vm::LogicalObject& logical_object() const {
......@@ -65,18 +63,24 @@ INTRUSIVE_BEGIN(LocalDepObject);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
LocalDepObject() : intrusive_ref_(), logical_object_(), mirrored_object_(), pool_hook_(), stored_hook_(), lifetime_hook_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
LocalDepObject()
: intrusive_ref_(),
logical_object_(),
mirrored_object_(),
pool_hook_(),
stored_hook_(),
lifetime_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<vm::LogicalObject>, logical_object_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<vm::MirroredObject>, mirrored_object_);
intrusive::shared_ptr<vm::LogicalObject> logical_object_;
intrusive::shared_ptr<vm::MirroredObject> mirrored_object_;
public:
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, pool_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, stored_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, lifetime_hook_);
INTRUSIVE_END(LocalDepObject);
// clang-format on
intrusive::ListHook pool_hook_;
intrusive::ListHook stored_hook_;
intrusive::ListHook lifetime_hook_;
};
Maybe<LocalDepObject*> GetLocalDepObjectFromDevicePool(Symbol<Device> device);
Maybe<void> PutLocalDepObjectToDevicePool(Symbol<Device> device, LocalDepObject* local_dep_object);
......
### 概念与数据结构
本子系统可以方便用户定义可侵入式类型。内建支持侵入式智能指针intrusive::shared_ptr和侵入式容器。
本子系统可以方便用户定义可侵入式类型。内建支持侵入式智能指针`intrusive::shared_ptr`和侵入式容器。
目前有主要有两类侵入式容器:
1. intrusive::List,双链表。基于此,还提供了intrusive::MutexedList和intrusive::Channel
2. intrusive::SkipList,跳表,等同于map。
1. `intrusive::List`,双链表。基于此,还提供了`intrusive::MutexedList``intrusive::Channel`
2. `intrusive::SkipList`,跳表,等同于map。
为了管理元素CURD所带来的生命周期,侵入式容器需要intrusive::shared_ptr来实现内存生命周期的管理,它与std::shared_ptr的不同在于其引用计数嵌入在目标结构体里。
为了管理元素CURD所带来的生命周期,侵入式容器需要`intrusive::shared_ptr`来实现内存生命周期的管理,它与`std::shared_ptr`的不同在于其引用计数嵌入在目标结构体里。
### 接口
由于侵入式容器要求其元素类型T必须满足std::is_standard_layout<T>::value,为了减少麻烦的编译问题及字段访问权限相关的悖论,我们使用一组宏定义可侵入类型。
1. INTRUSIVE_BEGIN,开始定义可侵入类型。
2. INTRUSIVE_END,结束定义可侵入类型。
3. INTRUSIVE_DEFINE_FIELD,定义可侵入类型的字段。
4. INTRUSIVE_FIELD,描述可侵入类型的某个字段,属性包括容器类型,字段类型,字段偏移,常用于侵入式容器的构建。
需要使用`intrusive::shared_ptr`来管理生命周期的类必须拥有`intrusive::Ref* mut_intrusive_ref();`方法
由于侵入式容器支持比标准容器更为强大的迭代方式,同时为了性能起见,我们提供三类迭代宏:
1. INTRUSIVE_FOR_EACH,支持迭代过程中删除当前元素,同时使用intrusive::shared_ptr管理好当前元素生命周期
2. INTRUSIVE_FOR_EACH_PTR,支持迭代过程中删除当前元素,类型直接为裸指针,即不负责当前元素生命周期的管理
3. INTRUSIVE_UNSAFE_FOR_EACH_PTR,不支持迭代中删除元素,不负责当前元素生命周期的管理。
1. `INTRUSIVE_FOR_EACH`,支持迭代过程中删除当前元素,同时使用`intrusive::shared_ptr`管理好当前元素生命周期
2. `INTRUSIVE_FOR_EACH_PTR`,支持迭代过程中删除当前元素,类型直接为裸指针,即不负责当前元素生命周期的管理
3. `INTRUSIVE_UNSAFE_FOR_EACH_PTR`,不支持迭代中删除元素,不负责当前元素生命周期的管理。
### 特点
本组件与boost::intrusive最大不同在于实现了完整的生命周期管理,另外提供了其他更能减少内存分配的容器定义方式(详见intrusive::HeadFreeList)。
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_INTRUSIVE_BASE_H_
#define ONEFLOW_CORE_INTRUSIVE_BASE_H_
namespace oneflow {
namespace intrusive {
class Base {
public:
void __Init__() {}
void __Delete__() {}
};
} // namespace intrusive
} // namespace oneflow
#endif // ONEFLOW_CORE_INTRUSIVE_BASE_H_
......@@ -62,7 +62,7 @@ class Channel {
return kChannelStatusSuccess;
}
ChannelStatus MoveFrom(intrusive::List<HookField>* src) {
ChannelStatus MoveFrom(List<HookField>* src) {
std::unique_lock<std::mutex> lock(*mut_mutex());
if (is_closed_) { return kChannelStatusErrorClosed; }
src->MoveToDstBack(&list_head_);
......@@ -70,7 +70,7 @@ class Channel {
return kChannelStatusSuccess;
}
ChannelStatus MoveTo(intrusive::List<HookField>* dst) {
ChannelStatus MoveTo(List<HookField>* dst) {
std::unique_lock<std::mutex> lock(*mut_mutex());
mut_cond()->wait(lock, [this]() { return (!list_head_.empty()) || is_closed_; });
if (list_head_.empty()) { return kChannelStatusErrorClosed; }
......@@ -78,7 +78,7 @@ class Channel {
return kChannelStatusSuccess;
}
ChannelStatus TryMoveTo(intrusive::List<HookField>* dst) {
ChannelStatus TryMoveTo(List<HookField>* dst) {
std::unique_lock<std::mutex> lock(*mut_mutex());
if (list_head_.empty()) { return kChannelStatusSuccess; }
mut_cond()->wait(lock, [this]() { return (!list_head_.empty()) || is_closed_; });
......@@ -97,7 +97,7 @@ class Channel {
std::mutex* mut_mutex() { return &mutex_; }
std::condition_variable* mut_cond() { return &cond_; }
intrusive::List<HookField> list_head_;
List<HookField> list_head_;
std::mutex mutex_;
std::condition_variable cond_;
bool is_closed_;
......
......@@ -26,8 +26,7 @@ namespace test {
namespace {
// clang-format off
INTRUSIVE_BEGIN(Foo);
class Foo final : public intrusive::Base {
public:
int x() const { return x_; }
void set_x(int val) { x_ = val; }
......@@ -36,13 +35,14 @@ INTRUSIVE_BEGIN(Foo);
Foo() : intrusive_ref_(), x_(), hook_() {}
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(int, x_);
int x_;
public:
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, hook_);
INTRUSIVE_END(Foo);
// clang-format on
intrusive::ListHook hook_;
};
using ChannelFoo = intrusive::Channel<INTRUSIVE_FIELD(Foo, hook_)>;
......
......@@ -244,8 +244,8 @@ struct FlatMsgViewUtil {
template<typename FlatMsgViewT, typename ValueType, typename ContainerT, typename Enabled = void>
struct FlatMsgViewContainerUtil {
using FlatMsgOneofField =
StructField<ValueType, typename ValueType::__OneofType, ValueType::__kDssFieldOffset>;
using FlatMsgOneofField = intrusive::OffsetStructField<ValueType, typename ValueType::__OneofType,
ValueType::__kDssFieldOffset>;
static bool Match(FlatMsgViewT* self, const ContainerT& container) {
return FlatMsgViewUtil<FlatMsgViewT, FlatMsgOneofField, typename ContainerT::value_type>::Match(
self, container.data(), container.size());
......@@ -254,8 +254,8 @@ struct FlatMsgViewContainerUtil {
template<typename FlatMsgViewT, typename ValueType, typename Enabled>
struct FlatMsgViewContainerUtil<FlatMsgViewT, ValueType, std::vector<FlatMsg<ValueType>>, Enabled> {
using FlatMsgOneofField =
StructField<ValueType, typename ValueType::__OneofType, ValueType::__kDssFieldOffset>;
using FlatMsgOneofField = intrusive::OffsetStructField<ValueType, typename ValueType::__OneofType,
ValueType::__kDssFieldOffset>;
static_assert(sizeof(ValueType) == sizeof(FlatMsg<ValueType>), "");
static_assert(alignof(ValueType) == alignof(FlatMsg<ValueType>), "");
static bool Match(FlatMsgViewT* self, const std::vector<FlatMsg<ValueType>>& container) {
......
......@@ -33,25 +33,27 @@ namespace intrusive {
// details
#define _INTRUSIVE_FOR_EACH(container_type, elem, container) \
for (intrusive::shared_ptr<typename container_type::value_type> elem, \
*end_if_not_null = nullptr; \
end_if_not_null == nullptr; end_if_not_null = nullptr, ++end_if_not_null) \
LIST_HOOK_FOR_EACH_WITH_EXPR( \
(StructField<typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
#define _INTRUSIVE_FOR_EACH(container_type, elem, container) \
for (intrusive::shared_ptr<typename container_type::value_type> elem, \
*end_if_not_null = nullptr; \
end_if_not_null == nullptr; end_if_not_null = nullptr, ++end_if_not_null) \
LIST_HOOK_FOR_EACH_WITH_EXPR( \
(intrusive::OffsetStructField< \
typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
container_type::iterator_struct_field, elem_ptr, (elem.Reset(elem_ptr), true))
#define _INTRUSIVE_FOR_EACH_PTR(container_type, elem, container) \
LIST_HOOK_FOR_EACH( \
(StructField<typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
container_type::iterator_struct_field, elem)
#define _INTRUSIVE_UNSAFE_FOR_EACH_PTR(container_type, elem, container) \
LIST_HOOK_UNSAFE_FOR_EACH( \
(StructField<typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
#define _INTRUSIVE_FOR_EACH_PTR(container_type, elem, container) \
LIST_HOOK_FOR_EACH((intrusive::OffsetStructField< \
typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
container_type::iterator_struct_field, elem)
#define _INTRUSIVE_UNSAFE_FOR_EACH_PTR(container_type, elem, container) \
LIST_HOOK_UNSAFE_FOR_EACH( \
(intrusive::OffsetStructField< \
typename container_type, intrusive::ListHook, \
container_type::IteratorHookOffset()>::FieldPtr4StructPtr(container)), \
container_type::iterator_struct_field, elem)
} // namespace intrusive
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/intrusive/ref.h"
#include "oneflow/core/intrusive/list_hook.h"
#include "oneflow/core/intrusive/struct_traits.h"
#include "oneflow/core/intrusive/reflective.h"
namespace oneflow {
namespace intrusive {
......@@ -50,13 +51,13 @@ class HeadFreeList {
void __Init__() {
list_head_.__Init__();
static_assert(
std::is_same<HeadFreeList,
INTRUSIVE_FIELD_TYPE(typename value_type, field_number_in_countainter)>::value,
std::is_same<HeadFreeList, REFLECTIVE_FIELD_TYPE(typename value_type,
field_number_in_countainter)>::value,
"It's invalid to define fields between definition of head-free list type and definition of "
"head-free list field.");
using ThisInContainer =
StructField<value_type, HeadFreeList,
INTRUSIVE_FIELD_OFFSET(value_type, field_number_in_countainter)>;
OffsetStructField<value_type, HeadFreeList,
REFLECTIVE_FIELD_OFFSET(value_type, field_number_in_countainter)>;
container_ = ThisInContainer::StructPtr4FieldPtr(this);
}
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// include sstream first to avoid some compiling error
// caused by the following trick
// reference: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=65899
#include <sstream>
#define private public
#include "oneflow/core/common/util.h"
#include "oneflow/core/intrusive/intrusive.h"
namespace oneflow {
namespace test {
namespace {
// clang-format off
REFLECTIVE_CLASS_BEGIN(SelfLoopContainer);
public:
void __Init__() { clear_deleted(); }
// Getters
bool has_deleted() const { return deleted_ != nullptr; }
bool deleted() const { return *deleted_; }
bool is_hook_empty() const { return hook_.empty(); }
// Setters
bool* mut_deleted() { return deleted_; }
void set_deleted(bool* val) { deleted_ = val; }
void clear_deleted() { deleted_ = nullptr; }
// methods
void __Init__(bool* deleted) {
__Init__();
set_deleted(deleted);
}
void __Delete__() { *mut_deleted() = true; }
size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
SelfLoopContainer() : intrusive_ref_(), deleted_(), hook_(), head_() {}
REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
// fields
REFLECTIVE_CLASS_DEFINE_FIELD(bool*, deleted_);
// list hooks
REFLECTIVE_CLASS_DEFINE_FIELD(intrusive::ListHook, hook_);
public:
// Do not insert other REFLECTIVE_CLASS_DEFINE_FIELD between `using SelfLoopContainerList = ...;` and `REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, ...);`
using SelfLoopContainerList =
intrusive::HeadFreeList<REFLECTIVE_FIELD(SelfLoopContainer, hook_), REFLECTIVE_FIELD_COUNTER>;
const SelfLoopContainerList& head() const { return head_; }
SelfLoopContainerList* mut_head() { return &head_; }
private:
REFLECTIVE_CLASS_DEFINE_FIELD(SelfLoopContainerList, head_);
REFLECTIVE_CLASS_END(SelfLoopContainer);
// clang-format on
TEST(HeadFreeList, __Init__) {
bool deleted = false;
auto self_loop_head = intrusive::make_shared<SelfLoopContainer>(&deleted);
ASSERT_EQ(self_loop_head->mut_head()->container_, self_loop_head.Mutable());
}
TEST(HeadFreeList, PushBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, PushFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->PushFront(self_loop_head0.Mutable());
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->PushFront(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, EmplaceBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceBack(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceBack(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, EmplaceFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceFront(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceFront(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, Erase) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->Erase(self_loop_head0.Mutable());
self_loop_head0->mut_head()->Erase(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, PopBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->PopBack();
self_loop_head0->mut_head()->PopBack();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, PopFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->PopFront();
self_loop_head0->mut_head()->PopFront();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, MoveTo) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->MoveTo(self_loop_head1->mut_head());
ASSERT_EQ(self_loop_head0->ref_cnt(), 2);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(HeadFreeList, Clear) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->Clear();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
} // namespace
} // namespace test
} // namespace oneflow
......@@ -16,13 +16,15 @@ limitations under the License.
#ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_
#define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_
#include "oneflow/core/intrusive/intrusive_core.h"
#include "oneflow/core/intrusive/struct_traits.h"
#include "oneflow/core/intrusive/base.h"
#include "oneflow/core/intrusive/ref.h"
#include "oneflow/core/intrusive/shared_ptr.h"
#include "oneflow/core/intrusive/list.h"
#include "oneflow/core/intrusive/head_free_list.h"
#include "oneflow/core/intrusive/skiplist.h"
#include "oneflow/core/intrusive/for_each.h"
#include "oneflow/core/intrusive/reflective.h"
#include "oneflow/core/intrusive/force_standard_layout.h"
#endif // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_H_
......@@ -44,8 +44,7 @@ TEST(Ref, ref_cnt) {
ASSERT_EQ(foo.DecreaseRefCount(), 0);
}
// clang-format off
INTRUSIVE_BEGIN(IntrusiveFoo)
class IntrusiveFoo final : public intrusive::Base {
public:
void __Init__() { clear_is_deleted(); }
void __Delete__();
......@@ -74,14 +73,13 @@ INTRUSIVE_BEGIN(IntrusiveFoo)
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
IntrusiveFoo() : intrusive_ref_(), x_(), foo_(), bar_(), foobar_(), is_deleted_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(int8_t, x_);
INTRUSIVE_DEFINE_FIELD(int32_t, foo_);
INTRUSIVE_DEFINE_FIELD(int16_t, bar_);
INTRUSIVE_DEFINE_FIELD(int64_t, foobar_);
INTRUSIVE_DEFINE_FIELD(std::string*, is_deleted_);
INTRUSIVE_END(IntrusiveFoo)
// clang-format on
intrusive::Ref intrusive_ref_;
int8_t x_;
int32_t foo_;
int16_t bar_;
int64_t foobar_;
std::string* is_deleted_;
};
void IntrusiveFoo::__Delete__() {
if (mut_is_deleted()) { *mut_is_deleted() = "deleted"; }
......@@ -104,11 +102,10 @@ TEST(intrusive, __delete__) {
ASSERT_TRUE(is_deleted == "deleted");
}
// clang-format off
INTRUSIVE_BEGIN(IntrusiveBar)
class IntrusiveBar final : public intrusive::Base {
public:
void __Init__() { clear_is_deleted(); }
void __Delete__(){
void __Delete__() {
if (mut_is_deleted()) { *mut_is_deleted() = "bar_deleted"; }
}
......@@ -137,11 +134,10 @@ INTRUSIVE_BEGIN(IntrusiveBar)
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
IntrusiveBar() : intrusive_ref_(), foo_(), is_deleted_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<IntrusiveFoo>, foo_);
INTRUSIVE_DEFINE_FIELD(std::string*, is_deleted_);
INTRUSIVE_END(IntrusiveBar)
// clang-format on
intrusive::Ref intrusive_ref_;
intrusive::shared_ptr<IntrusiveFoo> foo_;
std::string* is_deleted_;
};
TEST(intrusive, nested_objects) {
auto bar = intrusive::make_shared<IntrusiveBar>();
......@@ -174,8 +170,7 @@ FLAT_MSG_BEGIN(FlatMsgDemo)
FLAT_MSG_END(FlatMsgDemo)
// clang-format on
// clang-format off
INTRUSIVE_BEGIN(IntrusiveContainerDemo)
class IntrusiveContainerDemo final : public intrusive::Base {
public:
// Getters
const FlatMsgDemo& flat_field() const { return flat_field_.Get(); }
......@@ -187,10 +182,9 @@ INTRUSIVE_BEGIN(IntrusiveContainerDemo)
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
IntrusiveContainerDemo() : intrusive_ref_(), flat_field_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(FlatMsg<FlatMsgDemo>, flat_field_);
INTRUSIVE_END(IntrusiveContainerDemo)
// clang-format on
intrusive::Ref intrusive_ref_;
FlatMsg<FlatMsgDemo> flat_field_;
};
TEST(intrusive, flat_msg_field) {
auto obj = intrusive::make_shared<IntrusiveContainerDemo>();
......@@ -201,44 +195,44 @@ TEST(intrusive, flat_msg_field) {
}
// clang-format off
INTRUSIVE_BEGIN(TestIntrusiveField);
REFLECTIVE_CLASS_BEGIN(TestIntrusiveField);
TestIntrusiveField() = default;
static_assert(INTRUSIVE_FIELD_COUNTER == 0, "");
static_assert(INTRUSIVE_FIELD_COUNTER == 0, "");
INTRUSIVE_DEFINE_FIELD(int32_t, a);
static_assert(INTRUSIVE_FIELD_COUNTER == 1, "");
static_assert(INTRUSIVE_FIELD_COUNTER == 1, "");
INTRUSIVE_DEFINE_FIELD(int64_t, b);
static_assert(INTRUSIVE_FIELD_COUNTER == 2, "");
static_assert(INTRUSIVE_FIELD_COUNTER == 2, "");
INTRUSIVE_DEFINE_FIELD(int8_t, c);
static_assert(INTRUSIVE_FIELD_COUNTER == 3, "");
static_assert(INTRUSIVE_FIELD_COUNTER == 3, "");
INTRUSIVE_DEFINE_FIELD(int64_t, d);
static_assert(INTRUSIVE_FIELD_COUNTER == 4, "");
static_assert(INTRUSIVE_FIELD_COUNTER == 4, "");
INTRUSIVE_END(TestIntrusiveField);
static_assert(REFLECTIVE_FIELD_COUNTER == 0, "");
static_assert(REFLECTIVE_FIELD_COUNTER == 0, "");
REFLECTIVE_CLASS_DEFINE_FIELD(int32_t, a);
static_assert(REFLECTIVE_FIELD_COUNTER == 1, "");
static_assert(REFLECTIVE_FIELD_COUNTER == 1, "");
REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b);
static_assert(REFLECTIVE_FIELD_COUNTER == 2, "");
static_assert(REFLECTIVE_FIELD_COUNTER == 2, "");
REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c);
static_assert(REFLECTIVE_FIELD_COUNTER == 3, "");
static_assert(REFLECTIVE_FIELD_COUNTER == 3, "");
REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d);
static_assert(REFLECTIVE_FIELD_COUNTER == 4, "");
static_assert(REFLECTIVE_FIELD_COUNTER == 4, "");
REFLECTIVE_CLASS_END(TestIntrusiveField);
// clang-format on
TEST(intrusive, intrusive_field_number) {
static_assert(INTRUSIVE_FIELD_NUMBER(TestIntrusiveField, a) == 1, "");
static_assert(INTRUSIVE_FIELD_NUMBER(TestIntrusiveField, b) == 2, "");
static_assert(INTRUSIVE_FIELD_NUMBER(TestIntrusiveField, c) == 3, "");
static_assert(INTRUSIVE_FIELD_NUMBER(TestIntrusiveField, d) == 4, "");
static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, a) == 1, "");
static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, b) == 2, "");
static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, c) == 3, "");
static_assert(REFLECTIVE_FIELD_NUMBER(TestIntrusiveField, d) == 4, "");
}
TEST(intrusive, intrusive_field_type) {
static_assert(std::is_same<INTRUSIVE_FIELD_TYPE(TestIntrusiveField, 1), int32_t>::value, "");
static_assert(std::is_same<INTRUSIVE_FIELD_TYPE(TestIntrusiveField, 2), int64_t>::value, "");
static_assert(std::is_same<INTRUSIVE_FIELD_TYPE(TestIntrusiveField, 3), int8_t>::value, "");
static_assert(std::is_same<INTRUSIVE_FIELD_TYPE(TestIntrusiveField, 4), int64_t>::value, "");
static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 1), int32_t>::value, "");
static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 2), int64_t>::value, "");
static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 3), int8_t>::value, "");
static_assert(std::is_same<REFLECTIVE_FIELD_TYPE(TestIntrusiveField, 4), int64_t>::value, "");
}
TEST(intrusive, intrusive_field_offset) {
static_assert(INTRUSIVE_FIELD_OFFSET(TestIntrusiveField, 1) == 0, "");
static_assert(INTRUSIVE_FIELD_OFFSET(TestIntrusiveField, 2) == 8, "");
static_assert(INTRUSIVE_FIELD_OFFSET(TestIntrusiveField, 3) == 16, "");
static_assert(INTRUSIVE_FIELD_OFFSET(TestIntrusiveField, 4) == 24, "");
static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 1) == 0, "");
static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 2) == 8, "");
static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 3) == 16, "");
static_assert(REFLECTIVE_FIELD_OFFSET(TestIntrusiveField, 4) == 24, "");
}
} // namespace
......
......@@ -23,21 +23,20 @@ namespace oneflow {
namespace intrusive {
template<typename ValueHookField>
template<typename HookField>
class List {
public:
static_assert(std::is_same<typename ValueHookField::field_type, intrusive::ListHook>::value, "");
List(const List&) = delete;
List(List&&) = delete;
List() { this->__Init__(); }
~List() { this->Clear(); }
using value_type = typename ValueHookField::struct_type;
using iterator_struct_field = ValueHookField;
using value_type = typename HookField::struct_type;
using iterator_struct_field = HookField;
template<typename Enabled = void>
static constexpr int IteratorHookOffset() {
return offsetof(List, list_head_) + intrusive::ListHead<ValueHookField>::IteratorHookOffset();
return offsetof(List, list_head_) + intrusive::ListHead<HookField>::IteratorHookOffset();
}
std::size_t size() const { return list_head_.size(); }
......@@ -127,7 +126,7 @@ class List {
}
private:
intrusive::ListHead<ValueHookField> list_head_;
intrusive::ListHead<HookField> list_head_;
};
} // namespace intrusive
......
......@@ -44,7 +44,7 @@ class TestListHead : public intrusive::ListHead<ItemField> {
TestListHead() { this->__Init__(); }
};
using BarListHead = TestListHead<STRUCT_FIELD(ListItemBar, bar_list)>;
using BarListHead = TestListHead<INTRUSIVE_FIELD(ListItemBar, bar_list)>;
TEST(TestListHook, init) {
TestListHook list_iterator;
......
......@@ -27,8 +27,7 @@ namespace test {
namespace {
// clang-format off
INTRUSIVE_BEGIN(TestListItem)
class TestListItem : public intrusive::Base {
public:
void __Init__() { clear_cnt(); }
void __Delete__() {
......@@ -47,16 +46,16 @@ INTRUSIVE_BEGIN(TestListItem)
size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }
intrusive::ListHook foo_list_;
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
TestListItem() : intrusive_ref_(), cnt_(), foo_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(int*, cnt_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, foo_list_);
INTRUSIVE_END(TestListItem)
// clang-format on
intrusive::Ref intrusive_ref_;
int* cnt_;
};
using TestList = intrusive::List<INTRUSIVE_FIELD(TestListItem, foo_list_)>;
......@@ -251,8 +250,7 @@ TEST(List, FOR_EACH) {
ASSERT_EQ(item1->ref_cnt(), 1);
}
// clang-format off
INTRUSIVE_BEGIN(TestIntrusiveListHead);
class TestIntrusiveListHead final : public intrusive::Base {
public:
// types
using FooList = intrusive::List<INTRUSIVE_FIELD(TestListItem, foo_list_)>;
......@@ -266,10 +264,9 @@ INTRUSIVE_BEGIN(TestIntrusiveListHead);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
TestIntrusiveListHead() : intrusive_ref_(), foo_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(FooList, foo_list_);
INTRUSIVE_END(TestIntrusiveListHead);
// clang-format on
intrusive::Ref intrusive_ref_;
FooList foo_list_;
};
TEST(List, intrusive_list_for_each) {
auto foo_list_head = intrusive::make_shared<TestIntrusiveListHead>();
......@@ -297,8 +294,7 @@ TEST(List, intrusive_list_for_each) {
ASSERT_EQ(item1->ref_cnt(), 1);
}
// clang-format off
INTRUSIVE_BEGIN(TestIntrusiveListHeadWrapper);
class TestIntrusiveListHeadWrapper final : public intrusive::Base {
public:
// Getters
const TestIntrusiveListHead& head() const {
......@@ -320,10 +316,9 @@ INTRUSIVE_BEGIN(TestIntrusiveListHeadWrapper);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
TestIntrusiveListHeadWrapper() : intrusive_ref_(), head_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<TestIntrusiveListHead>, head_);
INTRUSIVE_END(TestIntrusiveListHeadWrapper);
// clang-format on
intrusive::Ref intrusive_ref_;
intrusive::shared_ptr<TestIntrusiveListHead> head_;
};
TEST(List, nested_list_delete) {
auto foo_list_head = intrusive::make_shared<TestIntrusiveListHeadWrapper>();
......@@ -373,219 +368,6 @@ TEST(List, MoveTo) {
ASSERT_EQ(item1->ref_cnt(), 2);
}
// clang-format off
INTRUSIVE_BEGIN(SelfLoopContainer);
public:
void __Init__() { clear_deleted(); }
// Getters
bool has_deleted() const { return deleted_ != nullptr; }
bool deleted() const { return *deleted_; }
bool is_hook_empty() const { return hook_.empty(); }
// Setters
bool* mut_deleted() { return deleted_; }
void set_deleted(bool* val) { deleted_ = val; }
void clear_deleted() { deleted_ = nullptr; }
// methods
void __Init__(bool* deleted) {
__Init__();
set_deleted(deleted);
}
void __Delete__() { *mut_deleted() = true; }
size_t ref_cnt() const { return intrusive_ref_.ref_cnt(); }
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
SelfLoopContainer() : intrusive_ref_(), deleted_(), hook_(), head_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
// fields
INTRUSIVE_DEFINE_FIELD(bool*, deleted_);
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, hook_);
public:
// Do not insert other INTRUSIVE_DEFINE_FIELDs between `using SelfLoopContainerList = ...;` and
// `INTRUSIVE_DEFINE_FIELD(SelfLoopContainerList, ...);`
using SelfLoopContainerList = intrusive::HeadFreeList<INTRUSIVE_FIELD(SelfLoopContainer, hook_), INTRUSIVE_FIELD_COUNTER>;
const SelfLoopContainerList& head() const { return head_; }
SelfLoopContainerList* mut_head() { return &head_; }
INTRUSIVE_DEFINE_FIELD(SelfLoopContainerList, head_);
INTRUSIVE_END(SelfLoopContainer);
// clang-format on
TEST(IntrusiveSelfLoopList, __Init__) {
bool deleted = false;
auto self_loop_head = intrusive::make_shared<SelfLoopContainer>(&deleted);
ASSERT_EQ(self_loop_head->mut_head()->container_, self_loop_head.Mutable());
}
TEST(IntrusiveSelfLoopList, PushBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, PushFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->PushFront(self_loop_head0.Mutable());
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->PushFront(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, EmplaceBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceBack(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceBack(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, EmplaceFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceFront(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head0));
ASSERT_EQ(self_loop_head0->head().size(), 1);
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
self_loop_head0->mut_head()->EmplaceFront(
intrusive::shared_ptr<SelfLoopContainer>(self_loop_head1));
ASSERT_EQ(self_loop_head1->ref_cnt(), 2);
ASSERT_EQ(self_loop_head0->head().size(), 2);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, Erase) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->Erase(self_loop_head0.Mutable());
self_loop_head0->mut_head()->Erase(self_loop_head1.Mutable());
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, PopBack) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->PopBack();
self_loop_head0->mut_head()->PopBack();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, PopFront) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->PopFront();
self_loop_head0->mut_head()->PopFront();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, MoveTo) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->MoveTo(self_loop_head1->mut_head());
ASSERT_EQ(self_loop_head0->ref_cnt(), 2);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
TEST(IntrusiveSelfLoopList, Clear) {
bool deleted0 = false;
bool deleted1 = false;
{
auto self_loop_head0 = intrusive::make_shared<SelfLoopContainer>(&deleted0);
auto self_loop_head1 = intrusive::make_shared<SelfLoopContainer>(&deleted1);
self_loop_head0->mut_head()->PushBack(self_loop_head0.Mutable());
self_loop_head0->mut_head()->PushBack(self_loop_head1.Mutable());
self_loop_head0->mut_head()->Clear();
ASSERT_EQ(self_loop_head0->ref_cnt(), 1);
ASSERT_EQ(self_loop_head1->ref_cnt(), 1);
}
ASSERT_TRUE(deleted0);
ASSERT_TRUE(deleted1);
}
} // namespace
} // namespace test
......
......@@ -27,6 +27,7 @@ template<typename HookField>
class MutexedList {
public:
using value_type = typename HookField::struct_type;
using list_type = List<HookField>;
MutexedList(const MutexedList&) = delete;
MutexedList(MutexedList&&) = delete;
......@@ -67,12 +68,12 @@ class MutexedList {
return list_head_.PopFront();
}
void MoveFrom(List<HookField>* src) {
void MoveFrom(list_type* src) {
std::unique_lock<std::mutex> lock(mutex_);
src->MoveToDstBack(&list_head_);
}
void MoveTo(List<HookField>* dst) {
void MoveTo(list_type* dst) {
std::unique_lock<std::mutex> lock(mutex_);
list_head_.MoveToDstBack(dst);
}
......@@ -83,7 +84,7 @@ class MutexedList {
}
private:
List<HookField> list_head_;
list_type list_head_;
mutable std::mutex mutex_;
};
......
......@@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_INTRUSIVE_REF_H_
#include <atomic>
#include <glog/logging.h>
namespace oneflow {
......
......@@ -13,20 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_CORE_H_
#define ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_CORE_H_
#ifndef ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_
#define ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_
#include <cstring>
#include <memory>
#include <type_traits>
#include <glog/logging.h>
#include "oneflow/core/intrusive/dss.h"
#include "oneflow/core/intrusive/static_counter.h"
#include "oneflow/core/intrusive/struct_traits.h"
#include "oneflow/core/intrusive/base.h"
namespace oneflow {
#define INTRUSIVE_BEGIN(class_name) \
#define REFLECTIVE_CLASS_BEGIN(class_name) \
struct class_name final : public intrusive::Base { \
public: \
using self_type = class_name; \
......@@ -36,7 +33,7 @@ namespace oneflow {
DEFINE_STATIC_COUNTER(field_counter); \
DSS_BEGIN(STATIC_COUNTER(field_counter), class_name);
#define INTRUSIVE_END(class_name) \
#define REFLECTIVE_CLASS_END(class_name) \
static_assert(__has_intrusive_ref__, "this class is not intrusive-referenced"); \
\
public: \
......@@ -48,53 +45,43 @@ namespace oneflow {
} \
;
#define INTRUSIVE_DEFINE_FIELD(field_type, field_name) \
private: \
#define REFLECTIVE_CLASS_DEFINE_FIELD(field_type, field_name) \
static_assert(__has_intrusive_ref__, "this class is not intrusive-referenced"); \
field_type field_name; \
INCREASE_STATIC_COUNTER(field_counter); \
DSS_DEFINE_FIELD(STATIC_COUNTER(field_counter), "intrusive-referenced class", field_type, \
field_name);
#define INTRUSIVE_FIELD(struct_type, field_name) \
StructField<struct_type, struct_type::OF_PP_CAT(field_name, DssFieldType), \
struct_type::OF_PP_CAT(field_name, kDssFieldOffset)>
#define REFLECTIVE_FIELD(struct_type, field_name) \
intrusive::OffsetStructField<struct_type, struct_type::OF_PP_CAT(field_name, DssFieldType), \
struct_type::OF_PP_CAT(field_name, kDssFieldOffset)>
// Get field number by field name
// note: field numbers start from 1 instead of 0.
#define INTRUSIVE_FIELD_NUMBER(cls, field_name) cls::OF_PP_CAT(field_name, kDssFieldNumber)
#define REFLECTIVE_FIELD_NUMBER(cls, field_name) cls::OF_PP_CAT(field_name, kDssFieldNumber)
// Get field type by field number
#define INTRUSIVE_FIELD_TYPE(cls, field_number) cls::template __DssFieldType__<field_number>::type
#define REFLECTIVE_FIELD_TYPE(cls, field_number) cls::template __DssFieldType__<field_number>::type
// Get field offset by field number
#define INTRUSIVE_FIELD_OFFSET(cls, field_number) \
#define REFLECTIVE_FIELD_OFFSET(cls, field_number) \
cls::template __DssFieldOffset4FieldIndex__<field_number>::value
// Get current defined field counter inside a intrusive-referenced class.
// note: not used outside INTRUSIVE_BEGIN ... INTRUSIVE_END
// note: not used outside REFLECTIVE_CLASS_BEGIN ... REFLECTIVE_CLASS_END
// e.g.:
// INTRUSIVE_BEGIN(Foo);
// static_assert(INTRUSIVE_FIELD_COUNTER == 0, "");
// INTRUSIVE_DEFINE_FIELD(int64_t, a);
// static_assert(INTRUSIVE_FIELD_COUNTER == 1, "");
// INTRUSIVE_DEFINE_FIELD(int64_t, b);
// static_assert(INTRUSIVE_FIELD_COUNTER == 2, "");
// INTRUSIVE_DEFINE_FIELD(int8_t, c);
// static_assert(INTRUSIVE_FIELD_COUNTER == 3, "");
// INTRUSIVE_DEFINE_FIELD(int64_t, d);
// INTRUSIVE_END(Foo);
#define INTRUSIVE_FIELD_COUNTER STATIC_COUNTER(field_counter)
namespace intrusive {
struct Base {
void __Init__() {}
void __Delete__() {}
};
} // namespace intrusive
// REFLECTIVE_CLASS_BEGIN(Foo);
// static_assert(REFLECTIVE_FIELD_COUNTER == 0, "");
// REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, a);
// static_assert(REFLECTIVE_FIELD_COUNTER == 1, "");
// REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, b);
// static_assert(REFLECTIVE_FIELD_COUNTER == 2, "");
// REFLECTIVE_CLASS_DEFINE_FIELD(int8_t, c);
// static_assert(REFLECTIVE_FIELD_COUNTER == 3, "");
// REFLECTIVE_CLASS_DEFINE_FIELD(int64_t, d);
// REFLECTIVE_CLASS_END(Foo);
#define REFLECTIVE_FIELD_COUNTER STATIC_COUNTER(field_counter)
} // namespace oneflow
#endif // ONEFLOW_CORE_INTRUSIVE_INTRUSIVE_CORE_H_
#endif // ONEFLOW_CORE_INTRUSIVE_REFLECTIVE_CORE_H_
......@@ -25,7 +25,6 @@ namespace intrusive {
template<typename T>
class shared_ptr final {
public:
static_assert(T::__has_intrusive_ref__, "T is not a intrusive-referenced class");
using value_type = T;
shared_ptr() : ptr_(nullptr) {}
shared_ptr(value_type* ptr) : ptr_(nullptr) { Reset(ptr); }
......
......@@ -35,10 +35,9 @@ class SkipList {
using value_type = typename ElemKeyField::struct_type;
using key_type = typename ElemKeyField::field_type::key_type;
using elem_key_level0_hook_struct_field =
StructField<typename ElemKeyField::field_type, intrusive::ListHook,
ElemKeyField::field_type::LevelZeroHookOffset()>;
using iterator_struct_field =
typename ComposeStructField<ElemKeyField, elem_key_level0_hook_struct_field>::type;
OffsetStructField<typename ElemKeyField::field_type, intrusive::ListHook,
ElemKeyField::field_type::LevelZeroHookOffset()>;
using iterator_struct_field = ComposeStructField<ElemKeyField, elem_key_level0_hook_struct_field>;
template<typename Enabled = void>
static constexpr int IteratorHookOffset() {
return offsetof(SkipList, skiplist_head_)
......
......@@ -65,7 +65,8 @@ struct ListHookArray final {
}
static ListHookArray* ThisPtr4HookPtr(ListHook* slist_ptr, int level) {
auto* hooks_ptr = (std::array<intrusive::ListHook, max_level>*)(slist_ptr - level);
return StructField<self_type, decltype(hooks_), HooksOffset()>::StructPtr4FieldPtr(hooks_ptr);
return OffsetStructField<self_type, decltype(hooks_), HooksOffset()>::StructPtr4FieldPtr(
hooks_ptr);
}
void CheckEmpty() const {
for (const auto& hook : hooks_) { CHECK(hook.empty()); }
......@@ -152,7 +153,7 @@ struct SkipListHook {
}
static SkipListHook* ThisPtr4HookPtr(ListHook* list_hook_ptr, int level) {
auto* skip_list_ptr = hook_type::ThisPtr4HookPtr(list_hook_ptr, level);
using FieldUtil = StructField<self_type, hook_type, SkipListIteratorOffset()>;
using FieldUtil = OffsetStructField<self_type, hook_type, SkipListIteratorOffset()>;
return FieldUtil::StructPtr4FieldPtr(skip_list_ptr);
}
......@@ -195,10 +196,10 @@ class SkipListHead {
using key_hook_type = typename ValueHookField::field_type;
using key_type = typename key_hook_type::key_type;
using value_key_level0_hook_struct_field =
StructField<typename ValueHookField::field_type, intrusive::ListHook,
ValueHookField::field_type::LevelZeroHookOffset()>;
OffsetStructField<typename ValueHookField::field_type, intrusive::ListHook,
ValueHookField::field_type::LevelZeroHookOffset()>;
using value_level0_hook_struct_field =
typename ComposeStructField<ValueHookField, value_key_level0_hook_struct_field>::type;
ComposeStructField<ValueHookField, value_key_level0_hook_struct_field>;
static const int max_level = key_hook_type::max_level;
template<typename Enabled = void>
static constexpr int IteratorHookOffset() {
......
......@@ -40,7 +40,7 @@ struct FooSkipListElem {
SkipListHook<int> key;
};
using FooSkipList = TestSkipListHead<STRUCT_FIELD(FooSkipListElem, key)>;
using FooSkipList = TestSkipListHead<INTRUSIVE_FIELD(FooSkipListElem, key)>;
TEST(SkipListHook, empty) {
FooSkipList skiplist;
......
......@@ -24,8 +24,7 @@ namespace test {
namespace {
// clang-format off
INTRUSIVE_BEGIN(SkipListFoo);
class SkipListFoo final : public intrusive::Base {
public:
void __Init__() { clear_is_deleted(); }
void __Delete__() {
......@@ -49,14 +48,14 @@ INTRUSIVE_BEGIN(SkipListFoo);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
SkipListFoo() : intrusive_ref_(), is_deleted_(), foo_map_key_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(int*, is_deleted_);
INTRUSIVE_DEFINE_FIELD(intrusive::SkipListHook<int32_t>, foo_map_key_);
INTRUSIVE_END(SkipListFoo);
// clang-format on
intrusive::Ref intrusive_ref_;
int* is_deleted_;
// clang-format off
INTRUSIVE_BEGIN(SkipListFooContainer);
public:
intrusive::SkipListHook<int32_t> foo_map_key_;
};
class SkipListFooContainer final : public intrusive::Base {
public:
// types
using Key2SkipListFoo = intrusive::SkipList<INTRUSIVE_FIELD(SkipListFoo, foo_map_key_)>;
......@@ -70,11 +69,10 @@ INTRUSIVE_BEGIN(SkipListFooContainer);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
SkipListFooContainer() : intrusive_ref_(), foo_map_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref intrusive_ref_;
// maps
INTRUSIVE_DEFINE_FIELD(Key2SkipListFoo, foo_map_);
INTRUSIVE_END(SkipListFooContainer);
// clang-format on
Key2SkipListFoo foo_map_;
};
using Key2SkipListFoo = intrusive::SkipList<INTRUSIVE_FIELD(SkipListFoo, foo_map_key_)>;
TEST(SkipList, empty) {
......
......@@ -46,10 +46,10 @@ static_assert(STATIC_COUNTER(static_counter) == 2, "");
TEST(StaticCounter, eq2) { static_assert(STATIC_COUNTER(static_counter) == 2, ""); }
// clang-format off
INTRUSIVE_BEGIN(FooBar);
REFLECTIVE_CLASS_BEGIN(FooBar);
FooBar() = default;
static_assert(STATIC_COUNTER(field_counter) == 0, "");
INTRUSIVE_END(FooBar);
REFLECTIVE_CLASS_END(FooBar);
// clang-format on
} // namespace
......
......@@ -21,15 +21,24 @@ limitations under the License.
#include "oneflow/core/common/preprocessor.h"
namespace oneflow {
namespace intrusive {
#define STRUCT_FIELD(T, field) \
StructField<T, STRUCT_FIELD_TYPE(T, field), STRUCT_FIELD_OFFSET(T, field)>
#define STRUCT_FIELD_TYPE(T, field) decltype(((T*)nullptr)->field)
#define STRUCT_FIELD_OFFSET(T, field) offsetof(T, field)
template<typename T, typename F, F T::*ptr2member>
struct PtrStructField {
using struct_type = T;
using field_type = F;
static T* StructPtr4FieldPtr(const F* field_ptr) {
int offset_value = reinterpret_cast<long long>(&(((T*)nullptr)->*ptr2member));
return (T*)(((char*)field_ptr) - offset_value);
}
static F* FieldPtr4StructPtr(const T* struct_ptr) {
return &(const_cast<T*>(struct_ptr)->*ptr2member);
}
};
// details
template<typename T, typename F, int offset>
struct StructField {
struct OffsetStructField {
using struct_type = T;
using field_type = F;
static const int offset_value = offset;
......@@ -42,12 +51,22 @@ struct StructField {
}
};
#define INTRUSIVE_FIELD(struct_type, field_name) \
intrusive::PtrStructField<struct_type, decltype(((struct_type*)nullptr)->field_name), \
&struct_type::field_name>
template<typename X, typename Y>
struct ComposeStructField {
static_assert(std::is_same<typename X::field_type, typename Y::struct_type>::value,
"invalid type");
using type = StructField<typename X::struct_type, typename Y::field_type,
X::offset_value + Y::offset_value>;
using struct_type = typename X::struct_type;
using field_type = typename Y::field_type;
static struct_type* StructPtr4FieldPtr(const field_type* field_ptr) {
return X::StructPtr4FieldPtr(Y::StructPtr4FieldPtr(field_ptr));
}
static field_type* FieldPtr4StructPtr(const struct_type* struct_ptr) {
return Y::FieldPtr4StructPtr(X::FieldPtr4StructPtr(struct_ptr));
}
};
template<typename T>
......@@ -75,6 +94,7 @@ struct ConstRefOrPtrStruct<T*> {
template<typename T>
using ConstRefOrPtr = typename ConstRefOrPtrStruct<T>::type;
} // namespace intrusive
} // namespace oneflow
#endif // ONEFLOW_CORE_INTRUSIVE_STRUCT_MACRO_TRAITS_H_
......@@ -33,8 +33,8 @@ struct OneflowTestNamespaceFoo {
TEST(StructField, mutable_struct_mutable_field) {
OneflowTestNamespaceFoo foo;
auto* bar = &foo.bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
......@@ -42,8 +42,8 @@ TEST(StructField, mutable_struct_mutable_field) {
TEST(StructField, mutable_struct_const_field) {
OneflowTestNamespaceFoo foo;
auto* bar = &foo.const_bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
......@@ -51,8 +51,8 @@ TEST(StructField, mutable_struct_const_field) {
TEST(StructField, const_struct_mutable_field) {
const OneflowTestNamespaceFoo foo;
auto* bar = &foo.bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
......@@ -60,8 +60,8 @@ TEST(StructField, const_struct_mutable_field) {
TEST(StructField, const_struct_const_field) {
const OneflowTestNamespaceFoo foo;
auto* bar = &foo.const_bar;
auto* struct_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = STRUCT_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
auto* struct_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::StructPtr4FieldPtr(bar);
auto* field_ptr = INTRUSIVE_FIELD(OneflowTestNamespaceFoo, const_bar)::FieldPtr4StructPtr(&foo);
ASSERT_EQ(struct_ptr, &foo);
ASSERT_EQ(field_ptr, bar);
}
......@@ -77,7 +77,7 @@ struct Y {
};
TEST(StructField, compose) {
using BFieldInY = typename ComposeStructField<STRUCT_FIELD(Y, d), STRUCT_FIELD(X, b)>::type;
using BFieldInY = intrusive::ComposeStructField<INTRUSIVE_FIELD(Y, d), INTRUSIVE_FIELD(X, b)>;
Y y;
int* field_b = &y.d.b;
ASSERT_EQ(BFieldInY::FieldPtr4StructPtr(&y), field_b);
......
......@@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_VM_ACCESS_BLOB_ARG_CB_PHY_INSTR_OPERAND_H_
#include <functional>
#include <memory>
#include "oneflow/core/vm/phy_instr_operand.h"
namespace oneflow {
......
......@@ -34,8 +34,7 @@ limitations under the License.
namespace oneflow {
namespace vm {
// clang-format off
INTRUSIVE_BEGIN(InstructionOperandList);
class InstructionOperandList final : public intrusive::Base {
public:
void __Init__() {}
// Getters
......@@ -48,11 +47,11 @@ INTRUSIVE_BEGIN(InstructionOperandList);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
InstructionOperandList() : intrusive_ref_(), operand_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
INTRUSIVE_DEFINE_FIELD(std::vector<FlatMsg<InstructionOperand>>, operand_);
INTRUSIVE_END(InstructionOperandList);
intrusive::Ref intrusive_ref_;
std::vector<FlatMsg<InstructionOperand>> operand_;
};
INTRUSIVE_BEGIN(InstructionMsg);
class InstructionMsg final : public intrusive::Base {
public:
// Getters
bool has_parallel_desc_symbol_id() const { return 0 != parallel_desc_symbol_id_; }
......@@ -84,7 +83,7 @@ INTRUSIVE_BEGIN(InstructionMsg);
void __Init__();
void __Init__(const std::string& instr_type_name);
void __Init__(const InstructionProto& proto);
void __Init__(const cfg::InstructionProto& proto);
void __Init__(const cfg::InstructionProto& proto);
void __Init__(const InstructionMsg& instr_msg);
void ToProto(InstructionProto* proto) const;
......@@ -95,16 +94,22 @@ INTRUSIVE_BEGIN(InstructionMsg);
intrusive::shared_ptr<InstructionMsg> add_bool_operand(bool bool_operand);
intrusive::shared_ptr<InstructionMsg> add_separator();
intrusive::shared_ptr<InstructionMsg> add_const_operand(ObjectId logical_object_id);
intrusive::shared_ptr<InstructionMsg> add_const_operand(ObjectId logical_object_id, const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_const_operand(ObjectId logical_object_id, const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_const_operand(ObjectId logical_object_id,
const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_const_operand(ObjectId logical_object_id,
const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_symbol_operand(ObjectId logical_object_id);
intrusive::shared_ptr<InstructionMsg> add_mut_operand(ObjectId logical_object_id);
intrusive::shared_ptr<InstructionMsg> add_mut_operand(ObjectId logical_object_id, const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut_operand(ObjectId logical_object_id, const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut_operand(ObjectId logical_object_id,
const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut_operand(ObjectId logical_object_id,
const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_init_symbol_operand(ObjectId logical_object_id);
intrusive::shared_ptr<InstructionMsg> add_mut2_operand(ObjectId logical_object_id);
intrusive::shared_ptr<InstructionMsg> add_mut2_operand(ObjectId logical_object_id, const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut2_operand(ObjectId logical_object_id, const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut2_operand(ObjectId logical_object_id,
const SoleMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_mut2_operand(ObjectId logical_object_id,
const AllMirroredObject&);
intrusive::shared_ptr<InstructionMsg> add_del_operand(ObjectId logical_object_id);
const std::vector<FlatMsg<InstructionOperand>>& operand() const {
return operand_list().operand();
......@@ -120,20 +125,29 @@ INTRUSIVE_BEGIN(InstructionMsg);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
InstructionMsg() : intrusive_ref_(), instr_type_id_(), instr_type_name_(), parallel_desc_symbol_id_(), parallel_desc_(), operand_list_(), phy_instr_operand_(), instr_msg_hook_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
InstructionMsg()
: intrusive_ref_(),
instr_type_id_(),
instr_type_name_(),
parallel_desc_symbol_id_(),
parallel_desc_(),
operand_list_(),
phy_instr_operand_(),
instr_msg_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(InstrTypeId, instr_type_id_);
InstrTypeId instr_type_id_;
// instr_type_name is a necessary reduandant field for method ToProto
INTRUSIVE_DEFINE_FIELD(std::string, instr_type_name_);
INTRUSIVE_DEFINE_FIELD(int64_t, parallel_desc_symbol_id_);
INTRUSIVE_DEFINE_FIELD(std::shared_ptr<const ParallelDesc>, parallel_desc_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<InstructionOperandList>, operand_list_);
INTRUSIVE_DEFINE_FIELD(std::shared_ptr<PhyInstrOperand>, phy_instr_operand_);
std::string instr_type_name_;
int64_t parallel_desc_symbol_id_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
intrusive::shared_ptr<InstructionOperandList> operand_list_;
std::shared_ptr<PhyInstrOperand> phy_instr_operand_;
public:
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, instr_msg_hook_);
INTRUSIVE_END(InstructionMsg);
// clang-format on
intrusive::ListHook instr_msg_hook_;
};
using InstructionMsgList = intrusive::List<INTRUSIVE_FIELD(InstructionMsg, instr_msg_hook_)>;
......@@ -153,25 +167,24 @@ FLAT_MSG_END(InstructionStatusBuffer);
// clang-format on
struct Instruction;
// clang-format off
INTRUSIVE_BEGIN(InstructionEdge);
class InstructionEdge final : public intrusive::Base {
public:
void __Init__() {
clear_src_instruction();
clear_dst_instruction();
}
// Getters
bool has_src_instruction() const { return src_instruction_ != nullptr; }
bool has_dst_instruction() const { return dst_instruction_ != nullptr; }
bool has_src_instruction() const { return src_instruction_ != nullptr; }
bool has_dst_instruction() const { return dst_instruction_ != nullptr; }
const Instruction& src_instruction() const { return *src_instruction_; }
const Instruction& dst_instruction() const { return *dst_instruction_; }
const Instruction& dst_instruction() const { return *dst_instruction_; }
// Setters
void set_src_instruction(Instruction* val) { src_instruction_ = val; }
void set_dst_instruction(Instruction* val) { dst_instruction_ = val; }
void clear_src_instruction() { src_instruction_ = nullptr; }
void clear_dst_instruction() { dst_instruction_ = nullptr; }
Instruction* mut_src_instruction() { return src_instruction_; }
Instruction* mut_dst_instruction() { return dst_instruction_; }
void set_src_instruction(Instruction* val) { src_instruction_ = val; }
void set_dst_instruction(Instruction* val) { dst_instruction_ = val; }
void clear_src_instruction() { src_instruction_ = nullptr; }
void clear_dst_instruction() { dst_instruction_ = nullptr; }
Instruction* mut_src_instruction() { return src_instruction_; }
Instruction* mut_dst_instruction() { return dst_instruction_; }
// methods
void __Init__(Instruction* src_instruction, Instruction* dst_instruction) {
__Init__();
......@@ -183,20 +196,25 @@ INTRUSIVE_BEGIN(InstructionEdge);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
InstructionEdge() : intrusive_ref_(), src_instruction_(), dst_instruction_(), in_edge_hook_(), out_edge_hook_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
InstructionEdge()
: intrusive_ref_(),
src_instruction_(),
dst_instruction_(),
in_edge_hook_(),
out_edge_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(Instruction*, src_instruction_);
INTRUSIVE_DEFINE_FIELD(Instruction*, dst_instruction_);
Instruction* src_instruction_;
Instruction* dst_instruction_;
public:
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, in_edge_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, out_edge_hook_);
INTRUSIVE_END(InstructionEdge);
// clang-format on
intrusive::ListHook in_edge_hook_;
intrusive::ListHook out_edge_hook_;
};
struct Stream;
// clang-format off
INTRUSIVE_BEGIN(Instruction);
class Instruction final : public intrusive::Base {
public:
// types
using InEdgeList = intrusive::List<INTRUSIVE_FIELD(InstructionEdge, in_edge_hook_)>;
......@@ -208,8 +226,8 @@ INTRUSIVE_BEGIN(Instruction);
// Getters
void __Init__() { clear_stream(); }
bool has_stream() const { return stream_ != nullptr; }
const Stream& stream() const { return *stream_; }
bool has_stream() const { return stream_ != nullptr; }
const Stream& stream() const { return *stream_; }
const InstructionMsg& instr_msg() const {
if (instr_msg_) { return instr_msg_.Get(); }
static const auto default_val = intrusive::make_shared<InstructionMsg>();
......@@ -218,15 +236,22 @@ INTRUSIVE_BEGIN(Instruction);
const std::shared_ptr<const ParallelDesc>& parallel_desc() const { return parallel_desc_; }
const InstructionStatusBuffer& status_buffer() const { return status_buffer_.Get(); }
const intrusive::ListHook& instruction_hook() const { return instruction_hook_; }
const intrusive::ListHook& dispatched_instruction_hook() const { return dispatched_instruction_hook_; }
const intrusive::ListHook& vm_stat_running_instruction_hook() const { return vm_stat_running_instruction_hook_; }
const intrusive::ListHook& dispatched_instruction_hook() const {
return dispatched_instruction_hook_;
}
const intrusive::ListHook& vm_stat_running_instruction_hook() const {
return vm_stat_running_instruction_hook_;
}
const intrusive::ListHook& pending_instruction_hook() const { return pending_instruction_hook_; }
const intrusive::ListHook& front_seq_compute_instr_hook() const { return front_seq_compute_instr_hook_; }
const intrusive::ListHook& front_seq_compute_instr_hook() const {
return front_seq_compute_instr_hook_;
}
const InEdgeList& in_edges() const { return in_edges_; }
const OutEdgeList& out_edges() const { return out_edges_; }
const RwMutexedObjectAccessList& access_list() const { return access_list_; }
const MirroredObjectId2RwMutexedObjectAccess& mirrored_object_id2access() const {
return mirrored_object_id2access_; }
return mirrored_object_id2access_;
}
// Setters
void set_stream(Stream* val) { stream_ = val; }
......@@ -244,32 +269,33 @@ INTRUSIVE_BEGIN(Instruction);
OutEdgeList* mut_out_edges() { return &out_edges_; }
RwMutexedObjectAccessList* mut_access_list() { return &access_list_; }
MirroredObjectId2RwMutexedObjectAccess* mut_mirrored_object_id2access() {
return &mirrored_object_id2access_;
return &mirrored_object_id2access_;
}
// methods
void __Init__(InstructionMsg* instr_msg, Stream* stream, const std::shared_ptr<const ParallelDesc>& parallel_desc);
void __Init__(InstructionMsg* instr_msg, Stream* stream,
const std::shared_ptr<const ParallelDesc>& parallel_desc);
void __Delete__();
bool Done() const;
void set_has_event_record(bool val);
const StreamType& stream_type() const;
template<OperandMemZoneModifier mem_zone_modifier>
const RwMutexedObject* operand_type(const Operand& operand) const {
const RwMutexedObject* operand_type(const Operand& operand) const {
CheckOperand<mem_zone_modifier>(operand);
return operand_type(operand, GetOperandDefaultGlobalDeviceId());
}
template<OperandMemZoneModifier mem_zone_modifier>
const RwMutexedObject* operand_value(const Operand& operand) const {
const RwMutexedObject* operand_value(const Operand& operand) const {
CheckOperand<mem_zone_modifier>(operand);
return operand_value(operand, GetOperandDefaultGlobalDeviceId());
}
template<OperandMemZoneModifier mem_zone_modifier>
RwMutexedObject* mut_operand_type(const Operand& operand) {
RwMutexedObject* mut_operand_type(const Operand& operand) {
CheckOperand<mem_zone_modifier>(operand);
return mut_operand_type(operand, GetOperandDefaultGlobalDeviceId());
}
template<OperandMemZoneModifier mem_zone_modifier>
RwMutexedObject* mut_operand_value(const Operand& operand) {
RwMutexedObject* mut_operand_value(const Operand& operand) {
CheckOperand<mem_zone_modifier>(operand);
return mut_operand_value(operand, GetOperandDefaultGlobalDeviceId());
}
......@@ -294,11 +320,11 @@ INTRUSIVE_BEGIN(Instruction);
return mut_operand_value<mem_zone_modifier>(operand.operand());
}
template<InterpretType interpret_type>
MirroredObject* MutMirroredObject(const MutOperand& mut_operand) {
MirroredObject* MutMirroredObject(const MutOperand& mut_operand) {
return MirroredObjectUtil<interpret_type>::Mut(this, mut_operand);
}
template<InterpretType interpret_type>
const MirroredObject* GetMirroredObject(const ConstOperand& const_operand) const {
const MirroredObject* GetMirroredObject(const ConstOperand& const_operand) const {
return MirroredObjectUtil<interpret_type>::Get(*this, const_operand);
}
MirroredObject* mut_type_mirrored_object(const MutOperand& mut_operand);
......@@ -307,22 +333,18 @@ INTRUSIVE_BEGIN(Instruction);
intrusive::Ref::RefCntType ref_cnt() const { return intrusive_ref_.ref_cnt(); }
private:
template<int64_t(*TransformLogicalObjectId)(int64_t)>
MirroredObject* MutMirroredObject(const Operand& operand,
int64_t default_global_device_id);
template<int64_t(*TransformLogicalObjectId)(int64_t)>
const MirroredObject* GetMirroredObject(const Operand& operand,
int64_t default_global_device_id) const;
template<int64_t (*TransformLogicalObjectId)(int64_t)>
MirroredObject* MutMirroredObject(const Operand& operand, int64_t default_global_device_id);
template<int64_t (*TransformLogicalObjectId)(int64_t)>
const MirroredObject* GetMirroredObject(const Operand& operand,
int64_t default_global_device_id) const;
const RwMutexedObject* operand_type(const Operand& operand,
int64_t default_global_device_id) const;
int64_t default_global_device_id) const;
const RwMutexedObject* operand_value(const Operand& operand,
int64_t default_global_device_id) const;
RwMutexedObject* mut_operand_type(const Operand& operand,
int64_t default_global_device_id);
RwMutexedObject* mut_operand_value(const Operand& operand,
int64_t default_global_device_id);
MirroredObject* MutMirroredObject(const Operand& operand,
int64_t default_global_device_id) {
int64_t default_global_device_id) const;
RwMutexedObject* mut_operand_type(const Operand& operand, int64_t default_global_device_id);
RwMutexedObject* mut_operand_value(const Operand& operand, int64_t default_global_device_id);
MirroredObject* MutMirroredObject(const Operand& operand, int64_t default_global_device_id) {
return MutMirroredObject<&IdUtil::GetValueId>(operand, default_global_device_id);
}
int64_t GetOperandDefaultGlobalDeviceId() const;
......@@ -335,31 +357,47 @@ INTRUSIVE_BEGIN(Instruction);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
Instruction() : intrusive_ref_(), status_buffer_(), instr_msg_(), parallel_desc_(), stream_(), instruction_hook_(), dispatched_instruction_hook_(), vm_stat_running_instruction_hook_(), pending_instruction_hook_(), front_seq_infer_instr_hook_(), front_seq_compute_instr_hook_(), mirrored_object_id2access_(), access_list_(), in_edges_(), out_edges_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
Instruction()
: intrusive_ref_(),
status_buffer_(),
instr_msg_(),
parallel_desc_(),
stream_(),
mirrored_object_id2access_(),
access_list_(),
in_edges_(),
out_edges_(),
instruction_hook_(),
dispatched_instruction_hook_(),
vm_stat_running_instruction_hook_(),
pending_instruction_hook_(),
front_seq_infer_instr_hook_(),
front_seq_compute_instr_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(FlatMsg<InstructionStatusBuffer>, status_buffer_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<InstructionMsg>, instr_msg_);
INTRUSIVE_DEFINE_FIELD(std::shared_ptr<const ParallelDesc>, parallel_desc_);
INTRUSIVE_DEFINE_FIELD(Stream*, stream_);
FlatMsg<InstructionStatusBuffer> status_buffer_;
intrusive::shared_ptr<InstructionMsg> instr_msg_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
Stream* stream_;
// maps
MirroredObjectId2RwMutexedObjectAccess mirrored_object_id2access_;
// lists
RwMutexedObjectAccessList access_list_;
InEdgeList in_edges_;
OutEdgeList out_edges_;
public:
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, instruction_hook_);
intrusive::ListHook instruction_hook_;
// dispatched to Stream
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, dispatched_instruction_hook_);
// `vm_stat_running_instruction_hook` valid from instruction ready to instruction done
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, vm_stat_running_instruction_hook_);
intrusive::ListHook dispatched_instruction_hook_;
// `vm_stat_running_instruction_hook` valid from instruction ready to instruction done
intrusive::ListHook vm_stat_running_instruction_hook_;
// pending to ThreadCtx
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, pending_instruction_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, front_seq_infer_instr_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, front_seq_compute_instr_hook_);
// maps
INTRUSIVE_DEFINE_FIELD(MirroredObjectId2RwMutexedObjectAccess, mirrored_object_id2access_);
// lists
INTRUSIVE_DEFINE_FIELD(RwMutexedObjectAccessList, access_list_);
INTRUSIVE_DEFINE_FIELD(InEdgeList, in_edges_);
INTRUSIVE_DEFINE_FIELD(OutEdgeList, out_edges_);
INTRUSIVE_END(Instruction);
// clang-format on
intrusive::ListHook pending_instruction_hook_;
intrusive::ListHook front_seq_infer_instr_hook_;
intrusive::ListHook front_seq_compute_instr_hook_;
};
} // namespace vm
} // namespace oneflow
......
......@@ -17,6 +17,7 @@ limitations under the License.
#define ONEFLOW_CORE_VM_RELEASE_TENSOR_ARG_PHY_INSTR_OPERAND_H_
#include <functional>
#include <memory>
#include "oneflow/core/vm/phy_instr_operand.h"
namespace oneflow {
......
......@@ -25,11 +25,11 @@ namespace vm {
struct ThreadCtx;
// clang-format off
INTRUSIVE_BEGIN(Stream);
class Stream final : public intrusive::Base {
public:
// types
using DispatchedInstructionList = intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;
using DispatchedInstructionList =
intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;
// Getters
int64_t max_device_num_per_machine() const { return max_device_num_per_machine_; }
......@@ -38,8 +38,12 @@ INTRUSIVE_BEGIN(Stream);
const std::unique_ptr<DeviceCtx>& device_ctx() const { return device_ctx_; }
const intrusive::ListHook& active_stream_hook() const { return active_stream_hook_; }
const DispatchedInstructionList& free_instruction_list() const { return free_instruction_list_; }
const DispatchedInstructionList& zombie_instruction_list() const { return zombie_instruction_list_; }
const DispatchedInstructionList& running_instruction_list() const { return running_instruction_list_; }
const DispatchedInstructionList& zombie_instruction_list() const {
return zombie_instruction_list_;
}
const DispatchedInstructionList& running_instruction_list() const {
return running_instruction_list_;
}
const StreamId& stream_id() const { return stream_id_.key(); }
// Setters
......@@ -55,8 +59,10 @@ INTRUSIVE_BEGIN(Stream);
// methods
void __Init__();
void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id, const int64_t max_device_num_per_machine);
intrusive::shared_ptr<Instruction> NewInstruction(InstructionMsg* instr_msg, const std::shared_ptr<const ParallelDesc>& parallel_desc);
void __Init__(ThreadCtx* thread_ctx, const StreamId& stream_id,
const int64_t max_device_num_per_machine);
intrusive::shared_ptr<Instruction> NewInstruction(
InstructionMsg* instr_msg, const std::shared_ptr<const ParallelDesc>& parallel_desc);
void DeleteInstruction(intrusive::shared_ptr<Instruction>&&);
int64_t global_device_id() const { return stream_id().global_device_id(); }
int64_t machine_id() const;
......@@ -71,23 +77,34 @@ INTRUSIVE_BEGIN(Stream);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
Stream() : intrusive_ref_(), thread_ctx_(), device_ctx_(), max_device_num_per_machine_(), active_stream_hook_(), thread_ctx_stream_hook_(), stream_id_(), free_instruction_list_(), zombie_instruction_list_(), running_instruction_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
Stream()
: intrusive_ref_(),
thread_ctx_(),
device_ctx_(),
max_device_num_per_machine_(),
free_instruction_list_(),
zombie_instruction_list_(),
running_instruction_list_(),
stream_id_(),
active_stream_hook_(),
thread_ctx_stream_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(ThreadCtx*, thread_ctx_);
INTRUSIVE_DEFINE_FIELD(std::unique_ptr<DeviceCtx>, device_ctx_);
INTRUSIVE_DEFINE_FIELD(int64_t, max_device_num_per_machine_);
ThreadCtx* thread_ctx_;
std::unique_ptr<DeviceCtx> device_ctx_;
int64_t max_device_num_per_machine_;
// lists
DispatchedInstructionList free_instruction_list_;
DispatchedInstructionList zombie_instruction_list_;
DispatchedInstructionList running_instruction_list_;
public:
// skiplist hooks
intrusive::SkipListHook<StreamId, 10> stream_id_;
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, active_stream_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, thread_ctx_stream_hook_);
using StreamIdKey = intrusive::SkipListHook<StreamId, 10>;
INTRUSIVE_DEFINE_FIELD(StreamIdKey, stream_id_);
// lists
INTRUSIVE_DEFINE_FIELD(DispatchedInstructionList, free_instruction_list_);
INTRUSIVE_DEFINE_FIELD(DispatchedInstructionList, zombie_instruction_list_);
INTRUSIVE_DEFINE_FIELD(DispatchedInstructionList, running_instruction_list_);
INTRUSIVE_END(Stream);
// clang-format on
intrusive::ListHook active_stream_hook_;
intrusive::ListHook thread_ctx_stream_hook_;
};
} // namespace vm
} // namespace oneflow
......
......@@ -56,8 +56,7 @@ class StreamId final {
int64_t global_device_id_;
};
// clang-format off
INTRUSIVE_BEGIN(StreamDesc);
class StreamDesc final : public intrusive::Base {
public:
// Getters
int32_t num_machines() const { return num_machines_; }
......@@ -72,26 +71,31 @@ INTRUSIVE_BEGIN(StreamDesc);
// methods
void __Init__() {}
void __Init__(const StreamTypeId& stream_type_id, int32_t num_machines, int32_t num_streams_per_machine,
int32_t num_streams_per_thread);
void __Init__(const StreamTypeId& stream_type_id, int32_t num_machines,
int32_t num_streams_per_machine, int32_t num_streams_per_thread);
int32_t num_threads() const;
int32_t parallel_num() const { return num_machines() * num_streams_per_machine(); }
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
StreamDesc() : intrusive_ref_(), num_machines_(), num_streams_per_machine_(), num_streams_per_thread_(), stream_type_id_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
StreamDesc()
: intrusive_ref_(),
num_machines_(),
num_streams_per_machine_(),
num_streams_per_thread_(),
stream_type_id_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(int32_t, num_machines_);
INTRUSIVE_DEFINE_FIELD(int32_t, num_streams_per_machine_);
INTRUSIVE_DEFINE_FIELD(int32_t, num_streams_per_thread_);
int32_t num_machines_;
int32_t num_streams_per_machine_;
int32_t num_streams_per_thread_;
public:
// skiplist hooks
using StreamTypeIdKey = intrusive::SkipListHook<FlatMsg<StreamTypeId>, 7>;
INTRUSIVE_DEFINE_FIELD(StreamTypeIdKey, stream_type_id_);
INTRUSIVE_END(StreamDesc);
// clang-format on
intrusive::SkipListHook<FlatMsg<StreamTypeId>, 7> stream_type_id_;
};
} // namespace vm
} // namespace oneflow
......
......@@ -26,8 +26,7 @@ class StreamType;
struct StreamDesc;
// Rt is short for Runtime
// clang-format off
INTRUSIVE_BEGIN(StreamRtDesc);
class StreamRtDesc final : public intrusive::Base {
public:
// types
using StreamId2Stream = intrusive::SkipList<INTRUSIVE_FIELD(Stream, stream_id_)>;
......@@ -40,7 +39,7 @@ INTRUSIVE_BEGIN(StreamRtDesc);
const StreamTypeId& stream_type_id() const { return stream_type_id_.key().Get(); }
const StreamId2Stream& stream_id2stream() const { return stream_id2stream_; }
// Setters
StreamDesc* mut_stream_desc() {
StreamDesc* mut_stream_desc() {
if (!stream_desc_) { stream_desc_ = intrusive::make_shared<StreamDesc>(); }
return stream_desc_.Mutable();
}
......@@ -56,16 +55,17 @@ INTRUSIVE_BEGIN(StreamRtDesc);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
StreamRtDesc() : intrusive_ref_(), stream_desc_(), stream_type_id_(), stream_id2stream_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
StreamRtDesc() : intrusive_ref_(), stream_desc_(), stream_id2stream_(), stream_type_id_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<StreamDesc>, stream_desc_);
// list hooks
using StreamTypeIdKey = intrusive::SkipListHook<FlatMsg<StreamTypeId>, 7>;
INTRUSIVE_DEFINE_FIELD(StreamTypeIdKey, stream_type_id_);
INTRUSIVE_DEFINE_FIELD(StreamId2Stream, stream_id2stream_);
INTRUSIVE_END(StreamRtDesc);
// clang-format on
intrusive::shared_ptr<StreamDesc> stream_desc_;
// maps
StreamId2Stream stream_id2stream_;
public:
// skiplist hooks
intrusive::SkipListHook<FlatMsg<StreamTypeId>, 7> stream_type_id_;
};
} // namespace vm
} // namespace oneflow
......
......@@ -43,8 +43,7 @@ TEST(StreamTypeId, logical_compare) {
LookupInferStreamTypeId(stream_type_id0);
}
// clang-format off
INTRUSIVE_BEGIN(StreamTypeIdItem);
class StreamTypeIdItem final : public intrusive::Base {
public:
// Getters
const StreamTypeId& stream_type_id() const { return stream_type_id_.key().Get(); }
......@@ -56,11 +55,12 @@ INTRUSIVE_BEGIN(StreamTypeIdItem);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
StreamTypeIdItem() : intrusive_ref_(), stream_type_id_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
using StreamTypeIdKey = intrusive::SkipListHook<FlatMsg<StreamTypeId>, 20>;
INTRUSIVE_DEFINE_FIELD(StreamTypeIdKey, stream_type_id_);
INTRUSIVE_END(StreamTypeIdItem);
// clang-format on
intrusive::Ref intrusive_ref_;
public:
// skiplist hooks
intrusive::SkipListHook<FlatMsg<StreamTypeId>, 20> stream_type_id_;
};
using StreamTypeIdSet = intrusive::SkipList<INTRUSIVE_FIELD(StreamTypeIdItem, stream_type_id_)>;
TEST(StreamTypeId, map_key) {
......
......@@ -25,8 +25,7 @@ limitations under the License.
namespace oneflow {
namespace vm {
// clang-format off
INTRUSIVE_BEGIN(ThreadCtx);
class ThreadCtx final : public intrusive::Base {
public:
void __Init__() { clear_stream_rt_desc(); }
......@@ -53,24 +52,30 @@ INTRUSIVE_BEGIN(ThreadCtx);
}
void LoopRun(const std::function<void(ThreadCtx*)>& Initializer);
intrusive::ChannelStatus TryReceiveAndRun();
private:
intrusive::ChannelStatus ReceiveAndRun();
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
ThreadCtx() : intrusive_ref_(), stream_rt_desc_(), thread_ctx_hook_(), stream_list_(), pending_instruction_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
ThreadCtx()
: intrusive_ref_(),
stream_rt_desc_(),
stream_list_(),
pending_instruction_list_(),
thread_ctx_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(const StreamRtDesc*, stream_rt_desc_);
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, thread_ctx_hook_);
const StreamRtDesc* stream_rt_desc_;
// lists
INTRUSIVE_DEFINE_FIELD(StreamList, stream_list_);
INTRUSIVE_DEFINE_FIELD(PendingInstructionChannel, pending_instruction_list_);
INTRUSIVE_END(ThreadCtx);
// clang-format on
StreamList stream_list_;
PendingInstructionChannel pending_instruction_list_;
public:
// list hooks
intrusive::ListHook thread_ctx_hook_;
};
} // namespace vm
} // namespace oneflow
......
......@@ -34,8 +34,7 @@ namespace oneflow {
namespace vm {
struct VmDesc;
// clang-format off
INTRUSIVE_BEGIN(VirtualMachine);
class VirtualMachine final : public intrusive::Base {
public:
// types
using ActiveStreamList = intrusive::List<INTRUSIVE_FIELD(Stream, active_stream_hook_)>;
......@@ -47,8 +46,9 @@ INTRUSIVE_BEGIN(VirtualMachine);
using FrontSeqInstructionList =
intrusive::List<INTRUSIVE_FIELD(Instruction, front_seq_compute_instr_hook_)>;
using InstructionMsgMutextList =
intrusive::MutexedList<INTRUSIVE_FIELD(InstructionMsg, instr_msg_hook_)>;
using StreamTypeId2StreamRtDesc = intrusive::SkipList<INTRUSIVE_FIELD(StreamRtDesc, stream_type_id_)>;
intrusive::MutexedList<INTRUSIVE_FIELD(InstructionMsg, InstructionMsg::instr_msg_hook_)>;
using StreamTypeId2StreamRtDesc =
intrusive::SkipList<INTRUSIVE_FIELD(StreamRtDesc, stream_type_id_)>;
using Id2LogicalObject = intrusive::SkipList<INTRUSIVE_FIELD(LogicalObject, logical_object_id_)>;
// Getters
......@@ -61,14 +61,22 @@ INTRUSIVE_BEGIN(VirtualMachine);
const std::atomic<int64_t>& flying_instruction_cnt() const { return flying_instruction_cnt_; }
const ActiveStreamList& active_stream_list() const { return active_stream_list_; }
const ThreadCtxList& thread_ctx_list() const { return thread_ctx_list_; }
const LogicalObjectDeleteList& delete_logical_object_list() const { return delete_logical_object_list_; }
const LogicalObjectDeleteList& delete_logical_object_list() const {
return delete_logical_object_list_;
}
const InstructionList& waiting_instruction_list() const { return waiting_instruction_list_; }
const VmStatRunningInstructionList& vm_stat_running_instruction_list() const { return vm_stat_running_instruction_list_; }
const FrontSeqInstructionList& front_seq_compute_instr_list() const { return front_seq_compute_instr_list_; }
const VmStatRunningInstructionList& vm_stat_running_instruction_list() const {
return vm_stat_running_instruction_list_;
}
const FrontSeqInstructionList& front_seq_compute_instr_list() const {
return front_seq_compute_instr_list_;
}
const InstructionMsgMutextList& pending_msg_list() const { return pending_msg_list_; }
const StreamTypeId2StreamRtDesc& stream_type_id2stream_rt_desc() const { return stream_type_id2stream_rt_desc_; }
const StreamTypeId2StreamRtDesc& stream_type_id2stream_rt_desc() const {
return stream_type_id2stream_rt_desc_;
}
const Id2LogicalObject& id2logical_object() const { return id2logical_object_; }
//Setters
// Setters
VmResourceDesc* mut_vm_resource_desc() {
if (!vm_resource_desc_) { vm_resource_desc_ = intrusive::make_shared<VmResourceDesc>(); }
return vm_resource_desc_.Mutable();
......@@ -79,10 +87,16 @@ INTRUSIVE_BEGIN(VirtualMachine);
ThreadCtxList* mut_thread_ctx_list() { return &thread_ctx_list_; }
LogicalObjectDeleteList* mut_delete_logical_object_list() { return &delete_logical_object_list_; }
InstructionList* mut_waiting_instruction_list() { return &waiting_instruction_list_; }
VmStatRunningInstructionList* mut_vm_stat_running_instruction_list() { return &vm_stat_running_instruction_list_; }
FrontSeqInstructionList* mut_front_seq_compute_instr_list() { return &front_seq_compute_instr_list_; }
VmStatRunningInstructionList* mut_vm_stat_running_instruction_list() {
return &vm_stat_running_instruction_list_;
}
FrontSeqInstructionList* mut_front_seq_compute_instr_list() {
return &front_seq_compute_instr_list_;
}
InstructionMsgMutextList* mut_pending_msg_list() { return &pending_msg_list_; }
StreamTypeId2StreamRtDesc* mut_stream_type_id2stream_rt_desc() { return &stream_type_id2stream_rt_desc_; }
StreamTypeId2StreamRtDesc* mut_stream_type_id2stream_rt_desc() {
return &stream_type_id2stream_rt_desc_;
}
Id2LogicalObject* mut_id2logical_object() { return &id2logical_object_; }
// methods
......@@ -94,8 +108,7 @@ INTRUSIVE_BEGIN(VirtualMachine);
bool Empty() const;
Maybe<const ParallelDesc> GetInstructionParallelDesc(const InstructionMsg&);
MirroredObject* MutMirroredObject(int64_t logical_object_id, int64_t global_device_id);
const MirroredObject* GetMirroredObject(int64_t logical_object_id,
int64_t global_device_id);
const MirroredObject* GetMirroredObject(int64_t logical_object_id, int64_t global_device_id);
int64_t this_machine_id() const;
int64_t this_start_global_device_id() const {
......@@ -105,50 +118,50 @@ INTRUSIVE_BEGIN(VirtualMachine);
private:
using TmpPendingInstrMsgList = intrusive::List<INTRUSIVE_FIELD(InstructionMsg, instr_msg_hook_)>;
using NewInstructionList = InstructionList;
using ReadyInstructionList = intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;
using ReadyInstructionList =
intrusive::List<INTRUSIVE_FIELD(Instruction, dispatched_instruction_hook_)>;
ReadyInstructionList* mut_ready_instruction_list() { return &ready_instruction_list_; }
void TryRunFrontSeqInstruction();
void ReleaseInstruction(Instruction* instruction);
void TryReleaseFinishedInstructions(Stream* stream);
void FilterAndRunInstructionsInAdvance(TmpPendingInstrMsgList* instr_msg_list);
void MakeInstructions(TmpPendingInstrMsgList*, /*out*/ NewInstructionList* ret_instruction_list);
template<int64_t (*TransformLogicalObjectId)(int64_t), typename DoEachT>
void ForEachMirroredObject(Id2LogicalObject* id2logical_object,
const Operand& operand,
void ForEachMirroredObject(Id2LogicalObject* id2logical_object, const Operand& operand,
int64_t global_device_id, const DoEachT& DoEach);
template<OperandMemZoneModifier mem_zone_modifier, typename DoEachT>
void ForEachConstMirroredObject(const InterpretType interpret_type,
Id2LogicalObject* id2logical_object,
const ModifiedOperand<kConstModifier, mem_zone_modifier>& const_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ForEachConstMirroredObject(
const InterpretType interpret_type, Id2LogicalObject* id2logical_object,
const ModifiedOperand<kConstModifier, mem_zone_modifier>& const_operand,
int64_t global_device_id, const DoEachT& DoEach);
template<OperandMemZoneModifier mem_zone_modifier, typename DoEachT>
void ForEachConstMirroredObject(const InterpretType interpret_type,
Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDataMutableModifier, mem_zone_modifier>& mutable_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ForEachConstMirroredObject(
const InterpretType interpret_type, Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDataMutableModifier, mem_zone_modifier>& mutable_operand,
int64_t global_device_id, const DoEachT& DoEach);
template<OperandMemZoneModifier mem_zone_modifier, typename DoEachT>
void ForEachMutMirroredObject(const InterpretType interpret_type,
Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDataMutableModifier, mem_zone_modifier>& mutable_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ForEachMutMirroredObject(
const InterpretType interpret_type, Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDataMutableModifier, mem_zone_modifier>& mutable_operand,
int64_t global_device_id, const DoEachT& DoEach);
template<OperandMemZoneModifier mem_zone_modifier, typename DoEachT>
void ForEachMutMirroredObject(const InterpretType interpret_type,
Id2LogicalObject* id2logical_object,
const ModifiedOperand<kTypeAndDataMutableModifier, mem_zone_modifier>& mut2_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ForEachMutMirroredObject(
const InterpretType interpret_type, Id2LogicalObject* id2logical_object,
const ModifiedOperand<kTypeAndDataMutableModifier, mem_zone_modifier>& mut2_operand,
int64_t global_device_id, const DoEachT& DoEach);
template<OperandMemZoneModifier mem_zone_modifier, typename DoEachT>
void ForEachMutMirroredObject(const InterpretType interpret_type,
Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDeleteModifier, mem_zone_modifier>& mut2_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ForEachMutMirroredObject(
const InterpretType interpret_type, Id2LogicalObject* id2logical_object,
const ModifiedOperand<kDeleteModifier, mem_zone_modifier>& mut2_operand,
int64_t global_device_id, const DoEachT& DoEach);
void ConnectInstruction(Instruction* src_instruction, Instruction* dst_instruction);
RwMutexedObjectAccess* ConsumeMirroredObject(OperandAccessType access_type, MirroredObject* mirrored_object,
Instruction* instrution);
RwMutexedObjectAccess* ConsumeMirroredObject(OperandAccessType access_type,
MirroredObject* mirrored_object,
Instruction* instrution);
void ConsumeMirroredObjects(Id2LogicalObject* id2logical_object,
NewInstructionList* new_instruction_list);
void DispatchAndPrescheduleInstructions();
......@@ -178,27 +191,24 @@ INTRUSIVE_BEGIN(VirtualMachine);
ready_instruction_list_(),
vm_stat_running_instruction_list_(),
front_seq_compute_instr_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<VmResourceDesc>, vm_resource_desc_);
INTRUSIVE_DEFINE_FIELD(Range, machine_id_range_);
INTRUSIVE_DEFINE_FIELD(std::atomic<int64_t>, flying_instruction_cnt_);
intrusive::shared_ptr<VmResourceDesc> vm_resource_desc_;
Range machine_id_range_;
std::atomic<int64_t> flying_instruction_cnt_;
// lists or maps
// Do not change the order of the following fields
INTRUSIVE_DEFINE_FIELD(ActiveStreamList, active_stream_list_);
INTRUSIVE_DEFINE_FIELD(ThreadCtxList, thread_ctx_list_);
INTRUSIVE_DEFINE_FIELD(StreamTypeId2StreamRtDesc, stream_type_id2stream_rt_desc_);
INTRUSIVE_DEFINE_FIELD(Id2LogicalObject, id2logical_object_);
INTRUSIVE_DEFINE_FIELD(LogicalObjectDeleteList, delete_logical_object_list_);
INTRUSIVE_DEFINE_FIELD(InstructionMsgMutextList, pending_msg_list_);
INTRUSIVE_DEFINE_FIELD(InstructionList, waiting_instruction_list_);
INTRUSIVE_DEFINE_FIELD(ReadyInstructionList, ready_instruction_list_);
INTRUSIVE_DEFINE_FIELD(VmStatRunningInstructionList, vm_stat_running_instruction_list_);
// TODO(lixinqi): rename to sequential_instruction_list
INTRUSIVE_DEFINE_FIELD(FrontSeqInstructionList, front_seq_compute_instr_list_);
INTRUSIVE_END(VirtualMachine);
// clang-format on
// Do not change the order of the following fields
ActiveStreamList active_stream_list_;
ThreadCtxList thread_ctx_list_;
StreamTypeId2StreamRtDesc stream_type_id2stream_rt_desc_;
Id2LogicalObject id2logical_object_;
LogicalObjectDeleteList delete_logical_object_list_;
InstructionMsgMutextList pending_msg_list_;
InstructionList waiting_instruction_list_;
ReadyInstructionList ready_instruction_list_;
VmStatRunningInstructionList vm_stat_running_instruction_list_;
FrontSeqInstructionList front_seq_compute_instr_list_;
};
} // namespace vm
......
......@@ -25,12 +25,10 @@ limitations under the License.
namespace oneflow {
namespace vm {
// clang-format off
INTRUSIVE_BEGIN(VmDesc);
class VmDesc final : public intrusive::Base {
public:
// types
using StreamTypeId2StreamDesc =
intrusive::SkipList<INTRUSIVE_FIELD(StreamDesc, stream_type_id_)>;
using StreamTypeId2StreamDesc = intrusive::SkipList<INTRUSIVE_FIELD(StreamDesc, stream_type_id_)>;
// Getters
const VmResourceDesc& vm_resource_desc() const {
if (vm_resource_desc_) { return vm_resource_desc_.Get(); }
......@@ -39,7 +37,7 @@ INTRUSIVE_BEGIN(VmDesc);
}
const Range& machine_id_range() const { return machine_id_range_; }
const StreamTypeId2StreamDesc& stream_type_id2desc() const { return stream_type_id2desc_; }
//Setters
// Setters
VmResourceDesc* mut_vm_resource_desc() {
if (!vm_resource_desc_) { vm_resource_desc_ = intrusive::make_shared<VmResourceDesc>(); }
return vm_resource_desc_.Mutable();
......@@ -48,26 +46,24 @@ INTRUSIVE_BEGIN(VmDesc);
StreamTypeId2StreamDesc* mut_stream_type_id2desc() { return &stream_type_id2desc_; }
// methods
void __Init__(const VmResourceDesc& vm_resource_desc) {
__Init__(vm_resource_desc, Range(0, 1));
}
void __Init__(const VmResourceDesc& vm_resource_desc) { __Init__(vm_resource_desc, Range(0, 1)); }
void __Init__(const VmResourceDesc& vm_resource_desc, const Range& machine_id_range) {
mut_vm_resource_desc()->CopyFrom(vm_resource_desc);
*mut_machine_id_range() = machine_id_range;
}
private:
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
VmDesc() : intrusive_ref_(), vm_resource_desc_(), machine_id_range_(), stream_type_id2desc_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<VmResourceDesc>, vm_resource_desc_);
INTRUSIVE_DEFINE_FIELD(Range, machine_id_range_);
intrusive::shared_ptr<VmResourceDesc> vm_resource_desc_;
Range machine_id_range_;
// maps
INTRUSIVE_DEFINE_FIELD(StreamTypeId2StreamDesc, stream_type_id2desc_);
INTRUSIVE_END(VmDesc);
// clang-format on
StreamTypeId2StreamDesc stream_type_id2desc_;
};
intrusive::shared_ptr<VmDesc> MakeVmDesc(const Resource& resource, int64_t this_machine_id);
intrusive::shared_ptr<VmDesc> MakeVmDesc(const Resource& resource, int64_t this_machine_id,
......
......@@ -38,8 +38,7 @@ enum OperandAccessType {
kMutableOperandAccess,
};
// clang-format off
INTRUSIVE_BEGIN(RwMutexedObjectAccess);
class RwMutexedObjectAccess final : public intrusive::Base {
public:
void __Init__();
// Getters
......@@ -50,7 +49,9 @@ INTRUSIVE_BEGIN(RwMutexedObjectAccess);
const Instruction& instruction() const { return *instruction_; }
const MirroredObject& mirrored_object() const { return *mirrored_object_; }
const RwMutexedObject& rw_mutexed_object() const { return *rw_mutexed_object_; }
const intrusive::ListHook& rw_mutexed_object_access_hook() const { return rw_mutexed_object_access_hook_; }
const intrusive::ListHook& rw_mutexed_object_access_hook() const {
return rw_mutexed_object_access_hook_;
}
const MirroredObjectId& mirrored_object_id() const { return mirrored_object_id_.key().Get(); }
bool is_mirrored_object_id_inserted() const { return !mirrored_object_id_.empty(); }
......@@ -69,7 +70,7 @@ INTRUSIVE_BEGIN(RwMutexedObjectAccess);
// methods
void __Init__(Instruction* instruction, MirroredObject* mirrored_object,
OperandAccessType access_type);
OperandAccessType access_type);
bool is_const_operand() const;
bool is_mut_operand() const;
......@@ -78,28 +79,39 @@ INTRUSIVE_BEGIN(RwMutexedObjectAccess);
private:
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } // NOLINT
RwMutexedObjectAccess() : intrusive_ref_(), access_type_(), instruction_(), mirrored_object_(), rw_mutexed_object_(), instruction_access_hook_(), rw_mutexed_object_access_hook_(), mirrored_object_id_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; } // NOLINT
RwMutexedObjectAccess()
: intrusive_ref_(),
access_type_(),
instruction_(),
mirrored_object_(),
rw_mutexed_object_(),
mirrored_object_id_(),
instruction_access_hook_(),
rw_mutexed_object_access_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(OperandAccessType, access_type_);
INTRUSIVE_DEFINE_FIELD(Instruction*, instruction_);
INTRUSIVE_DEFINE_FIELD(MirroredObject*, mirrored_object_);
INTRUSIVE_DEFINE_FIELD(RwMutexedObject*, rw_mutexed_object_);
OperandAccessType access_type_;
Instruction* instruction_;
MirroredObject* mirrored_object_;
RwMutexedObject* rw_mutexed_object_;
public:
// skiplist hooks
intrusive::SkipListHook<FlatMsg<MirroredObjectId>, 10> mirrored_object_id_;
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, instruction_access_hook_);
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, rw_mutexed_object_access_hook_);
using MirroredObjectIdKey = intrusive::SkipListHook<FlatMsg<MirroredObjectId>, 10>;
INTRUSIVE_DEFINE_FIELD(MirroredObjectIdKey, mirrored_object_id_);
INTRUSIVE_END(RwMutexedObjectAccess); // NOLINT
intrusive::ListHook instruction_access_hook_;
intrusive::ListHook rw_mutexed_object_access_hook_;
}; // NOLINT
struct LogicalObject;
INTRUSIVE_BEGIN(RwMutexedObject);
class RwMutexedObject final : public intrusive::Base {
public:
void __Init__() {}
// types
using RwMutexedObjectAccessList = intrusive::List<INTRUSIVE_FIELD(RwMutexedObjectAccess, rw_mutexed_object_access_hook_)>;
using RwMutexedObjectAccessList =
intrusive::List<INTRUSIVE_FIELD(RwMutexedObjectAccess, rw_mutexed_object_access_hook_)>;
// Getters
const RwMutexedObjectAccessList& access_list() const { return access_list_; }
......@@ -107,26 +119,30 @@ INTRUSIVE_BEGIN(RwMutexedObject);
RwMutexedObjectAccessList* mut_access_list() { return &access_list_; }
// methods
template<typename T> bool Has() const {
template<typename T>
bool Has() const {
return dynamic_cast<const T*>(&object()) != nullptr;
}
template<typename T> Maybe<const T&> Get() const {
template<typename T>
Maybe<const T&> Get() const {
const T* obj = dynamic_cast<const T*>(&object());
const auto &origin_obj = *object_ptr_;
CHECK_NOTNULL_OR_RETURN(obj)
<< "cast to " << typeid(T).name() << "failed. "
<< "type: " << (object_ptr_ ? typeid(origin_obj).name() : "nullptr");
const auto& origin_obj = *object_ptr_;
CHECK_NOTNULL_OR_RETURN(obj) << "cast to " << typeid(T).name() << "failed. "
<< "type: "
<< (object_ptr_ ? typeid(origin_obj).name() : "nullptr");
return *obj;
}
template<typename T> Maybe<T*> Mut() {
template<typename T>
Maybe<T*> Mut() {
T* obj = dynamic_cast<T*>(object_ptr_.get());
const auto &origin_obj = *object_ptr_;
CHECK_NOTNULL_OR_RETURN(obj)
<< "cast to " << typeid(T).name() << "failed. "
<< "type: " << (object_ptr_ ? typeid(origin_obj).name() : "nullptr");
const auto& origin_obj = *object_ptr_;
CHECK_NOTNULL_OR_RETURN(obj) << "cast to " << typeid(T).name() << "failed. "
<< "type: "
<< (object_ptr_ ? typeid(origin_obj).name() : "nullptr");
return obj;
}
template<typename T, typename... Args> T* Init(Args&&... args) {
template<typename T, typename... Args>
T* Init(Args&&... args) {
T* object = dynamic_cast<T*>(object_ptr_.get());
CHECK(object == nullptr);
object = new T(std::forward<Args>(args)...);
......@@ -145,15 +161,15 @@ INTRUSIVE_BEGIN(RwMutexedObject);
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
RwMutexedObject() : intrusive_ref_(), object_ptr_(), access_list_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(std::unique_ptr<Object>, object_ptr_);
std::unique_ptr<Object> object_ptr_;
// list hooks
INTRUSIVE_DEFINE_FIELD(RwMutexedObjectAccessList, access_list_);
INTRUSIVE_END(RwMutexedObject);
RwMutexedObjectAccessList access_list_;
};
INTRUSIVE_BEGIN(MirroredObject);
class MirroredObject final : public intrusive::Base {
public:
// Getters
bool has_deleting_access() const { return deleting_access_ != nullptr; }
......@@ -182,7 +198,6 @@ INTRUSIVE_BEGIN(MirroredObject);
MirroredObjectId* mut_mirrored_object_id() { return mirrored_object_id_.Mutable(); }
void set_global_device_id(int64_t val) { *global_device_id_.mut_key() = val; }
// methods
void __Init__() { clear_deleting_access(); }
void __Init__(LogicalObject* logical_object, int64_t global_device_id);
......@@ -193,19 +208,25 @@ INTRUSIVE_BEGIN(MirroredObject);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
MirroredObject() : intrusive_ref_(), mirrored_object_id_(), rw_mutexed_object_(), deleting_access_(), global_device_id_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
//fields
INTRUSIVE_DEFINE_FIELD(FlatMsg<MirroredObjectId>, mirrored_object_id_);
INTRUSIVE_DEFINE_FIELD(intrusive::shared_ptr<RwMutexedObject>, rw_mutexed_object_);
INTRUSIVE_DEFINE_FIELD(RwMutexedObjectAccess*, deleting_access_);
// map hooks
using Int64Key = intrusive::SkipListHook<int64_t, 10>;
INTRUSIVE_DEFINE_FIELD(Int64Key, global_device_id_);
INTRUSIVE_END(MirroredObject);
MirroredObject()
: intrusive_ref_(),
mirrored_object_id_(),
rw_mutexed_object_(),
deleting_access_(),
global_device_id_() {}
intrusive::Ref intrusive_ref_;
// fields
FlatMsg<MirroredObjectId> mirrored_object_id_;
intrusive::shared_ptr<RwMutexedObject> rw_mutexed_object_;
RwMutexedObjectAccess* deleting_access_;
public:
// skiplist hooks
intrusive::SkipListHook<int64_t, 10> global_device_id_;
};
struct VirtualMachine;
INTRUSIVE_BEGIN(LogicalObject);
class LogicalObject final : public intrusive::Base {
public:
// types
using GlobalDeviceId2MirroredObject =
......@@ -225,12 +246,13 @@ INTRUSIVE_BEGIN(LogicalObject);
}
// methods
void __Init__() { /* Do nothing */ }
void __Init__() { /* Do nothing */
}
void __Init__(const ObjectId& logical_object_id) {
__Init__(logical_object_id, std::shared_ptr<const ParallelDesc>());
}
void __Init__(const ObjectId& logical_object_id,
const std::shared_ptr<const ParallelDesc>& parallel_desc) {
const std::shared_ptr<const ParallelDesc>& parallel_desc) {
set_logical_object_id(logical_object_id);
*mut_parallel_desc() = parallel_desc;
}
......@@ -241,19 +263,24 @@ INTRUSIVE_BEGIN(LogicalObject);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
LogicalObject() : intrusive_ref_(), parallel_desc_(), logical_object_id_(), delete_hook_(), global_device_id2mirrored_object_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
LogicalObject()
: intrusive_ref_(),
parallel_desc_(),
global_device_id2mirrored_object_(),
logical_object_id_(),
delete_hook_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(std::shared_ptr<const ParallelDesc>, parallel_desc_);
// map hooks
using ObjectIdKey = intrusive::SkipListHook<ObjectId, 24>;
INTRUSIVE_DEFINE_FIELD(ObjectIdKey, logical_object_id_);
// list hooks
INTRUSIVE_DEFINE_FIELD(intrusive::ListHook, delete_hook_);
std::shared_ptr<const ParallelDesc> parallel_desc_;
// maps
INTRUSIVE_DEFINE_FIELD(GlobalDeviceId2MirroredObject, global_device_id2mirrored_object_);
INTRUSIVE_END(LogicalObject);
// clang-format on
GlobalDeviceId2MirroredObject global_device_id2mirrored_object_;
public:
// skiplist hooks
intrusive::SkipListHook<ObjectId, 24> logical_object_id_;
// list hooks
intrusive::ListHook delete_hook_;
};
} // namespace vm
......
......@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_VM_VM_RESOURCE_DESC__H_
#define ONEFLOW_CORE_VM_VM_RESOURCE_DESC__H_
#ifndef ONEFLOW_CORE_VM_VM_RESOURCE_DESC_H_
#define ONEFLOW_CORE_VM_VM_RESOURCE_DESC_H_
#include <unordered_map>
#include "oneflow/core/intrusive/intrusive.h"
......@@ -28,24 +28,21 @@ namespace vm {
using DeviceTag2DeviceNum = std::unordered_map<std::string, int64_t>;
// clang-format off
INTRUSIVE_BEGIN(VmResourceDesc);
class VmResourceDesc final : public intrusive::Base {
public:
void __Init__() {}
// Getters
int64_t machine_num() const { return machine_num_; }
int64_t max_device_num_per_machine() const { return max_device_num_per_machine_; }
const DeviceTag2DeviceNum& device_tag2device_num() const { return device_tag2device_num_.Get(); }
const DeviceTag2DeviceNum& device_tag2device_num() const { return device_tag2device_num_; }
// Setters
void set_machine_num(int64_t val) { machine_num_ = val; }
void set_max_device_num_per_machine(int64_t val) { max_device_num_per_machine_ = val; }
DeviceTag2DeviceNum* mut_device_tag2device_num() { return device_tag2device_num_.Mutable(); }
DeviceTag2DeviceNum* mut_device_tag2device_num() { return &device_tag2device_num_; }
// methods
void __Init__(const Resource& resource);
void __Init__(
int64_t machine_num, const DeviceTag2DeviceNum& device_tag2device_num);
void __Init__(int64_t machine_num, const DeviceTag2DeviceNum& device_tag2device_num);
void CopyFrom(const VmResourceDesc& vm_resource_desc);
int64_t GetGlobalDeviceId(int64_t machine_id, int64_t device_id) const;
......@@ -53,17 +50,17 @@ INTRUSIVE_BEGIN(VmResourceDesc);
friend class intrusive::Ref;
intrusive::Ref* mut_intrusive_ref() { return &intrusive_ref_; }
VmResourceDesc() : intrusive_ref_(), machine_num_(), max_device_num_per_machine_(), device_tag2device_num_() {}
INTRUSIVE_DEFINE_FIELD(intrusive::Ref, intrusive_ref_);
VmResourceDesc()
: intrusive_ref_(), machine_num_(), max_device_num_per_machine_(), device_tag2device_num_() {}
intrusive::Ref intrusive_ref_;
// fields
INTRUSIVE_DEFINE_FIELD(int64_t, machine_num_);
INTRUSIVE_DEFINE_FIELD(int64_t, max_device_num_per_machine_);
int64_t machine_num_;
int64_t max_device_num_per_machine_;
// maps
INTRUSIVE_DEFINE_FIELD(intrusive::ForceStandardLayout<DeviceTag2DeviceNum>, device_tag2device_num_);
INTRUSIVE_END(VmResourceDesc);
// clang-format on
DeviceTag2DeviceNum device_tag2device_num_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_VM_VM_RESOURCE_DESC__H_
#endif // ONEFLOW_CORE_VM_VM_RESOURCE_DESC_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册