Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
比较版本
d338b271d7d2e10190f73c3ce913318793ac2c80...4376aba2d24897b963498da6a720fbc0b441cb94
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
源分支
4376aba2d24897b963498da6a720fbc0b441cb94
选择Git版本
...
目标分支
d338b271d7d2e10190f73c3ce913318793ac2c80
选择Git版本
比较
Commits (11)
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/808bf377696620cb40d61520a33e046c6d55d616
add --inplace (#6661)
2021-11-01T14:11:44+08:00
Shenghang Tsai
jackalcooper@gmail.com
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2e96920bbcf926768c92bbde579a6fe528c3b798
migrate parital fc op from lazy to functor (#6387)
2021-11-01T18:46:31+08:00
Yao Chi
later@usopp.net
* migrate partial_fc * add test and fix DistributedPariticalFCSample release bug * fix typos in functional_api.yaml * initialization * refine testcase * skip cpu-only test * reformat Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:1182563586@qq.com" title="1182563586@qq.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg5" style="text-decoration: none">N</a><a href="mailto:1182563586@qq.com" title="1182563586@qq.com">bbuf</a> <<a href="mailto:1182563586@qq.com" title="1182563586@qq.com">1182563586@qq.com</a>></span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg2" style="text-decoration: none">N</a><a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com">Xiaoyu Zhang</a> <<a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com">35585791+BBuf@users.noreply.github.com</a>></span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg2" style="text-decoration: none">N</a><a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com">Yinggang Wang</a> <<a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com">wyg19970408@gmail.com</a>></span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg6" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> <<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>></span>
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/f88c979a345d6fef5bae083daf0e536b4b848882
update speed test threshold (#6664)
2021-11-01T20:31:21+08:00
daquexian
daquexian566@gmail.com
Signed-off-by: <span data-trailer="Signed-off-by:"><a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg6" style="text-decoration: none">N</a><a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com">daquexian</a> <<a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com">daquexian566@gmail.com</a>></span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> <<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>></span>
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/21caffd9d94e70538a9035abdcf215405d045167
just macro: rename local variables to prevent shadowing (#6667)
2021-11-01T21:37:51+08:00
Twice
i@twice.moe
Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> <<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>></span>
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/8b94ac9b8fd0578aeed91a85c955a7f2a400b6aa
restruct reshape gradient funcs (#6634)
2021-11-02T04:15:23+00:00
Luyang
flowingsun007@163.com
* restruct * refine
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2247386074c82dd56248405b52095d9b1609bae8
Fix model update pass adam (#6673)
2021-11-02T18:11:41+08:00
ZZK
42901638+MARD1NO@users.noreply.github.com
* add first version of unary primitive op * fix * remove redundant file * Revert * fix format * use has input to check
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/55d32c333c8a298da5307bc1d219f48967fe5490
adjust GILForeignLockHelper order to avoid glog print to stderr (#6671)
2021-11-02T18:50:45+08:00
Xiaoyu Xu
xiaoyulink@gmail.com
Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> <<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>></span>
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/9ec6871dec4c38acb9badec2813d7617e1b0849f
modify by review
2021-11-02T19:37:21+08:00
leaves-zwx
kunta0932@gmail.com
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/91523d64b615146e9b0594a7c2d67b34863f899b
modify by review
2021-11-02T19:56:40+08:00
leaves-zwx
kunta0932@gmail.com
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/188d97504f5de980588a807ffa4a2649044d512d
fix
2021-11-02T20:27:58+08:00
leaves-zwx
kunta0932@gmail.com
https://gitcode.net/Oneflow-Inc/oneflow/-/commit/4376aba2d24897b963498da6a720fbc0b441cb94
Merge branch 'master' into ref_id_util
2021-11-02T20:31:26+08:00
leaves-zwx
kunta0932@gmail.com
隐藏空白更改
内联
并排
Showing
31 changed file
with
365 addition
and
168 deletion
+365
-168
README.md
README.md
+1
-1
ci/test/test_speed_multi_client.sh
ci/test/test_speed_multi_client.sh
+4
-4
oneflow/core/autograd/gradient_funcs/partial_fc_sample.cpp
oneflow/core/autograd/gradient_funcs/partial_fc_sample.cpp
+78
-0
oneflow/core/autograd/gradient_funcs/reshape.cpp
oneflow/core/autograd/gradient_funcs/reshape.cpp
+10
-5
oneflow/core/common/just.h
oneflow/core/common/just.h
+48
-48
oneflow/core/device/cpu_stream_index.cpp
oneflow/core/device/cpu_stream_index.cpp
+7
-7
oneflow/core/device/cpu_stream_index.h
oneflow/core/device/cpu_stream_index.h
+13
-13
oneflow/core/device/cuda_stream_index.cpp
oneflow/core/device/cuda_stream_index.cpp
+3
-3
oneflow/core/device/cuda_stream_index.h
oneflow/core/device/cuda_stream_index.h
+10
-10
oneflow/core/device/device_id.h
oneflow/core/device/device_id.h
+21
-16
oneflow/core/device/stream_index.h
oneflow/core/device/stream_index.h
+4
-4
oneflow/core/functional/functional_api.yaml
oneflow/core/functional/functional_api.yaml
+11
-0
oneflow/core/functional/impl/nn_functor.cpp
oneflow/core/functional/impl/nn_functor.cpp
+44
-0
oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp
...graph/boxing/collective_boxing_sub_task_graph_builder.cpp
+9
-9
oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
...ow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
+4
-3
oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
...core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
+2
-2
oneflow/core/graph/copy_task_node.cpp
oneflow/core/graph/copy_task_node.cpp
+2
-2
oneflow/core/graph/stream_index_getter_registry_manager.cpp
oneflow/core/graph/stream_index_getter_registry_manager.cpp
+1
-1
oneflow/core/graph/stream_index_getter_registry_manager.h
oneflow/core/graph/stream_index_getter_registry_manager.h
+1
-1
oneflow/core/graph/task_graph.cpp
oneflow/core/graph/task_graph.cpp
+11
-9
oneflow/core/graph/task_id.cpp
oneflow/core/graph/task_id.cpp
+5
-4
oneflow/core/graph/task_id.h
oneflow/core/graph/task_id.h
+7
-6
oneflow/core/graph/task_id_generator.h
oneflow/core/graph/task_id_generator.h
+1
-1
oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
+6
-0
oneflow/core/memory/memory_zone.cpp
oneflow/core/memory/memory_zone.cpp
+3
-3
oneflow/core/memory/memory_zone.h
oneflow/core/memory/memory_zone.h
+1
-1
oneflow/core/stream/stream_id.cpp
oneflow/core/stream/stream_id.cpp
+4
-3
oneflow/core/stream/stream_id.h
oneflow/core/stream/stream_id.h
+11
-10
oneflow/user/kernels/partial_fc_sample_kernel.cu
oneflow/user/kernels/partial_fc_sample_kernel.cu
+4
-1
python/oneflow/__init__.py
python/oneflow/__init__.py
+2
-1
python/oneflow/test/modules/test_parital_fc.py
python/oneflow/test/modules/test_parital_fc.py
+37
-0
未找到文件。
README.md
浏览文件 @
4376aba2
...
...
@@ -121,7 +121,7 @@ docker pull oneflowinc/oneflow:nightly-cuda11.1
-
In the root directory of OneFlow source code, run:
```
python3 docker/package/manylinux/build_wheel.py --python_version=3.6
python3 docker/package/manylinux/build_wheel.py --
inplace --
python_version=3.6
```
This should produce `.whl` files in the directory `wheelhouse`
...
...
ci/test/test_speed_multi_client.sh
浏览文件 @
4376aba2
...
...
@@ -18,13 +18,13 @@ function write_to_file_and_print {
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224
--no-show-memory
--times
100 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224
--no-show-memory
--times
100 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224
--no-show-memory
--times
200 | check_relative_speed 1.0
5
| write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224
--no-show-memory
--times
200 | check_relative_speed 1.0
9
| write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224
--no-show-memory
--times
200 | check_relative_speed 0.9
5
| write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224
--no-show-memory
--times
200 | check_relative_speed 1.0
1
| write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224
--no-show-memory
--times
200 | check_relative_speed 1.0
6
| write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224
--no-show-memory
--times
200 | check_relative_speed 0.9
4
| write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224
--no-show-memory
--times
100
--ddp
| check_relative_speed 0.99 | write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224
--no-show-memory
--times
100
--ddp
| check_relative_speed 0.99 | write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224
--no-show-memory
--times
200
--ddp
| check_relative_speed 0.9
3
| write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224
--no-show-memory
--times
200
--ddp
| check_relative_speed 0.9
1
| write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224
--no-show-memory
--times
200
--ddp
| check_relative_speed 0.83 | write_to_file_and_print
python3
-m
oneflow.distributed.launch
--nproc_per_node
2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224
--no-show-memory
--times
200
--ddp
| check_relative_speed 0.82 | write_to_file_and_print
...
...
oneflow/core/autograd/gradient_funcs/partial_fc_sample.cpp
0 → 100644
浏览文件 @
4376aba2
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace
oneflow
{
namespace
one
{
struct
PartialFCSampleState
:
public
AutoGradCaptureState
{
bool
requires_grad
=
false
;
int32_t
index_sampled_label
=
-
1
;
int32_t
index_weight
=
-
1
;
};
class
PartialFCSample
:
public
OpExprGradFunction
<
PartialFCSampleState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
;
Maybe
<
void
>
Capture
(
PartialFCSampleState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
;
Maybe
<
void
>
Apply
(
const
PartialFCSampleState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
;
private:
AttrMap
base_attrs_
;
};
Maybe
<
void
>
PartialFCSample
::
Init
(
const
OpExpr
&
op
)
{
const
UserOpExpr
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
CHECK_NOTNULL_OR_RETURN
(
fw_op_expr
);
base_attrs_
=
MakeAttrMapFromUserOpConf
(
fw_op_expr
->
proto
());
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
PartialFCSample
::
Capture
(
PartialFCSampleState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
{
ctx
->
requires_grad
=
inputs
.
at
(
0
)
->
requires_grad
();
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
ctx
->
index_sampled_label
=
ctx
->
SaveTensorForBackward
(
outputs
.
at
(
1
));
// sampled_label
ctx
->
index_weight
=
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
));
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
PartialFCSample
::
Apply
(
const
PartialFCSampleState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
{
CHECK_EQ_OR_RETURN
(
out_grads
.
size
(),
3
);
in_grads
->
resize
(
1
);
if
(
!
ctx
->
requires_grad
)
{
return
Maybe
<
void
>::
Ok
();
}
const
auto
&
diff_sampled_weight
=
out_grads
.
at
(
2
);
// diff of sampled_weight
const
auto
&
sampled_tensor
=
ctx
->
SavedTensors
().
at
(
ctx
->
index_sampled_label
);
const
auto
&
weight
=
ctx
->
SavedTensors
().
at
(
ctx
->
index_weight
);
const
auto
&
out_tensors_of_op0
=
JUST
(
functional
::
DistributedPariticalFCSampleDisableBoxing
(
diff_sampled_weight
,
sampled_tensor
));
const
auto
&
out_tensors_of_op1
=
JUST
(
functional
::
UnsortedSegmentSumLike
(
out_tensors_of_op0
->
at
(
0
),
out_tensors_of_op0
->
at
(
1
),
weight
,
0
));
in_grads
->
at
(
0
)
=
out_tensors_of_op1
;
return
Maybe
<
void
>::
Ok
();
}
REGISTER_OP_EXPR_GRAD_FUNCTION
(
"distributed_partial_fc_sample"
,
PartialFCSample
);
}
// namespace one
}
// namespace oneflow
oneflow/core/autograd/gradient_funcs/reshape.cpp
浏览文件 @
4376aba2
...
...
@@ -24,7 +24,11 @@ limitations under the License.
namespace
oneflow
{
namespace
one
{
class
ReshapeOpExprGrad
:
public
OpExprGradFunction
<
AutoGradCaptureState
>
{
struct
ReshapeCaptureState
:
public
AutoGradCaptureState
{
DimVector
input_shape_vec
;
};
class
ReshapeOpExprGrad
:
public
OpExprGradFunction
<
ReshapeCaptureState
>
{
public:
Maybe
<
void
>
Init
(
const
OpExpr
&
op
)
override
{
const
auto
*
fw_op_expr
=
dynamic_cast
<
const
UserOpExpr
*>
(
&
op
);
...
...
@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Capture
(
AutoGrad
CaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
Maybe
<
void
>
Capture
(
Reshape
CaptureState
*
ctx
,
const
TensorTuple
&
inputs
,
const
TensorTuple
&
outputs
,
const
AttrMap
&
attrs
)
const
override
{
ctx
->
SaveTensorForBackward
(
inputs
.
at
(
0
)
);
ctx
->
input_shape_vec
=
inputs
.
at
(
0
)
->
shape
()
->
dim_vec
(
);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
Apply
(
const
AutoGrad
CaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
Maybe
<
void
>
Apply
(
const
Reshape
CaptureState
*
ctx
,
const
TensorTuple
&
out_grads
,
TensorTuple
*
in_grads
)
const
override
{
const
auto
&
saved_tensors
=
ctx
->
SavedTensors
();
in_grads
->
resize
(
1
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
ReshapeLike
(
out_grads
.
at
(
0
),
saved_tensors
.
at
(
0
)));
Shape
shape
(
ctx
->
input_shape_vec
);
in_grads
->
at
(
0
)
=
JUST
(
functional
::
Reshape
(
out_grads
.
at
(
0
),
shape
));
return
Maybe
<
void
>::
Ok
();
}
};
...
...
oneflow/core/common/just.h
浏览文件 @
4376aba2
...
...
@@ -90,62 +90,62 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(
value_to_check_)>(value_to_check_);
\
#define JUST(...)
\
::oneflow::private_details::RemoveRValConst(({
\
auto&&
_just_
value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(
_just_
value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame(
\
::oneflow::private_details::JustGetError(
_just_
value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__));
\
}
\
std::forward<decltype(
_just_value_to_check_)>(_just_value_to_check_);
\
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \
([&](const char*
func_name) {
\
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
func_name, OF_PP_STRINGIZE(__VA_ARGS__)));
\
} \
return std::forward<decltype(
value_to_check_)>(value_to_check_);
\
})(__FUNCTION__) \
#define CHECK_JUST(...)
\
([&](const char*
_just_closure_func_name_) {
\
auto&&
_just_
value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(
_just_
value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError(
\
::oneflow::private_details::JustErrorAddStackFrame(
\
::oneflow::private_details::JustGetError(
_just_
value_to_check_), __FILE__, __LINE__, \
_just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__)));
\
}
\
return std::forward<decltype(
_just_value_to_check_)>(_just_value_to_check_);
\
})(__FUNCTION__)
\
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(
value_to_check_)>(value_to_check_);
\
#define JUST_MSG(value, ...)
\
::oneflow::private_details::RemoveRValConst(({
\
auto&&
_just_
value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(
_just_
value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage(
\
::oneflow::Error(::oneflow::private_details::JustGetError(
_just_
value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__),
\
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__);
\
}
\
std::forward<decltype(
_just_value_to_check_)>(_just_value_to_check_);
\
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \
([&](const char*
func_name) {
\
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__,
func_name),
\
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(
value_to_check_)>(value_to_check_);
\
})(__FUNCTION__) \
#define CHECK_JUST_MSG(value, ...)
\
([&](const char*
_just_closure_func_name_) {
\
auto&&
_just_
value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(
_just_
value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError(
\
::oneflow::private_details::JustErrorAddMessage(
\
::oneflow::Error(::oneflow::private_details::JustGetError(
_just_
value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__,
_just_closure_func_name_),
\
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__)
\
.error_proto());
\
}
\
return std::forward<decltype(
_just_value_to_check_)>(_just_value_to_check_);
\
})(__FUNCTION__)
\
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(
value_to_check_)>(value_to_check_);
\
#define JUST_OPT(...)
\
::oneflow::private_details::RemoveRValConst(({
\
auto&&
_just_
value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!
_just_
value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(
_just_value_to_check_)>(_just_value_to_check_);
\
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#else
...
...
oneflow/core/device/cpu_stream_index.cpp
浏览文件 @
4376aba2
...
...
@@ -30,19 +30,19 @@ CPUStreamIndexGenerator::CPUStreamIndexGenerator()
next_stream_index_
++
;
}
StreamIndexGenerator
::
index_t
CPUStreamIndexGenerator
::
GenerateComputeStreamIndex
()
{
StreamIndexGenerator
::
stream_
index_t
CPUStreamIndexGenerator
::
GenerateComputeStreamIndex
()
{
return
compute_stream_index_begin_
+
(
compute_stream_index_counter_
++
%
compute_stream_num_
);
}
StreamIndexGenerator
::
index_t
CPUStreamIndexGenerator
::
GenerateCommNetStreamIndex
()
{
StreamIndexGenerator
::
stream_
index_t
CPUStreamIndexGenerator
::
GenerateCommNetStreamIndex
()
{
return
comm_net_stream_index_
;
}
StreamIndexGenerator
::
index_t
CPUStreamIndexGenerator
::
GenerateTickTockStreamIndex
()
{
StreamIndexGenerator
::
stream_
index_t
CPUStreamIndexGenerator
::
GenerateTickTockStreamIndex
()
{
return
tick_tock_stream_index_
;
}
StreamIndexGenerator
::
index_t
CPUStreamIndexGenerator
::
GenerateIndependentTaskStreamIndex
(
StreamIndexGenerator
::
stream_
index_t
CPUStreamIndexGenerator
::
GenerateIndependentTaskStreamIndex
(
TaskType
task_type
)
{
auto
max_num_iter
=
task_type2max_stream_num_
.
end
();
if
(
IsClassRegistered
<
int32_t
,
IndependentThreadNum4TaskType
>
(
task_type
))
{
...
...
@@ -52,8 +52,8 @@ StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateIndependentTaskSt
max_num_iter
=
task_type2max_stream_num_
.
find
(
task_type
);
if
(
max_num_iter
==
task_type2max_stream_num_
.
end
())
{
task_type2max_stream_num_
.
emplace
(
task_type
,
max_num
);
CHECK
(
task_type2allocated_stream_index_vec_
.
emplace
(
task_type
,
std
::
vector
<
index_t
>
{})
.
second
);
CHECK
(
task_type2allocated_stream_index_vec_
.
emplace
(
task_type
,
std
::
vector
<
stream_index_t
>
{})
.
second
);
}
else
{
CHECK_EQ
(
max_num_iter
->
second
,
max_num
);
CHECK
(
task_type2allocated_stream_index_vec_
.
find
(
task_type
)
...
...
@@ -61,7 +61,7 @@ StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateIndependentTaskSt
}
}
index_t
index
=
next_stream_index_
;
stream_
index_t
index
=
next_stream_index_
;
if
(
max_num_iter
!=
task_type2max_stream_num_
.
end
())
{
auto
&
allocated_stream_index_vec
=
task_type2allocated_stream_index_vec_
[
task_type
];
if
(
allocated_stream_index_vec
.
size
()
<
max_num_iter
->
second
)
{
...
...
oneflow/core/device/cpu_stream_index.h
浏览文件 @
4376aba2
...
...
@@ -27,24 +27,24 @@ class CPUStreamIndexGenerator final : public StreamIndexGenerator {
OF_DISALLOW_COPY_AND_MOVE
(
CPUStreamIndexGenerator
);
~
CPUStreamIndexGenerator
()
=
default
;
index_t
GenerateComputeStreamIndex
()
override
;
index_t
GenerateH2DStreamIndex
()
override
{
UNIMPLEMENTED
();
}
index_t
GenerateD2HStreamIndex
()
override
{
UNIMPLEMENTED
();
}
index_t
GenerateCommNetStreamIndex
();
index_t
GenerateTickTockStreamIndex
();
index_t
GenerateIndependentTaskStreamIndex
(
TaskType
task_type
);
stream_
index_t
GenerateComputeStreamIndex
()
override
;
stream_
index_t
GenerateH2DStreamIndex
()
override
{
UNIMPLEMENTED
();
}
stream_
index_t
GenerateD2HStreamIndex
()
override
{
UNIMPLEMENTED
();
}
stream_
index_t
GenerateCommNetStreamIndex
();
stream_
index_t
GenerateTickTockStreamIndex
();
stream_
index_t
GenerateIndependentTaskStreamIndex
(
TaskType
task_type
);
private:
index_t
next_stream_index_
;
index_t
compute_stream_index_begin_
;
index_t
compute_stream_num_
;
index_t
comm_net_stream_index_
;
index_t
tick_tock_stream_index_
;
stream_
index_t
next_stream_index_
;
stream_
index_t
compute_stream_index_begin_
;
stream_
index_t
compute_stream_num_
;
stream_
index_t
comm_net_stream_index_
;
stream_
index_t
tick_tock_stream_index_
;
// for GenerateComputeStreamIndex
index_t
compute_stream_index_counter_
;
stream_
index_t
compute_stream_index_counter_
;
// for GenerateIndependentStreamIndex
HashMap
<
TaskType
,
size_t
>
task_type2max_stream_num_
;
HashMap
<
TaskType
,
std
::
vector
<
index_t
>>
task_type2allocated_stream_index_vec_
;
HashMap
<
TaskType
,
std
::
vector
<
stream_
index_t
>>
task_type2allocated_stream_index_vec_
;
HashMap
<
TaskType
,
size_t
>
task_type2allocated_stream_index_vec_index_
;
};
...
...
oneflow/core/device/cuda_stream_index.cpp
浏览文件 @
4376aba2
...
...
@@ -21,12 +21,12 @@ CudaStreamIndexGenerator::CudaStreamIndexGenerator() { next_stream_index_ = kD2H
CudaStreamIndexGenerator
::~
CudaStreamIndexGenerator
()
=
default
;
StreamIndexGenerator
::
index_t
CudaStreamIndexGenerator
::
GenerateNamedStreamIndex
(
StreamIndexGenerator
::
stream_
index_t
CudaStreamIndexGenerator
::
GenerateNamedStreamIndex
(
const
std
::
string
&
name
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
named_stream_index_mutex_
);
auto
it
=
named_stream_index_
.
find
(
name
);
if
(
it
==
named_stream_index_
.
end
())
{
index_t
index
=
next_stream_index_
;
stream_
index_t
index
=
next_stream_index_
;
next_stream_index_
+=
1
;
named_stream_index_
.
emplace
(
name
,
index
);
return
index
;
...
...
@@ -35,7 +35,7 @@ StreamIndexGenerator::index_t CudaStreamIndexGenerator::GenerateNamedStreamIndex
}
}
bool
CudaStreamIndexGenerator
::
IsNamedStreamIndex
(
const
std
::
string
&
name
,
index_t
index
)
{
bool
CudaStreamIndexGenerator
::
IsNamedStreamIndex
(
const
std
::
string
&
name
,
stream_
index_t
index
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
named_stream_index_mutex_
);
auto
it
=
named_stream_index_
.
find
(
name
);
if
(
it
==
named_stream_index_
.
end
())
{
...
...
oneflow/core/device/cuda_stream_index.h
浏览文件 @
4376aba2
...
...
@@ -25,19 +25,19 @@ class CudaStreamIndexGenerator final : public StreamIndexGenerator {
OF_DISALLOW_COPY_AND_MOVE
(
CudaStreamIndexGenerator
);
CudaStreamIndexGenerator
();
~
CudaStreamIndexGenerator
();
index_t
GenerateComputeStreamIndex
()
override
{
return
kCompute
;
}
index_t
GenerateH2DStreamIndex
()
override
{
return
kH2D
;
}
index_t
GenerateD2HStreamIndex
()
override
{
return
kD2H
;
}
index_t
GenerateNamedStreamIndex
(
const
std
::
string
&
name
);
bool
IsNamedStreamIndex
(
const
std
::
string
&
name
,
index_t
index
);
stream_
index_t
GenerateComputeStreamIndex
()
override
{
return
kCompute
;
}
stream_
index_t
GenerateH2DStreamIndex
()
override
{
return
kH2D
;
}
stream_
index_t
GenerateD2HStreamIndex
()
override
{
return
kD2H
;
}
stream_
index_t
GenerateNamedStreamIndex
(
const
std
::
string
&
name
);
bool
IsNamedStreamIndex
(
const
std
::
string
&
name
,
stream_
index_t
index
);
private:
static
const
index_t
kCompute
=
0
;
static
const
index_t
kH2D
=
1
;
static
const
index_t
kD2H
=
2
;
HashMap
<
std
::
string
,
index_t
>
named_stream_index_
;
static
const
stream_
index_t
kCompute
=
0
;
static
const
stream_
index_t
kH2D
=
1
;
static
const
stream_
index_t
kD2H
=
2
;
HashMap
<
std
::
string
,
stream_
index_t
>
named_stream_index_
;
std
::
mutex
named_stream_index_mutex_
;
index_t
next_stream_index_
;
stream_
index_t
next_stream_index_
;
};
}
// namespace oneflow
...
...
oneflow/core/device/device_id.h
浏览文件 @
4376aba2
...
...
@@ -29,27 +29,32 @@ namespace oneflow {
class
DeviceId
{
public:
using
index_t
=
uint32_t
;
using
node_index_t
=
uint32_t
;
using
device_type_t
=
uint32_t
;
using
device_index_t
=
uint32_t
;
constexpr
static
size_t
kNodeIndexBits
=
19
;
constexpr
static
size_t
kDeviceTypeBits
=
5
;
constexpr
static
size_t
kDeviceIndexBits
=
7
;
constexpr
static
index_t
kMaxNodeIndex
=
(
index_t
{
1
}
<<
kNodeIndexBits
)
-
index_t
{
1
};
constexpr
static
index_t
kMaxDeviceTypeVal
=
(
index_t
{
1
}
<<
kDeviceTypeBits
)
-
index_t
{
1
};
constexpr
static
index_t
kMaxDeviceIndex
=
(
index_t
{
1
}
<<
kDeviceIndexBits
)
-
index_t
{
1
};
DeviceId
(
index_t
node_index
,
DeviceType
device_type
,
index_t
device_index
)
constexpr
static
node_index_t
kMaxNodeIndex
=
(
node_index_t
{
1
}
<<
kNodeIndexBits
)
-
node_index_t
{
1
};
constexpr
static
device_type_t
kMaxDeviceTypeVal
=
(
device_type_t
{
1
}
<<
kDeviceTypeBits
)
-
device_type_t
{
1
};
constexpr
static
device_index_t
kMaxDeviceIndex
=
(
device_index_t
{
1
}
<<
kDeviceIndexBits
)
-
device_index_t
{
1
};
DeviceId
(
node_index_t
node_index
,
DeviceType
device_type
,
device_index_t
device_index
)
:
node_index_
(
node_index
),
device_type_
(
static_cast
<
index
_t
>
(
device_type
)),
device_type_
(
static_cast
<
device_type
_t
>
(
device_type
)),
device_index_
(
device_index
)
{
CHECK_LE
(
node_index_
,
kMaxNodeIndex
);
CHECK_LE
(
device_type_
,
kMaxDeviceTypeVal
);
CHECK_LE
(
device_index
,
kMaxDeviceIndex
);
CHECK_LE
(
device_index
_
,
kMaxDeviceIndex
);
}
index_t
node_index
()
const
{
return
node_index_
;
}
node_
index_t
node_index
()
const
{
return
node_index_
;
}
DeviceType
device_type
()
const
{
return
static_cast
<
DeviceType
>
(
device_type_
);
}
index_t
device_index
()
const
{
return
device_index_
;
}
device_
index_t
device_index
()
const
{
return
device_index_
;
}
bool
operator
==
(
const
DeviceId
&
rhs
)
const
{
return
node_index_
==
rhs
.
node_index_
&&
device_type_
==
rhs
.
device_type_
...
...
@@ -59,16 +64,16 @@ class DeviceId {
bool
operator
!=
(
const
DeviceId
&
rhs
)
const
{
return
!
(
*
this
==
rhs
);
}
size_t
hash
()
const
{
size_t
hash
=
std
::
hash
<
index_t
>
{}(
node_index_
);
HashCombine
(
&
hash
,
std
::
hash
<
index
_t
>
{}(
device_type_
));
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
device_index_
));
size_t
hash
=
std
::
hash
<
node_
index_t
>
{}(
node_index_
);
HashCombine
(
&
hash
,
std
::
hash
<
device_type
_t
>
{}(
device_type_
));
HashCombine
(
&
hash
,
std
::
hash
<
device_
index_t
>
{}(
device_index_
));
return
hash
;
}
private:
index_t
node_index_
;
index
_t
device_type_
;
index_t
device_index_
;
node_
index_t
node_index_
;
device_type
_t
device_type_
;
device_
index_t
device_index_
;
};
}
// namespace oneflow
...
...
oneflow/core/device/stream_index.h
浏览文件 @
4376aba2
...
...
@@ -25,11 +25,11 @@ namespace oneflow {
class
StreamIndexGenerator
{
public:
virtual
~
StreamIndexGenerator
()
{}
using
index_t
=
StreamId
::
index_t
;
using
stream_index_t
=
StreamId
::
stream_
index_t
;
virtual
index_t
GenerateComputeStreamIndex
()
=
0
;
virtual
index_t
GenerateH2DStreamIndex
()
=
0
;
virtual
index_t
GenerateD2HStreamIndex
()
=
0
;
virtual
stream_
index_t
GenerateComputeStreamIndex
()
=
0
;
virtual
stream_
index_t
GenerateH2DStreamIndex
()
=
0
;
virtual
stream_
index_t
GenerateD2HStreamIndex
()
=
0
;
};
class
StreamIndexGeneratorManager
final
{
...
...
oneflow/core/functional/functional_api.yaml
浏览文件 @
4376aba2
...
...
@@ -1511,6 +1511,17 @@
signature
:
"
TensorTuple
(Tensor
log_probs,
Tensor
input_lengths,
Bool
merge_repeated=True)
=>
CtcGreedyDecoder"
bind_python
:
True
-
name
:
"
distributed_partial_fc_sample"
signature
:
"
TensorTuple
(Tensor
weight,
Tensor
label,
Int64
num_sample)
=>
DistributedPariticalFCSample"
bind_python
:
True
-
name
:
"
distributed_partial_fc_sample_disable_boxing"
signature
:
"
TensorTuple
(Tensor
sampled_weight_diff,
Tensor
sampled_label)
=>
DistributedPariticalFCSampleDisableBoxing"
bind_python
:
False
-
name
:
"
meshgrid"
signature
:
"
TensorTuple
(TensorTuple
tensors)
=>
Meshgrid"
bind_python
:
True
oneflow/core/functional/impl/nn_functor.cpp
浏览文件 @
4376aba2
...
...
@@ -1872,6 +1872,48 @@ class CtcGreedyDecoderFunctor {
std
::
shared_ptr
<
OpExpr
>
op_
;
};
class
PartialFCSampleFunctor
{
public:
PartialFCSampleFunctor
()
{
op_
=
CHECK_JUST
(
one
::
OpBuilder
(
"distributed_partial_fc_sample"
)
.
Input
(
"weight"
)
.
Input
(
"label"
)
.
Output
(
"mapped_label"
)
.
Output
(
"sampled_label"
)
.
Output
(
"sampled_weight"
)
.
Build
());
}
Maybe
<
TensorTuple
>
operator
()(
const
std
::
shared_ptr
<
one
::
Tensor
>&
wegiht
,
const
std
::
shared_ptr
<
one
::
Tensor
>&
label
,
const
int64_t
&
num_sample
)
const
{
MutableAttrMap
attrs
;
JUST
(
attrs
.
SetAttr
<
int64_t
>
(
"num_sample"
,
num_sample
));
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op_
,
{
wegiht
,
label
},
attrs
);
}
private:
std
::
shared_ptr
<
OpExpr
>
op_
;
};
class
PariticalFCSampleDisableBoxing
{
public:
PariticalFCSampleDisableBoxing
()
{
op_
=
CHECK_JUST
(
one
::
OpBuilder
(
"distributed_partial_fc_sample_disable_boxing"
)
.
Input
(
"sampled_weight_diff"
)
.
Input
(
"sampled_label"
)
.
Output
(
"boxing_disabled_sampled_weight_diff"
)
.
Output
(
"boxing_disabled_sampled_label"
)
.
Build
());
}
Maybe
<
TensorTuple
>
operator
()(
const
std
::
shared_ptr
<
one
::
Tensor
>&
sampled_weight_diff
,
const
std
::
shared_ptr
<
one
::
Tensor
>&
sampled_label
)
const
{
return
OpInterpUtil
::
Dispatch
<
TensorTuple
>
(
*
op_
,
{
sampled_weight_diff
,
sampled_label
});
}
private:
std
::
shared_ptr
<
OpExpr
>
op_
;
};
}
// namespace impl
ONEFLOW_FUNCTION_LIBRARY
(
m
)
{
...
...
@@ -1932,6 +1974,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m
.
add_functor
<
impl
::
FusedBiasAddDropoutFunctor
>
(
"FusedBiasAddDropout"
);
m
.
add_functor
<
impl
::
FusedScaleTrilFunctor
>
(
"FusedScaleTril"
);
m
.
add_functor
<
impl
::
CtcGreedyDecoderFunctor
>
(
"CtcGreedyDecoder"
);
m
.
add_functor
<
impl
::
PartialFCSampleFunctor
>
(
"DistributedPariticalFCSample"
);
m
.
add_functor
<
impl
::
PariticalFCSampleDisableBoxing
>
(
"DistributedPariticalFCSampleDisableBoxing"
);
};
}
// namespace functional
...
...
oneflow/core/graph/boxing/collective_boxing_sub_task_graph_builder.cpp
浏览文件 @
4376aba2
...
...
@@ -65,8 +65,8 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const
int64_t
machine_id
=
CHECK_JUST
(
parallel_desc
.
MachineId4ParallelId
(
parallel_id
));
const
int64_t
device_index
=
CHECK_JUST
(
parallel_desc
.
DeviceId4ParallelId
(
parallel_id
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
dynamic_cast
<
CudaStreamIndexGenerator
*>
(
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
));
CHECK_NOTNULL
(
stream_index_generator
);
...
...
@@ -191,8 +191,8 @@ class NcclCollectiveBoxingP2SNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
in_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -293,8 +293,8 @@ class NcclCollectiveBoxingS2BNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
out_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
out_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -406,7 +406,7 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
SliceBoxingTaskNode
*
slice_node
=
ctx
->
task_graph
()
->
NewNode
<
SliceBoxingTaskNode
>
();
// slice on cpu
const
auto
in_machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
0
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
in_machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
in_machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -522,8 +522,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
FOR_RANGE
(
int64_t
,
i
,
0
,
in_parallel_desc
.
parallel_num
())
{
const
int64_t
machine_id
=
CHECK_JUST
(
in_parallel_desc
.
MachineId4ParallelId
(
i
));
const
int64_t
device_index
=
CHECK_JUST
(
in_parallel_desc
.
DeviceId4ParallelId
(
i
));
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
device_index
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
device_index
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/boxing/naive_b2p_sub_task_graph_builder.cpp
浏览文件 @
4376aba2
...
...
@@ -58,8 +58,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
int64_t
thrd_id
=
-
1
;
if
(
out_parallel_desc
.
device_type
()
==
DeviceType
::
kGPU
)
{
#ifdef WITH_CUDA
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
out_machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
index_t
>
(
out_dev_phy_id
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
out_machine_id
),
DeviceType
::
kGPU
,
static_cast
<
DeviceId
::
device_
index_t
>
(
out_dev_phy_id
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
@@ -68,7 +68,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
UNIMPLEMENTED
();
#endif
}
else
if
(
out_parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
)
{
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
out_machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
out_machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
浏览文件 @
4376aba2
...
...
@@ -61,8 +61,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
}
else
{
dev_id
=
CHECK_JUST
(
pd
.
DeviceId4ParallelId
(
parallel_id
));
}
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
pd
.
device_type
(),
static_cast
<
DeviceId
::
index_t
>
(
dev_id
)};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
pd
.
device_type
(),
static_cast
<
DeviceId
::
device_
index_t
>
(
dev_id
)};
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
auto
stream_index
=
stream_index_generator
->
GenerateComputeStreamIndex
();
...
...
oneflow/core/graph/copy_task_node.cpp
浏览文件 @
4376aba2
...
...
@@ -45,7 +45,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, const DeviceId& device_i
set_machine_id
(
device_id
.
node_index
());
auto
*
stream_index_generator
=
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
);
StreamId
::
index_t
stream_index
=
0
;
StreamId
::
stream_
index_t
stream_index
=
0
;
if
(
copy_type
==
CopyHdOpConf
::
H2D
)
{
stream_index
=
stream_index_generator
->
GenerateH2DStreamIndex
();
}
else
if
(
copy_type
==
CopyHdOpConf
::
D2H
)
{
...
...
@@ -84,7 +84,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
void
CopyCommNetTaskNode
::
Init
(
int64_t
machine_id
,
const
LogicalBlobId
&
lbi
)
{
set_machine_id
(
machine_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
DeviceType
::
kCPU
,
0
};
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_
index_t
>
(
machine_id
),
DeviceType
::
kCPU
,
0
};
auto
*
generator
=
dynamic_cast
<
CPUStreamIndexGenerator
*>
(
Global
<
IDMgr
>::
Get
()
->
GetStreamIndexGeneratorManager
()
->
GetGenerator
(
device_id
));
CHECK_NOTNULL
(
generator
);
...
...
oneflow/core/graph/stream_index_getter_registry_manager.cpp
浏览文件 @
4376aba2
...
...
@@ -22,7 +22,7 @@ StreamIndexGetterRegistryManager& StreamIndexGetterRegistryManager::Get() {
return
mgr
;
}
StreamId
::
index_t
StreamIndexGetterRegistryManager
::
StreamIndex4DeviceIdAndTaskType
(
StreamId
::
stream_
index_t
StreamIndexGetterRegistryManager
::
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
)
{
auto
index_getter_fn
=
StreamIndexGetterRegistryManager
::
GetStreamIndexGetterFunc
(
device_id
.
device_type
(),
task_type
);
...
...
oneflow/core/graph/stream_index_getter_registry_manager.h
浏览文件 @
4376aba2
...
...
@@ -47,7 +47,7 @@ class StreamIndexGetterRegistryManager final {
StreamIndexKeyMap
<
StreamIndexGetterFn
>&
StreamIndexGetterFuncs
();
StreamId
::
index_t
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
);
StreamId
::
stream_
index_t
StreamIndex4DeviceIdAndTaskType
(
DeviceId
device_id
,
TaskType
task_type
);
private:
StreamIndexGetterFn
GetStreamIndexGetterFunc
(
DeviceType
dev_type
,
TaskType
task_type
);
...
...
oneflow/core/graph/task_graph.cpp
浏览文件 @
4376aba2
...
...
@@ -284,16 +284,17 @@ void GenSortedCompTaskNodes(const OpNode* op_node, std::vector<CompTaskNode*>* s
comp_task_node
->
mut_parallel_ctx
()
->
set_parallel_id
(
parallel_idx
++
);
comp_task_node
->
mut_parallel_ctx
()
->
set_parallel_num
(
parallel_num
);
DeviceId
::
index_t
device_index
=
parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
index_t
>
(
dev_phy_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
index_t
>
(
machine_id
),
parallel_desc
.
device_type
(),
device_index
};
StreamId
::
index_t
stream_index
{};
DeviceId
::
device_index_t
device_index
=
parallel_desc
.
device_type
()
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
device_index_t
>
(
dev_phy_id
);
DeviceId
device_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
machine_id
),
parallel_desc
.
device_type
(),
device_index
};
StreamId
::
stream_index_t
stream_index
{};
if
(
op_node
->
op
().
op_conf
().
has_stream_index_hint
())
{
int32_t
stream_index_hint
=
op_node
->
op
().
op_conf
().
stream_index_hint
();
LOG
(
INFO
)
<<
"set op: "
<<
op_node
->
op
().
op_name
()
<<
" to stream: "
<<
stream_index_hint
;
stream_index
=
static_cast
<
StreamId
::
index_t
>
(
stream_index_hint
);
stream_index
=
static_cast
<
StreamId
::
stream_
index_t
>
(
stream_index_hint
);
}
else
{
stream_index
=
StreamIndexGetterRegistryManager
::
Get
().
StreamIndex4DeviceIdAndTaskType
(
device_id
,
comp_task_node
->
GetTaskType
());
...
...
@@ -522,8 +523,9 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const
int64_t
dev_id
=
CHECK_JUST
(
dst_parallel_desc
.
DeviceId4ParallelId
(
dst_parallel_id
));
DeviceType
device_type
=
dst_parallel_desc
.
device_type
();
auto
device_index
=
(
device_type
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
index_t
>
(
dev_id
));
MemZoneId
mem_zone_id
{
static_cast
<
MemZoneId
::
index_t
>
(
dst_machine_id
),
device_type
,
device_index
};
(
device_type
==
DeviceType
::
kCPU
?
0
:
static_cast
<
DeviceId
::
device_index_t
>
(
dev_id
));
MemZoneId
mem_zone_id
{
static_cast
<
MemZoneId
::
node_index_t
>
(
dst_machine_id
),
device_type
,
device_index
};
return
GetProxyNode
(
src_node
,
lbi
,
mem_zone_id
);
}
...
...
oneflow/core/graph/task_id.cpp
浏览文件 @
4376aba2
...
...
@@ -65,10 +65,11 @@ TaskId DecodeTaskIdFromInt64(int64_t task_id_val) {
int64_t
device_index
=
(
task_id_val
&
kDeviceIndexInt64Mask
)
>>
kDeviceIndexShift
;
int64_t
stream_index
=
(
task_id_val
&
kStreamIndexInt64Mask
)
>>
kStreamIndexShift
;
int64_t
task_index
=
task_id_val
&
kTaskIndexInt64Mask
;
StreamId
stream_id
{
static_cast
<
DeviceId
::
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
index_t
>
(
device_index
),
static_cast
<
StreamId
::
index_t
>
(
stream_index
)};
return
TaskId
{
stream_id
,
static_cast
<
TaskId
::
index_t
>
(
task_index
)};
StreamId
stream_id
{
static_cast
<
DeviceId
::
node_index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
device_index_t
>
(
device_index
),
static_cast
<
StreamId
::
stream_index_t
>
(
stream_index
)};
return
TaskId
{
stream_id
,
static_cast
<
TaskId
::
task_index_t
>
(
task_index
)};
}
int64_t
MachineId4ActorId
(
int64_t
actor_id
)
{
...
...
oneflow/core/graph/task_id.h
浏览文件 @
4376aba2
...
...
@@ -22,18 +22,19 @@ namespace oneflow {
class
TaskId
{
public:
using
index_t
=
uint32_t
;
using
task_
index_t
=
uint32_t
;
const
static
size_t
kTaskIndexBits
=
21
;
constexpr
static
index_t
kMaxTaskIndex
=
(
index_t
{
1
}
<<
kTaskIndexBits
)
-
index_t
{
1
};
constexpr
static
task_index_t
kMaxTaskIndex
=
(
task_index_t
{
1
}
<<
kTaskIndexBits
)
-
task_index_t
{
1
};
TaskId
(
const
StreamId
&
stream_id
,
index_t
task_index
)
TaskId
(
const
StreamId
&
stream_id
,
task_
index_t
task_index
)
:
stream_id_
(
stream_id
),
task_index_
(
task_index
)
{
CHECK_LE
(
task_index_
,
kMaxTaskIndex
);
}
const
StreamId
&
stream_id
()
const
{
return
stream_id_
;
}
index_t
task_index
()
const
{
return
task_index_
;
}
task_
index_t
task_index
()
const
{
return
task_index_
;
}
bool
operator
==
(
const
TaskId
&
rhs
)
const
{
return
stream_id_
==
rhs
.
stream_id_
&&
task_index_
==
rhs
.
task_index_
;
...
...
@@ -42,13 +43,13 @@ class TaskId {
size_t
hash
()
const
{
size_t
hash
=
stream_id_
.
hash
();
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
task_index_
));
HashCombine
(
&
hash
,
std
::
hash
<
task_
index_t
>
{}(
task_index_
));
return
hash
;
}
private:
StreamId
stream_id_
;
index_t
task_index_
;
task_
index_t
task_index_
;
};
int64_t
EncodeTaskIdToInt64
(
const
TaskId
&
);
...
...
oneflow/core/graph/task_id_generator.h
浏览文件 @
4376aba2
...
...
@@ -22,7 +22,7 @@ namespace oneflow {
class
TaskIdGenerator
final
{
public:
using
task_index_t
=
TaskId
::
index_t
;
using
task_index_t
=
TaskId
::
task_
index_t
;
TaskIdGenerator
()
=
default
;
OF_DISALLOW_COPY_AND_MOVE
(
TaskIdGenerator
);
...
...
oneflow/core/job_rewriter/fuse_update_ops_pass.cpp
浏览文件 @
4376aba2
...
...
@@ -173,6 +173,12 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.
Attr
<
float
>
(
"beta1"
,
user_op_conf
.
attr
<
float
>
(
"beta1"
))
.
Attr
<
float
>
(
"beta2"
,
user_op_conf
.
attr
<
float
>
(
"beta2"
))
.
Attr
<
float
>
(
"epsilon"
,
user_op_conf
.
attr
<
float
>
(
"epsilon"
));
if
(
user_op_conf
.
has_input
(
"bias_correction1"
,
0
))
{
fused_op_builder
.
Input
(
"bias_correction1"
,
user_op_conf
.
input
(
"bias_correction1"
,
0
));
}
if
(
user_op_conf
.
has_input
(
"bias_correction2"
,
0
))
{
fused_op_builder
.
Input
(
"bias_correction2"
,
user_op_conf
.
input
(
"bias_correction2"
,
0
));
}
}
else
if
(
user_op_conf
.
op_type_name
()
==
"rmsprop_update"
)
{
const
bool
centered
=
user_op_conf
.
attr
<
bool
>
(
"centered"
);
fused_op_builder
.
Input
(
"mean_square"
,
user_op_conf
.
input
(
"mean_square"
,
0.
f
))
...
...
oneflow/core/memory/memory_zone.cpp
浏览文件 @
4376aba2
...
...
@@ -32,7 +32,7 @@ constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDe
const
MemZoneId
kInvalidMemZoneId
=
MemZoneId
{
0
,
DeviceType
::
kInvalidDevice
,
0
};
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
index_t
node_index
)
{
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
node_
index_t
node_index
)
{
return
MemZoneId
{
node_index
,
DeviceType
::
kCPU
,
0
};
}
...
...
@@ -47,9 +47,9 @@ MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t
node_index
=
(
mem_zone_id
&
kMemZoneIdNodeIndexInt64Mask
)
>>
kMemZoneIdNodeIndexShift
;
int64_t
device_type
=
(
mem_zone_id
&
kMemZoneIdDeviceTypeInt64Mask
)
>>
kMemZoneIdDeviceTypeShift
;
int64_t
device_index
=
mem_zone_id
&
kMemZoneIdDeviceIndexInt64Mask
;
return
MemZoneId
(
static_cast
<
MemZoneId
::
index_t
>
(
node_index
),
return
MemZoneId
(
static_cast
<
MemZoneId
::
node_
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
MemZoneId
::
index_t
>
(
device_index
));
static_cast
<
MemZoneId
::
device_
index_t
>
(
device_index
));
}
}
// namespace oneflow
oneflow/core/memory/memory_zone.h
浏览文件 @
4376aba2
...
...
@@ -25,7 +25,7 @@ using MemZoneId = DeviceId;
int64_t
EncodeMemZoneIdToInt64
(
const
MemZoneId
&
);
MemZoneId
DecodeMemZoneIdFromInt64
(
int64_t
);
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
index_t
node_index
);
MemZoneId
GetNodeCPUMemZoneId
(
MemZoneId
::
node_
index_t
node_index
);
extern
const
MemZoneId
kInvalidMemZoneId
;
...
...
oneflow/core/stream/stream_id.cpp
浏览文件 @
4376aba2
...
...
@@ -59,9 +59,10 @@ StreamId DecodeStreamIdFromInt64(int64_t stream_id_val) {
int64_t
device_type
=
(
stream_id_val
&
kDeviceTypeInt64Mask
)
>>
kDeviceTypeShift
;
int64_t
device_index
=
(
stream_id_val
&
kDeviceIndexInt64Mask
)
>>
kDeviceIndexShift
;
int64_t
stream_index
=
(
stream_id_val
&
kStreamIndexInt64Mask
);
return
StreamId
{
static_cast
<
DeviceId
::
index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
index_t
>
(
device_index
),
static_cast
<
StreamId
::
index_t
>
(
stream_index
)};
return
StreamId
{
static_cast
<
DeviceId
::
node_index_t
>
(
node_index
),
static_cast
<
DeviceType
>
(
device_type
),
static_cast
<
DeviceId
::
device_index_t
>
(
device_index
),
static_cast
<
StreamId
::
stream_index_t
>
(
stream_index
)};
}
}
// namespace oneflow
oneflow/core/stream/stream_id.h
浏览文件 @
4376aba2
...
...
@@ -22,26 +22,27 @@ namespace oneflow {
class
StreamId
{
public:
using
index_t
=
uint32_t
;
using
stream_
index_t
=
uint32_t
;
constexpr
static
size_t
kStreamIndexBits
=
12
;
constexpr
static
index_t
kMaxStreamIndex
=
(
index_t
{
1
}
<<
kStreamIndexBits
)
-
index_t
{
1
};
constexpr
static
stream_index_t
kMaxStreamIndex
=
(
stream_index_t
{
1
}
<<
kStreamIndexBits
)
-
stream_index_t
{
1
};
StreamId
(
const
DeviceId
&
device_id
,
index_t
stream_index
)
StreamId
(
const
DeviceId
&
device_id
,
stream_
index_t
stream_index
)
:
device_id_
(
device_id
),
stream_index_
(
stream_index
)
{
CHECK_LE
(
stream_index
,
kMaxStreamIndex
);
}
StreamId
(
DeviceId
::
index_t
node_index
,
DeviceType
device_type
,
DeviceId
::
index_t
device_index
,
index_t
stream_index
)
StreamId
(
DeviceId
::
node_index_t
node_index
,
DeviceType
device_type
,
DeviceId
::
device_index_t
device_index
,
stream_
index_t
stream_index
)
:
device_id_
(
node_index
,
device_type
,
device_index
),
stream_index_
(
stream_index
)
{
CHECK_LE
(
stream_index
,
kMaxStreamIndex
);
}
const
DeviceId
&
device_id
()
const
{
return
device_id_
;
}
DeviceId
::
index_t
node_index
()
const
{
return
device_id_
.
node_index
();
}
DeviceId
::
node_
index_t
node_index
()
const
{
return
device_id_
.
node_index
();
}
DeviceType
device_type
()
const
{
return
device_id_
.
device_type
();
}
DeviceId
::
index_t
device_index
()
const
{
return
device_id_
.
device_index
();
}
index_t
stream_index
()
const
{
return
stream_index_
;
}
DeviceId
::
device_
index_t
device_index
()
const
{
return
device_id_
.
device_index
();
}
stream_
index_t
stream_index
()
const
{
return
stream_index_
;
}
bool
operator
==
(
const
StreamId
&
rhs
)
const
{
return
device_id_
==
rhs
.
device_id_
&&
stream_index_
==
rhs
.
stream_index_
;
...
...
@@ -51,13 +52,13 @@ class StreamId {
size_t
hash
()
const
{
size_t
hash
=
device_id_
.
hash
();
HashCombine
(
&
hash
,
std
::
hash
<
index_t
>
{}(
stream_index_
));
HashCombine
(
&
hash
,
std
::
hash
<
stream_
index_t
>
{}(
stream_index_
));
return
hash
;
}
private:
DeviceId
device_id_
;
index_t
stream_index_
;
stream_
index_t
stream_index_
;
};
int64_t
EncodeStreamIdToInt64
(
const
StreamId
&
);
...
...
oneflow/user/kernels/partial_fc_sample_kernel.cu
浏览文件 @
4376aba2
...
...
@@ -152,7 +152,10 @@ class DistributedPartialFcSampleOpKernelState final : public user_op::OpKernelSt
SetupKernel
<<<
BlocksNum4ThreadsNum
(
num_classes
),
kCudaThreadsNumPerBlock
,
0
,
ctx
->
cuda_stream
()
>>>
(
seed
,
curand_states_
);
}
~
DistributedPartialFcSampleOpKernelState
()
{
OF_CUDA_CHECK
(
cudaFree
(
curand_states_
));
};
~
DistributedPartialFcSampleOpKernelState
()
{
cudaError_t
ret
=
cudaFree
(
curand_states_
);
if
(
ret
!=
cudaErrorCudartUnloading
)
{
OF_CUDA_CHECK
(
ret
);
}
};
int64_t
lower
()
const
{
return
lower_
;
}
int64_t
upper
()
const
{
return
upper_
;
}
...
...
python/oneflow/__init__.py
浏览文件 @
4376aba2
...
...
@@ -128,6 +128,7 @@ from oneflow._C import softplus
from
oneflow._C
import
tril
from
oneflow._C
import
triu
from
oneflow._C
import
pad
from
oneflow._C
import
distributed_partial_fc_sample
from
oneflow._C
import
transpose
from
oneflow._C
import
relu
from
oneflow._C
import
softmax
...
...
@@ -155,7 +156,6 @@ import oneflow.framework.register_python_callback
INVALID_SPLIT_AXIS
=
oneflow
.
_oneflow_internal
.
INVALID_SPLIT_AXIS
register_class_method_util
.
RegisterMethod4Class
()
oneflow
.
_oneflow_internal
.
RegisterGILForeignLockHelper
()
import
oneflow.framework.env_util
as
env_util
import
oneflow.framework.scope_util
as
scope_util
import
oneflow.framework.session_context
as
session_ctx
...
...
@@ -165,6 +165,7 @@ if not env_util.HasAllMultiClientEnvVars():
env_util
.
SetDefaultMultiClientEnvVars
()
oneflow
.
_oneflow_internal
.
SetIsMultiClient
(
True
)
env_util
.
api_env_init
()
oneflow
.
_oneflow_internal
.
RegisterGILForeignLockHelper
()
oneflow
.
_oneflow_internal
.
InitDefaultConsistentTransportTokenScope
()
session_ctx
.
OpenDefaultSession
(
MultiClientSession
(
oneflow
.
_oneflow_internal
.
NewSessionId
())
...
...
python/oneflow/test/modules/test_parital_fc.py
0 → 100644
浏览文件 @
4376aba2
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
unittest
from
oneflow.test_utils.automated_test_util
import
*
import
oneflow
as
flow
import
oneflow.unittest
@
unittest
.
skipIf
(
os
.
getenv
(
"ONEFLOW_TEST_CPU_ONLY"
),
"only test cpu cases"
)
class
TestParitalFC
(
flow
.
unittest
.
TestCase
):
def
test_parital_fc
(
test_case
):
p
=
flow
.
env
.
all_device_placement
(
"cuda"
)
w
=
flow
.
randn
(
50000
,
128
,
placement
=
p
,
sbp
=
flow
.
sbp
.
broadcast
)
label
=
flow
.
randint
(
0
,
50000
,
(
512
,),
placement
=
p
,
sbp
=
flow
.
sbp
.
broadcast
)
num_sample
=
5000
out
=
flow
.
distributed_partial_fc_sample
(
w
,
label
,
num_sample
)
test_case
.
assertTrue
(
out
[
0
].
shape
==
flow
.
Size
([
512
]))
test_case
.
assertTrue
(
out
[
1
].
shape
==
flow
.
Size
([
5000
]))
test_case
.
assertTrue
(
out
[
2
].
shape
==
flow
.
Size
([
5000
,
128
]))
if
__name__
==
"__main__"
:
unittest
.
main
()