...
 
Commits (11)
    https://gitcode.net/Oneflow-Inc/oneflow/-/commit/808bf377696620cb40d61520a33e046c6d55d616 add --inplace (#6661) 2021-11-01T14:11:44+08:00 Shenghang Tsai jackalcooper@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2e96920bbcf926768c92bbde579a6fe528c3b798 migrate parital fc op from lazy to functor (#6387) 2021-11-01T18:46:31+08:00 Yao Chi later@usopp.net * migrate partial_fc * add test and fix DistributedPariticalFCSample release bug * fix typos in functional_api.yaml * initialization * refine testcase * skip cpu-only test * reformat Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:1182563586@qq.com" title="1182563586@qq.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg5" style="text-decoration: none">N</a><a href="mailto:1182563586@qq.com" title="1182563586@qq.com">bbuf</a> &lt;<a href="mailto:1182563586@qq.com" title="1182563586@qq.com">1182563586@qq.com</a>&gt;</span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg2" style="text-decoration: none">N</a><a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com">Xiaoyu Zhang</a> &lt;<a href="mailto:35585791+BBuf@users.noreply.github.com" title="35585791+BBuf@users.noreply.github.com">35585791+BBuf@users.noreply.github.com</a>&gt;</span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg2" style="text-decoration: none">N</a><a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com">Yinggang Wang</a> &lt;<a href="mailto:wyg19970408@gmail.com" title="wyg19970408@gmail.com">wyg19970408@gmail.com</a>&gt;</span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg6" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/f88c979a345d6fef5bae083daf0e536b4b848882 update speed test threshold (#6664) 2021-11-01T20:31:21+08:00 daquexian daquexian566@gmail.com Signed-off-by: <span data-trailer="Signed-off-by:"><a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg6" style="text-decoration: none">N</a><a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com">daquexian</a> &lt;<a href="mailto:daquexian566@gmail.com" title="daquexian566@gmail.com">daquexian566@gmail.com</a>&gt;</span> Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/21caffd9d94e70538a9035abdcf215405d045167 just macro: rename local variables to prevent shadowing (#6667) 2021-11-01T21:37:51+08:00 Twice i@twice.moe Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/8b94ac9b8fd0578aeed91a85c955a7f2a400b6aa restruct reshape gradient funcs (#6634) 2021-11-02T04:15:23+00:00 Luyang flowingsun007@163.com * restruct * refine https://gitcode.net/Oneflow-Inc/oneflow/-/commit/2247386074c82dd56248405b52095d9b1609bae8 Fix model update pass adam (#6673) 2021-11-02T18:11:41+08:00 ZZK 42901638+MARD1NO@users.noreply.github.com * add first version of unary primitive op * fix * remove redundant file * Revert * fix format * use has input to check https://gitcode.net/Oneflow-Inc/oneflow/-/commit/55d32c333c8a298da5307bc1d219f48967fe5490 adjust GILForeignLockHelper order to avoid glog print to stderr (#6671) 2021-11-02T18:50:45+08:00 Xiaoyu Xu xiaoyulink@gmail.com Co-authored-by: <span data-trailer="Co-authored-by:"><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com"></a><a href="javascript:void(0)" class="avatar s16 avatar-inline identicon bg1" style="text-decoration: none">N</a><a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">oneflow-ci-bot</a> &lt;<a href="mailto:69100618+oneflow-ci-bot@users.noreply.github.com" title="69100618+oneflow-ci-bot@users.noreply.github.com">69100618+oneflow-ci-bot@users.noreply.github.com</a>&gt;</span> https://gitcode.net/Oneflow-Inc/oneflow/-/commit/9ec6871dec4c38acb9badec2813d7617e1b0849f modify by review 2021-11-02T19:37:21+08:00 leaves-zwx kunta0932@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/91523d64b615146e9b0594a7c2d67b34863f899b modify by review 2021-11-02T19:56:40+08:00 leaves-zwx kunta0932@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/188d97504f5de980588a807ffa4a2649044d512d fix 2021-11-02T20:27:58+08:00 leaves-zwx kunta0932@gmail.com https://gitcode.net/Oneflow-Inc/oneflow/-/commit/4376aba2d24897b963498da6a720fbc0b441cb94 Merge branch 'master' into ref_id_util 2021-11-02T20:31:26+08:00 leaves-zwx kunta0932@gmail.com
......@@ -121,7 +121,7 @@ docker pull oneflowinc/oneflow:nightly-cuda11.1
- In the root directory of OneFlow source code, run:
```
python3 docker/package/manylinux/build_wheel.py --python_version=3.6
python3 docker/package/manylinux/build_wheel.py --inplace --python_version=3.6
```
This should produce `.whl` files in the directory `wheelhouse`
......
......@@ -18,13 +18,13 @@ function write_to_file_and_print {
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.05 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.09 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.95 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.01 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 | check_relative_speed 1.06 | write_to_file_and_print
python3 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 | check_relative_speed 0.94 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 16x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 8x3x224x224 --no-show-memory --times 100 --ddp | check_relative_speed 0.99 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.93 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 4x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.91 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 2x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.83 | write_to_file_and_print
python3 -m oneflow.distributed.launch --nproc_per_node 2 scripts/compare_speed_with_pytorch.py resnet50/models/resnet50.py resnet50 1x3x224x224 --no-show-memory --times 200 --ddp | check_relative_speed 0.82 | write_to_file_and_print
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct PartialFCSampleState : public AutoGradCaptureState {
bool requires_grad = false;
int32_t index_sampled_label = -1;
int32_t index_weight = -1;
};
class PartialFCSample : public OpExprGradFunction<PartialFCSampleState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> PartialFCSample::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> PartialFCSample::Capture(PartialFCSampleState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->index_sampled_label = ctx->SaveTensorForBackward(outputs.at(1)); // sampled_label
ctx->index_weight = ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Maybe<void> PartialFCSample::Apply(const PartialFCSampleState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
CHECK_EQ_OR_RETURN(out_grads.size(), 3);
in_grads->resize(1);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& diff_sampled_weight = out_grads.at(2); // diff of sampled_weight
const auto& sampled_tensor = ctx->SavedTensors().at(ctx->index_sampled_label);
const auto& weight = ctx->SavedTensors().at(ctx->index_weight);
const auto& out_tensors_of_op0 = JUST(
functional::DistributedPariticalFCSampleDisableBoxing(diff_sampled_weight, sampled_tensor));
const auto& out_tensors_of_op1 = JUST(functional::UnsortedSegmentSumLike(
out_tensors_of_op0->at(0), out_tensors_of_op0->at(1), weight, 0));
in_grads->at(0) = out_tensors_of_op1;
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("distributed_partial_fc_sample", PartialFCSample);
} // namespace one
} // namespace oneflow
......@@ -24,7 +24,11 @@ limitations under the License.
namespace oneflow {
namespace one {
class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
struct ReshapeCaptureState : public AutoGradCaptureState {
DimVector input_shape_vec;
};
class ReshapeOpExprGrad : public OpExprGradFunction<ReshapeCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
......@@ -32,17 +36,18 @@ class ReshapeOpExprGrad : public OpExprGradFunction<AutoGradCaptureState> {
return Maybe<void>::Ok();
}
Maybe<void> Capture(AutoGradCaptureState* ctx, const TensorTuple& inputs,
Maybe<void> Capture(ReshapeCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
ctx->input_shape_vec = inputs.at(0)->shape()->dim_vec();
return Maybe<void>::Ok();
}
Maybe<void> Apply(const AutoGradCaptureState* ctx, const TensorTuple& out_grads,
Maybe<void> Apply(const ReshapeCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
const auto& saved_tensors = ctx->SavedTensors();
in_grads->resize(1);
in_grads->at(0) = JUST(functional::ReshapeLike(out_grads.at(0), saved_tensors.at(0)));
Shape shape(ctx->input_shape_vec);
in_grads->at(0) = JUST(functional::Reshape(out_grads.at(0), shape));
return Maybe<void>::Ok();
}
};
......
......@@ -90,62 +90,62 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \
([&](const char* func_name) { \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(value_to_check_), __FILE__, __LINE__, \
func_name, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \
return std::forward<decltype(value_to_check_)>(value_to_check_); \
})(__FUNCTION__) \
#define CHECK_JUST(...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \
_just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \
([&](const char* func_name) { \
auto&& value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, func_name), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(value_to_check_)>(value_to_check_); \
})(__FUNCTION__) \
#define CHECK_JUST_MSG(value, ...) \
([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \
::oneflow::private_details::JustErrorAddMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \
.error_proto()); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(value_to_check_)>(value_to_check_); \
#define JUST_OPT(...) \
::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!_just_value_to_check_.has_value()) { return NullOpt; } \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#else
......
......@@ -30,19 +30,19 @@ CPUStreamIndexGenerator::CPUStreamIndexGenerator()
next_stream_index_++;
}
StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateComputeStreamIndex() {
StreamIndexGenerator::stream_index_t CPUStreamIndexGenerator::GenerateComputeStreamIndex() {
return compute_stream_index_begin_ + (compute_stream_index_counter_++ % compute_stream_num_);
}
StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateCommNetStreamIndex() {
StreamIndexGenerator::stream_index_t CPUStreamIndexGenerator::GenerateCommNetStreamIndex() {
return comm_net_stream_index_;
}
StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateTickTockStreamIndex() {
StreamIndexGenerator::stream_index_t CPUStreamIndexGenerator::GenerateTickTockStreamIndex() {
return tick_tock_stream_index_;
}
StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateIndependentTaskStreamIndex(
StreamIndexGenerator::stream_index_t CPUStreamIndexGenerator::GenerateIndependentTaskStreamIndex(
TaskType task_type) {
auto max_num_iter = task_type2max_stream_num_.end();
if (IsClassRegistered<int32_t, IndependentThreadNum4TaskType>(task_type)) {
......@@ -52,8 +52,8 @@ StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateIndependentTaskSt
max_num_iter = task_type2max_stream_num_.find(task_type);
if (max_num_iter == task_type2max_stream_num_.end()) {
task_type2max_stream_num_.emplace(task_type, max_num);
CHECK(
task_type2allocated_stream_index_vec_.emplace(task_type, std::vector<index_t>{}).second);
CHECK(task_type2allocated_stream_index_vec_.emplace(task_type, std::vector<stream_index_t>{})
.second);
} else {
CHECK_EQ(max_num_iter->second, max_num);
CHECK(task_type2allocated_stream_index_vec_.find(task_type)
......@@ -61,7 +61,7 @@ StreamIndexGenerator::index_t CPUStreamIndexGenerator::GenerateIndependentTaskSt
}
}
index_t index = next_stream_index_;
stream_index_t index = next_stream_index_;
if (max_num_iter != task_type2max_stream_num_.end()) {
auto& allocated_stream_index_vec = task_type2allocated_stream_index_vec_[task_type];
if (allocated_stream_index_vec.size() < max_num_iter->second) {
......
......@@ -27,24 +27,24 @@ class CPUStreamIndexGenerator final : public StreamIndexGenerator {
OF_DISALLOW_COPY_AND_MOVE(CPUStreamIndexGenerator);
~CPUStreamIndexGenerator() = default;
index_t GenerateComputeStreamIndex() override;
index_t GenerateH2DStreamIndex() override { UNIMPLEMENTED(); }
index_t GenerateD2HStreamIndex() override { UNIMPLEMENTED(); }
index_t GenerateCommNetStreamIndex();
index_t GenerateTickTockStreamIndex();
index_t GenerateIndependentTaskStreamIndex(TaskType task_type);
stream_index_t GenerateComputeStreamIndex() override;
stream_index_t GenerateH2DStreamIndex() override { UNIMPLEMENTED(); }
stream_index_t GenerateD2HStreamIndex() override { UNIMPLEMENTED(); }
stream_index_t GenerateCommNetStreamIndex();
stream_index_t GenerateTickTockStreamIndex();
stream_index_t GenerateIndependentTaskStreamIndex(TaskType task_type);
private:
index_t next_stream_index_;
index_t compute_stream_index_begin_;
index_t compute_stream_num_;
index_t comm_net_stream_index_;
index_t tick_tock_stream_index_;
stream_index_t next_stream_index_;
stream_index_t compute_stream_index_begin_;
stream_index_t compute_stream_num_;
stream_index_t comm_net_stream_index_;
stream_index_t tick_tock_stream_index_;
// for GenerateComputeStreamIndex
index_t compute_stream_index_counter_;
stream_index_t compute_stream_index_counter_;
// for GenerateIndependentStreamIndex
HashMap<TaskType, size_t> task_type2max_stream_num_;
HashMap<TaskType, std::vector<index_t>> task_type2allocated_stream_index_vec_;
HashMap<TaskType, std::vector<stream_index_t>> task_type2allocated_stream_index_vec_;
HashMap<TaskType, size_t> task_type2allocated_stream_index_vec_index_;
};
......
......@@ -21,12 +21,12 @@ CudaStreamIndexGenerator::CudaStreamIndexGenerator() { next_stream_index_ = kD2H
CudaStreamIndexGenerator::~CudaStreamIndexGenerator() = default;
StreamIndexGenerator::index_t CudaStreamIndexGenerator::GenerateNamedStreamIndex(
StreamIndexGenerator::stream_index_t CudaStreamIndexGenerator::GenerateNamedStreamIndex(
const std::string& name) {
std::lock_guard<std::mutex> lock(named_stream_index_mutex_);
auto it = named_stream_index_.find(name);
if (it == named_stream_index_.end()) {
index_t index = next_stream_index_;
stream_index_t index = next_stream_index_;
next_stream_index_ += 1;
named_stream_index_.emplace(name, index);
return index;
......@@ -35,7 +35,7 @@ StreamIndexGenerator::index_t CudaStreamIndexGenerator::GenerateNamedStreamIndex
}
}
bool CudaStreamIndexGenerator::IsNamedStreamIndex(const std::string& name, index_t index) {
bool CudaStreamIndexGenerator::IsNamedStreamIndex(const std::string& name, stream_index_t index) {
std::lock_guard<std::mutex> lock(named_stream_index_mutex_);
auto it = named_stream_index_.find(name);
if (it == named_stream_index_.end()) {
......
......@@ -25,19 +25,19 @@ class CudaStreamIndexGenerator final : public StreamIndexGenerator {
OF_DISALLOW_COPY_AND_MOVE(CudaStreamIndexGenerator);
CudaStreamIndexGenerator();
~CudaStreamIndexGenerator();
index_t GenerateComputeStreamIndex() override { return kCompute; }
index_t GenerateH2DStreamIndex() override { return kH2D; }
index_t GenerateD2HStreamIndex() override { return kD2H; }
index_t GenerateNamedStreamIndex(const std::string& name);
bool IsNamedStreamIndex(const std::string& name, index_t index);
stream_index_t GenerateComputeStreamIndex() override { return kCompute; }
stream_index_t GenerateH2DStreamIndex() override { return kH2D; }
stream_index_t GenerateD2HStreamIndex() override { return kD2H; }
stream_index_t GenerateNamedStreamIndex(const std::string& name);
bool IsNamedStreamIndex(const std::string& name, stream_index_t index);
private:
static const index_t kCompute = 0;
static const index_t kH2D = 1;
static const index_t kD2H = 2;
HashMap<std::string, index_t> named_stream_index_;
static const stream_index_t kCompute = 0;
static const stream_index_t kH2D = 1;
static const stream_index_t kD2H = 2;
HashMap<std::string, stream_index_t> named_stream_index_;
std::mutex named_stream_index_mutex_;
index_t next_stream_index_;
stream_index_t next_stream_index_;
};
} // namespace oneflow
......
......@@ -29,27 +29,32 @@ namespace oneflow {
class DeviceId {
public:
using index_t = uint32_t;
using node_index_t = uint32_t;
using device_type_t = uint32_t;
using device_index_t = uint32_t;
constexpr static size_t kNodeIndexBits = 19;
constexpr static size_t kDeviceTypeBits = 5;
constexpr static size_t kDeviceIndexBits = 7;
constexpr static index_t kMaxNodeIndex = (index_t{1} << kNodeIndexBits) - index_t{1};
constexpr static index_t kMaxDeviceTypeVal = (index_t{1} << kDeviceTypeBits) - index_t{1};
constexpr static index_t kMaxDeviceIndex = (index_t{1} << kDeviceIndexBits) - index_t{1};
DeviceId(index_t node_index, DeviceType device_type, index_t device_index)
constexpr static node_index_t kMaxNodeIndex =
(node_index_t{1} << kNodeIndexBits) - node_index_t{1};
constexpr static device_type_t kMaxDeviceTypeVal =
(device_type_t{1} << kDeviceTypeBits) - device_type_t{1};
constexpr static device_index_t kMaxDeviceIndex =
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1};
DeviceId(node_index_t node_index, DeviceType device_type, device_index_t device_index)
: node_index_(node_index),
device_type_(static_cast<index_t>(device_type)),
device_type_(static_cast<device_type_t>(device_type)),
device_index_(device_index) {
CHECK_LE(node_index_, kMaxNodeIndex);
CHECK_LE(device_type_, kMaxDeviceTypeVal);
CHECK_LE(device_index, kMaxDeviceIndex);
CHECK_LE(device_index_, kMaxDeviceIndex);
}
index_t node_index() const { return node_index_; }
node_index_t node_index() const { return node_index_; }
DeviceType device_type() const { return static_cast<DeviceType>(device_type_); }
index_t device_index() const { return device_index_; }
device_index_t device_index() const { return device_index_; }
bool operator==(const DeviceId& rhs) const {
return node_index_ == rhs.node_index_ && device_type_ == rhs.device_type_
......@@ -59,16 +64,16 @@ class DeviceId {
bool operator!=(const DeviceId& rhs) const { return !(*this == rhs); }
size_t hash() const {
size_t hash = std::hash<index_t>{}(node_index_);
HashCombine(&hash, std::hash<index_t>{}(device_type_));
HashCombine(&hash, std::hash<index_t>{}(device_index_));
size_t hash = std::hash<node_index_t>{}(node_index_);
HashCombine(&hash, std::hash<device_type_t>{}(device_type_));
HashCombine(&hash, std::hash<device_index_t>{}(device_index_));
return hash;
}
private:
index_t node_index_;
index_t device_type_;
index_t device_index_;
node_index_t node_index_;
device_type_t device_type_;
device_index_t device_index_;
};
} // namespace oneflow
......
......@@ -25,11 +25,11 @@ namespace oneflow {
class StreamIndexGenerator {
public:
virtual ~StreamIndexGenerator() {}
using index_t = StreamId::index_t;
using stream_index_t = StreamId::stream_index_t;
virtual index_t GenerateComputeStreamIndex() = 0;
virtual index_t GenerateH2DStreamIndex() = 0;
virtual index_t GenerateD2HStreamIndex() = 0;
virtual stream_index_t GenerateComputeStreamIndex() = 0;
virtual stream_index_t GenerateH2DStreamIndex() = 0;
virtual stream_index_t GenerateD2HStreamIndex() = 0;
};
class StreamIndexGeneratorManager final {
......
......@@ -1511,6 +1511,17 @@
signature: "TensorTuple (Tensor log_probs, Tensor input_lengths, Bool merge_repeated=True) => CtcGreedyDecoder"
bind_python: True
- name: "distributed_partial_fc_sample"
signature:
"TensorTuple (Tensor weight, Tensor label, Int64 num_sample) => DistributedPariticalFCSample"
bind_python: True
- name: "distributed_partial_fc_sample_disable_boxing"
signature:
"TensorTuple (Tensor sampled_weight_diff, Tensor sampled_label) => DistributedPariticalFCSampleDisableBoxing"
bind_python: False
- name: "meshgrid"
signature: "TensorTuple (TensorTuple tensors) => Meshgrid"
bind_python: True
......@@ -1872,6 +1872,48 @@ class CtcGreedyDecoderFunctor {
std::shared_ptr<OpExpr> op_;
};
class PartialFCSampleFunctor {
public:
PartialFCSampleFunctor() {
op_ = CHECK_JUST(one::OpBuilder("distributed_partial_fc_sample")
.Input("weight")
.Input("label")
.Output("mapped_label")
.Output("sampled_label")
.Output("sampled_weight")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& wegiht,
const std::shared_ptr<one::Tensor>& label,
const int64_t& num_sample) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("num_sample", num_sample));
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {wegiht, label}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class PariticalFCSampleDisableBoxing {
public:
PariticalFCSampleDisableBoxing() {
op_ = CHECK_JUST(one::OpBuilder("distributed_partial_fc_sample_disable_boxing")
.Input("sampled_weight_diff")
.Input("sampled_label")
.Output("boxing_disabled_sampled_weight_diff")
.Output("boxing_disabled_sampled_label")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& sampled_weight_diff,
const std::shared_ptr<one::Tensor>& sampled_label) const {
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {sampled_weight_diff, sampled_label});
}
private:
std::shared_ptr<OpExpr> op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
......@@ -1932,6 +1974,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::FusedBiasAddDropoutFunctor>("FusedBiasAddDropout");
m.add_functor<impl::FusedScaleTrilFunctor>("FusedScaleTril");
m.add_functor<impl::CtcGreedyDecoderFunctor>("CtcGreedyDecoder");
m.add_functor<impl::PartialFCSampleFunctor>("DistributedPariticalFCSample");
m.add_functor<impl::PariticalFCSampleDisableBoxing>("DistributedPariticalFCSampleDisableBoxing");
};
} // namespace functional
......
......@@ -65,8 +65,8 @@ void NcclInitCollectiveNode(CollectiveBoxingGenericTaskNode* node,
const int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
const int64_t device_index = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator = dynamic_cast<CudaStreamIndexGenerator*>(
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id));
CHECK_NOTNULL(stream_index_generator);
......@@ -191,8 +191,8 @@ class NcclCollectiveBoxingP2SNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -293,8 +293,8 @@ class NcclCollectiveBoxingS2BNoncontinuousSubTskGphBuilder final : public SubTsk
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -406,7 +406,7 @@ class CollectiveBoxingScatterThenNcclAllGatherSubTskGphBuilder final : public Su
SliceBoxingTaskNode* slice_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();
// slice on cpu
const auto in_machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(0));
DeviceId device_id{static_cast<DeviceId::index_t>(in_machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(in_machine_id), DeviceType::kCPU, 0};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -522,8 +522,8 @@ class NcclCollectiveBoxingAll2AllSubTskGphBuilder final : public SubTskGphBuilde
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
const int64_t machine_id = CHECK_JUST(in_parallel_desc.MachineId4ParallelId(i));
const int64_t device_index = CHECK_JUST(in_parallel_desc.DeviceId4ParallelId(i));
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(device_index)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(device_index)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -58,8 +58,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
int64_t thrd_id = -1;
if (out_parallel_desc.device_type() == DeviceType::kGPU) {
#ifdef WITH_CUDA
DeviceId device_id{static_cast<DeviceId::index_t>(out_machine_id), DeviceType::kGPU,
static_cast<DeviceId::index_t>(out_dev_phy_id)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(out_machine_id), DeviceType::kGPU,
static_cast<DeviceId::device_index_t>(out_dev_phy_id)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......@@ -68,7 +68,8 @@ Maybe<SubTskGphBuilderStatus> NaiveB2PSubTskGphBuilder::Build(
UNIMPLEMENTED();
#endif
} else if (out_parallel_desc.device_type() == DeviceType::kCPU) {
DeviceId device_id{static_cast<DeviceId::index_t>(out_machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(out_machine_id), DeviceType::kCPU,
0};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -61,8 +61,8 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
} else {
dev_id = CHECK_JUST(pd.DeviceId4ParallelId(parallel_id));
}
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), pd.device_type(),
static_cast<DeviceId::index_t>(dev_id)};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), pd.device_type(),
static_cast<DeviceId::device_index_t>(dev_id)};
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
auto stream_index = stream_index_generator->GenerateComputeStreamIndex();
......
......@@ -45,7 +45,7 @@ void CopyHdTaskNode::Init(CopyHdOpConf::Type copy_type, const DeviceId& device_i
set_machine_id(device_id.node_index());
auto* stream_index_generator =
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id);
StreamId::index_t stream_index = 0;
StreamId::stream_index_t stream_index = 0;
if (copy_type == CopyHdOpConf::H2D) {
stream_index = stream_index_generator->GenerateH2DStreamIndex();
} else if (copy_type == CopyHdOpConf::D2H) {
......@@ -84,7 +84,7 @@ OperatorConf CopyHdTaskNode::NewCopyOpConf() {
void CopyCommNetTaskNode::Init(int64_t machine_id, const LogicalBlobId& lbi) {
set_machine_id(machine_id);
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), DeviceType::kCPU, 0};
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id), DeviceType::kCPU, 0};
auto* generator = dynamic_cast<CPUStreamIndexGenerator*>(
Global<IDMgr>::Get()->GetStreamIndexGeneratorManager()->GetGenerator(device_id));
CHECK_NOTNULL(generator);
......
......@@ -22,7 +22,7 @@ StreamIndexGetterRegistryManager& StreamIndexGetterRegistryManager::Get() {
return mgr;
}
StreamId::index_t StreamIndexGetterRegistryManager::StreamIndex4DeviceIdAndTaskType(
StreamId::stream_index_t StreamIndexGetterRegistryManager::StreamIndex4DeviceIdAndTaskType(
DeviceId device_id, TaskType task_type) {
auto index_getter_fn = StreamIndexGetterRegistryManager::GetStreamIndexGetterFunc(
device_id.device_type(), task_type);
......
......@@ -47,7 +47,7 @@ class StreamIndexGetterRegistryManager final {
StreamIndexKeyMap<StreamIndexGetterFn>& StreamIndexGetterFuncs();
StreamId::index_t StreamIndex4DeviceIdAndTaskType(DeviceId device_id, TaskType task_type);
StreamId::stream_index_t StreamIndex4DeviceIdAndTaskType(DeviceId device_id, TaskType task_type);
private:
StreamIndexGetterFn GetStreamIndexGetterFunc(DeviceType dev_type, TaskType task_type);
......
......@@ -284,16 +284,17 @@ void GenSortedCompTaskNodes(const OpNode* op_node, std::vector<CompTaskNode*>* s
comp_task_node->mut_parallel_ctx()->set_parallel_id(parallel_idx++);
comp_task_node->mut_parallel_ctx()->set_parallel_num(parallel_num);
DeviceId::index_t device_index = parallel_desc.device_type() == DeviceType::kCPU
? 0
: static_cast<DeviceId::index_t>(dev_phy_id);
DeviceId device_id{static_cast<DeviceId::index_t>(machine_id), parallel_desc.device_type(),
device_index};
StreamId::index_t stream_index{};
DeviceId::device_index_t device_index =
parallel_desc.device_type() == DeviceType::kCPU
? 0
: static_cast<DeviceId::device_index_t>(dev_phy_id);
DeviceId device_id{static_cast<DeviceId::node_index_t>(machine_id),
parallel_desc.device_type(), device_index};
StreamId::stream_index_t stream_index{};
if (op_node->op().op_conf().has_stream_index_hint()) {
int32_t stream_index_hint = op_node->op().op_conf().stream_index_hint();
LOG(INFO) << "set op: " << op_node->op().op_name() << " to stream: " << stream_index_hint;
stream_index = static_cast<StreamId::index_t>(stream_index_hint);
stream_index = static_cast<StreamId::stream_index_t>(stream_index_hint);
} else {
stream_index = StreamIndexGetterRegistryManager::Get().StreamIndex4DeviceIdAndTaskType(
device_id, comp_task_node->GetTaskType());
......@@ -522,8 +523,9 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
DeviceType device_type = dst_parallel_desc.device_type();
auto device_index =
(device_type == DeviceType::kCPU ? 0 : static_cast<DeviceId::index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::index_t>(dst_machine_id), device_type, device_index};
(device_type == DeviceType::kCPU ? 0 : static_cast<DeviceId::device_index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::node_index_t>(dst_machine_id), device_type,
device_index};
return GetProxyNode(src_node, lbi, mem_zone_id);
}
......
......@@ -65,10 +65,11 @@ TaskId DecodeTaskIdFromInt64(int64_t task_id_val) {
int64_t device_index = (task_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;
int64_t stream_index = (task_id_val & kStreamIndexInt64Mask) >> kStreamIndexShift;
int64_t task_index = task_id_val & kTaskIndexInt64Mask;
StreamId stream_id{
static_cast<DeviceId::index_t>(node_index), static_cast<DeviceType>(device_type),
static_cast<DeviceId::index_t>(device_index), static_cast<StreamId::index_t>(stream_index)};
return TaskId{stream_id, static_cast<TaskId::index_t>(task_index)};
StreamId stream_id{static_cast<DeviceId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<DeviceId::device_index_t>(device_index),
static_cast<StreamId::stream_index_t>(stream_index)};
return TaskId{stream_id, static_cast<TaskId::task_index_t>(task_index)};
}
int64_t MachineId4ActorId(int64_t actor_id) {
......
......@@ -22,18 +22,19 @@ namespace oneflow {
class TaskId {
public:
using index_t = uint32_t;
using task_index_t = uint32_t;
const static size_t kTaskIndexBits = 21;
constexpr static index_t kMaxTaskIndex = (index_t{1} << kTaskIndexBits) - index_t{1};
constexpr static task_index_t kMaxTaskIndex =
(task_index_t{1} << kTaskIndexBits) - task_index_t{1};
TaskId(const StreamId& stream_id, index_t task_index)
TaskId(const StreamId& stream_id, task_index_t task_index)
: stream_id_(stream_id), task_index_(task_index) {
CHECK_LE(task_index_, kMaxTaskIndex);
}
const StreamId& stream_id() const { return stream_id_; }
index_t task_index() const { return task_index_; }
task_index_t task_index() const { return task_index_; }
bool operator==(const TaskId& rhs) const {
return stream_id_ == rhs.stream_id_ && task_index_ == rhs.task_index_;
......@@ -42,13 +43,13 @@ class TaskId {
size_t hash() const {
size_t hash = stream_id_.hash();
HashCombine(&hash, std::hash<index_t>{}(task_index_));
HashCombine(&hash, std::hash<task_index_t>{}(task_index_));
return hash;
}
private:
StreamId stream_id_;
index_t task_index_;
task_index_t task_index_;
};
int64_t EncodeTaskIdToInt64(const TaskId&);
......
......@@ -22,7 +22,7 @@ namespace oneflow {
class TaskIdGenerator final {
public:
using task_index_t = TaskId::index_t;
using task_index_t = TaskId::task_index_t;
TaskIdGenerator() = default;
OF_DISALLOW_COPY_AND_MOVE(TaskIdGenerator);
......
......@@ -173,6 +173,12 @@ Maybe<void> FuseUpdateOpsPass::Apply(const OpGraph& op_graph, JobBuilder* job_bu
.Attr<float>("beta1", user_op_conf.attr<float>("beta1"))
.Attr<float>("beta2", user_op_conf.attr<float>("beta2"))
.Attr<float>("epsilon", user_op_conf.attr<float>("epsilon"));
if (user_op_conf.has_input("bias_correction1", 0)) {
fused_op_builder.Input("bias_correction1", user_op_conf.input("bias_correction1", 0));
}
if (user_op_conf.has_input("bias_correction2", 0)) {
fused_op_builder.Input("bias_correction2", user_op_conf.input("bias_correction2", 0));
}
} else if (user_op_conf.op_type_name() == "rmsprop_update") {
const bool centered = user_op_conf.attr<bool>("centered");
fused_op_builder.Input("mean_square", user_op_conf.input("mean_square", 0.f))
......
......@@ -32,7 +32,7 @@ constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDe
const MemZoneId kInvalidMemZoneId = MemZoneId{0, DeviceType::kInvalidDevice, 0};
MemZoneId GetNodeCPUMemZoneId(MemZoneId::index_t node_index) {
MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index) {
return MemZoneId{node_index, DeviceType::kCPU, 0};
}
......@@ -47,9 +47,9 @@ MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t node_index = (mem_zone_id & kMemZoneIdNodeIndexInt64Mask) >> kMemZoneIdNodeIndexShift;
int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift;
int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask;
return MemZoneId(static_cast<MemZoneId::index_t>(node_index),
return MemZoneId(static_cast<MemZoneId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<MemZoneId::index_t>(device_index));
static_cast<MemZoneId::device_index_t>(device_index));
}
} // namespace oneflow
......@@ -25,7 +25,7 @@ using MemZoneId = DeviceId;
int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);
MemZoneId GetNodeCPUMemZoneId(MemZoneId::index_t node_index);
MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index);
extern const MemZoneId kInvalidMemZoneId;
......
......@@ -59,9 +59,10 @@ StreamId DecodeStreamIdFromInt64(int64_t stream_id_val) {
int64_t device_type = (stream_id_val & kDeviceTypeInt64Mask) >> kDeviceTypeShift;
int64_t device_index = (stream_id_val & kDeviceIndexInt64Mask) >> kDeviceIndexShift;
int64_t stream_index = (stream_id_val & kStreamIndexInt64Mask);
return StreamId{static_cast<DeviceId::index_t>(node_index), static_cast<DeviceType>(device_type),
static_cast<DeviceId::index_t>(device_index),
static_cast<StreamId::index_t>(stream_index)};
return StreamId{static_cast<DeviceId::node_index_t>(node_index),
static_cast<DeviceType>(device_type),
static_cast<DeviceId::device_index_t>(device_index),
static_cast<StreamId::stream_index_t>(stream_index)};
}
} // namespace oneflow
......@@ -22,26 +22,27 @@ namespace oneflow {
class StreamId {
public:
using index_t = uint32_t;
using stream_index_t = uint32_t;
constexpr static size_t kStreamIndexBits = 12;
constexpr static index_t kMaxStreamIndex = (index_t{1} << kStreamIndexBits) - index_t{1};
constexpr static stream_index_t kMaxStreamIndex =
(stream_index_t{1} << kStreamIndexBits) - stream_index_t{1};
StreamId(const DeviceId& device_id, index_t stream_index)
StreamId(const DeviceId& device_id, stream_index_t stream_index)
: device_id_(device_id), stream_index_(stream_index) {
CHECK_LE(stream_index, kMaxStreamIndex);
}
StreamId(DeviceId::index_t node_index, DeviceType device_type, DeviceId::index_t device_index,
index_t stream_index)
StreamId(DeviceId::node_index_t node_index, DeviceType device_type,
DeviceId::device_index_t device_index, stream_index_t stream_index)
: device_id_(node_index, device_type, device_index), stream_index_(stream_index) {
CHECK_LE(stream_index, kMaxStreamIndex);
}
const DeviceId& device_id() const { return device_id_; }
DeviceId::index_t node_index() const { return device_id_.node_index(); }
DeviceId::node_index_t node_index() const { return device_id_.node_index(); }
DeviceType device_type() const { return device_id_.device_type(); }
DeviceId::index_t device_index() const { return device_id_.device_index(); }
index_t stream_index() const { return stream_index_; }
DeviceId::device_index_t device_index() const { return device_id_.device_index(); }
stream_index_t stream_index() const { return stream_index_; }
bool operator==(const StreamId& rhs) const {
return device_id_ == rhs.device_id_ && stream_index_ == rhs.stream_index_;
......@@ -51,13 +52,13 @@ class StreamId {
size_t hash() const {
size_t hash = device_id_.hash();
HashCombine(&hash, std::hash<index_t>{}(stream_index_));
HashCombine(&hash, std::hash<stream_index_t>{}(stream_index_));
return hash;
}
private:
DeviceId device_id_;
index_t stream_index_;
stream_index_t stream_index_;
};
int64_t EncodeStreamIdToInt64(const StreamId&);
......
......@@ -152,7 +152,10 @@ class DistributedPartialFcSampleOpKernelState final : public user_op::OpKernelSt
SetupKernel<<<BlocksNum4ThreadsNum(num_classes), kCudaThreadsNumPerBlock, 0,
ctx->cuda_stream()>>>(seed, curand_states_);
}
~DistributedPartialFcSampleOpKernelState() { OF_CUDA_CHECK(cudaFree(curand_states_)); };
~DistributedPartialFcSampleOpKernelState() {
cudaError_t ret = cudaFree(curand_states_);
if (ret != cudaErrorCudartUnloading) { OF_CUDA_CHECK(ret); }
};
int64_t lower() const { return lower_; }
int64_t upper() const { return upper_; }
......
......@@ -128,6 +128,7 @@ from oneflow._C import softplus
from oneflow._C import tril
from oneflow._C import triu
from oneflow._C import pad
from oneflow._C import distributed_partial_fc_sample
from oneflow._C import transpose
from oneflow._C import relu
from oneflow._C import softmax
......@@ -155,7 +156,6 @@ import oneflow.framework.register_python_callback
INVALID_SPLIT_AXIS = oneflow._oneflow_internal.INVALID_SPLIT_AXIS
register_class_method_util.RegisterMethod4Class()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
import oneflow.framework.env_util as env_util
import oneflow.framework.scope_util as scope_util
import oneflow.framework.session_context as session_ctx
......@@ -165,6 +165,7 @@ if not env_util.HasAllMultiClientEnvVars():
env_util.SetDefaultMultiClientEnvVars()
oneflow._oneflow_internal.SetIsMultiClient(True)
env_util.api_env_init()
oneflow._oneflow_internal.RegisterGILForeignLockHelper()
oneflow._oneflow_internal.InitDefaultConsistentTransportTokenScope()
session_ctx.OpenDefaultSession(
MultiClientSession(oneflow._oneflow_internal.NewSessionId())
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from oneflow.test_utils.automated_test_util import *
import oneflow as flow
import oneflow.unittest
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestParitalFC(flow.unittest.TestCase):
def test_parital_fc(test_case):
p = flow.env.all_device_placement("cuda")
w = flow.randn(50000, 128, placement=p, sbp=flow.sbp.broadcast)
label = flow.randint(0, 50000, (512,), placement=p, sbp=flow.sbp.broadcast)
num_sample = 5000
out = flow.distributed_partial_fc_sample(w, label, num_sample)
test_case.assertTrue(out[0].shape == flow.Size([512]))
test_case.assertTrue(out[1].shape == flow.Size([5000]))
test_case.assertTrue(out[2].shape == flow.Size([5000, 128]))
if __name__ == "__main__":
unittest.main()