未验证 提交 87cec628 编写于 作者: H Houjiang Chen 提交者: GitHub

Sync access and assign indexing tensor. (#5907)

* Sync access and assign indexing tensor.

* Remove unused comments.
Co-authored-by: qq_22305325's avatarbinbinHan <han_binbin@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 1b7511ba
......@@ -152,15 +152,16 @@ Maybe<Tensor> ConvertToIndexingTensor(PyObject* object) {
// Prevent the python object release until the callback is complete.
Py_INCREF(object);
auto handle = std::shared_ptr<PyObject>(PyObjectPtr(object));
JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
JUST(builder->AccessBlobByCallback(
JUST(tensor->AsMirroredTensor()),
[handle](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
CHECK_JUST(ParseArrayToBlob(handle.get(), of_blob->mut_blob()));
},
"mut"));
return Maybe<void>::Ok();
const auto& callback =
std::make_shared<std::function<void(uint64_t)>>([handle](uint64_t of_blob_ptr) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
CHECK_JUST(ParseArrayToBlob(handle.get(), of_blob->mut_blob()));
});
JUST(SpinCounter::SpinWait(1, [&](const std::shared_ptr<SpinCounter>& sc) -> Maybe<void> {
return PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
return builder->SyncAccessBlobByCallback(JUST(tensor->AsMirroredTensor()), sc, callback,
"mut");
});
}));
return tensor;
}
......
......@@ -46,26 +46,10 @@ std::string SbpParallelSymbolToString(const Symbol<cfg::SbpParallel>& sbp_sym) {
Maybe<std::vector<Symbol<cfg::SbpParallel>>> MakeSplitSbpParallelList(int max_split_axis) {
std::shared_ptr<std::vector<Symbol<cfg::SbpParallel>>> ret =
std::make_shared<std::vector<Symbol<cfg::SbpParallel>>>(max_split_axis);
for (int i = 0; i < max_split_axis; ++i) {
cfg::SbpParallel split_sbp_parallel;
split_sbp_parallel.mutable_split_parallel()->set_axis(i);
ret->at(i) = SymbolOf(split_sbp_parallel);
}
for (int i = 0; i < max_split_axis; ++i) { ret->at(i) = JUST(MakeSplitSbpParallel(i)); }
return ret;
}
Maybe<Symbol<cfg::SbpParallel>> MakeBroadcastSbpParallel() {
cfg::SbpParallel broadcast_sbp;
broadcast_sbp.mutable_broadcast_parallel();
return SymbolOf(broadcast_sbp);
}
Maybe<Symbol<cfg::SbpParallel>> MakePartialSumSbpParallel() {
cfg::SbpParallel partial_sum_sbp;
partial_sum_sbp.mutable_partial_sum_parallel();
return SymbolOf(partial_sum_sbp);
}
Maybe<Symbol<cfg::SbpParallel>> GetSplitSbpParallel(int axis) {
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);
static std::vector<Symbol<cfg::SbpParallel>> split_sbp_sym_list =
......
......@@ -31,25 +31,6 @@ namespace {
using SbpPair2EagerBoxingInterpreter =
HashMap<std::pair<cfg::SbpParallel, cfg::SbpParallel>, std::shared_ptr<EagerBoxingInterpreter>>;
Maybe<Symbol<cfg::SbpParallel>> GetSplitSbpParallel(int axis) {
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);
cfg::SbpParallel split_sbp_parallel;
split_sbp_parallel.mutable_split_parallel()->set_axis(axis);
return SymbolOf(split_sbp_parallel);
}
Maybe<Symbol<cfg::SbpParallel>> MakeBroadcastSbpParallel() {
cfg::SbpParallel broadcast_sbp;
broadcast_sbp.mutable_broadcast_parallel();
return SymbolOf(broadcast_sbp);
}
Maybe<Symbol<cfg::SbpParallel>> MakePartialSumSbpParallel() {
cfg::SbpParallel partial_sum_sbp;
partial_sum_sbp.mutable_partial_sum_parallel();
return SymbolOf(partial_sum_sbp);
}
std::string GetSupportedBoxingTypeInfo() {
static std::string supported_boxing_type_info =
"============ Supported eager boxing type============\n"
......@@ -65,15 +46,15 @@ std::string GetSupportedBoxingTypeInfo() {
Maybe<EagerBoxingInterpreter> GetOneDimNcclCollectiveEagerBoxingInterpreter(
Symbol<cfg::NdSbp> in_nd_sbp, Symbol<cfg::NdSbp> out_nd_sbp) {
static SbpPair2EagerBoxingInterpreter sbp_pair2eager_boxing_interpreter = {
{{*JUST(GetSplitSbpParallel(0)), *JUST(MakeBroadcastSbpParallel())}, // S(0) -> B
{{*JUST(MakeSplitSbpParallel(0)), *JUST(MakeBroadcastSbpParallel())}, // S(0) -> B
std::make_shared<NcclCollectiveAllGatherBoxingInterpreter>()},
{{*JUST(MakeBroadcastSbpParallel()), *JUST(GetSplitSbpParallel(0))}, // B -> S(0)
{{*JUST(MakeBroadcastSbpParallel()), *JUST(MakeSplitSbpParallel(0))}, // B -> S(0)
std::make_shared<NcclCollectiveReduceScatterBoxingInterpreter>("max")},
{{*JUST(MakePartialSumSbpParallel()), *JUST(MakeBroadcastSbpParallel())}, // P -> B
std::make_shared<NcclCollectiveAllReduceBoxingInterpreter>()},
{{*JUST(MakePartialSumSbpParallel()), *JUST(GetSplitSbpParallel(0))}, // P -> S(0)
{{*JUST(MakePartialSumSbpParallel()), *JUST(MakeSplitSbpParallel(0))}, // P -> S(0)
std::make_shared<NcclCollectiveReduceScatterBoxingInterpreter>("sum")},
{{*JUST(GetSplitSbpParallel(0)), *JUST(MakePartialSumSbpParallel())}, // S(0) -> P
{{*JUST(MakeSplitSbpParallel(0)), *JUST(MakePartialSumSbpParallel())}, // S(0) -> P
std::make_shared<NcclS2PBoxingInterpreter>()},
};
const auto& key = std::make_pair(in_nd_sbp->sbp_parallel(0), out_nd_sbp->sbp_parallel(0));
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/job/sbp_parallel.h"
namespace oneflow {
namespace one {
......@@ -241,8 +242,10 @@ Maybe<Tensor> ApplyAdvancedIndexing(const std::shared_ptr<Tensor>& input,
packed_indices = JUST(Transpose(packed_indices, permute));
if (transposed_input->is_consistent()) {
// TODO(hjchen2): Cast local indices to consistent.
UNIMPLEMENTED_THEN_RETURN() << "Not support consistent mode.";
const auto& placement = JUST(transposed_input->parallel_desc());
const auto& broadcast_sbp = JUST(MakeBroadcastSbpParallel());
packed_indices = JUST(ToConsistent(packed_indices, placement, {broadcast_sbp},
/*identity_grad=*/false, /*grad_sbp_parallels=*/{}));
}
Symbol<Device> device = JUST(transposed_input->device());
if (JUST(packed_indices->device()) != device) {
......
......@@ -18,6 +18,25 @@ limitations under the License.
namespace oneflow {
Maybe<Symbol<cfg::SbpParallel>> MakeSplitSbpParallel(int axis) {
CHECK_LT_OR_RETURN(axis, kMaxSplitAxis);
cfg::SbpParallel split_sbp_parallel;
split_sbp_parallel.mutable_split_parallel()->set_axis(axis);
return SymbolOf(split_sbp_parallel);
}
Maybe<Symbol<cfg::SbpParallel>> MakeBroadcastSbpParallel() {
cfg::SbpParallel broadcast_sbp;
broadcast_sbp.mutable_broadcast_parallel();
return SymbolOf(broadcast_sbp);
}
Maybe<Symbol<cfg::SbpParallel>> MakePartialSumSbpParallel() {
cfg::SbpParallel partial_sum_sbp;
partial_sum_sbp.mutable_partial_sum_parallel();
return SymbolOf(partial_sum_sbp);
}
// S -> S
// P -> B
// B -> P
......
......@@ -22,6 +22,10 @@ limitations under the License.
namespace oneflow {
Maybe<Symbol<cfg::SbpParallel>> MakeSplitSbpParallel(int axis);
Maybe<Symbol<cfg::SbpParallel>> MakeBroadcastSbpParallel();
Maybe<Symbol<cfg::SbpParallel>> MakePartialSumSbpParallel();
inline bool operator!=(const cfg::SbpParallel& lhs, const cfg::SbpParallel& rhs) {
return !(lhs == rhs);
}
......@@ -58,6 +62,7 @@ void NdSbpSignatureToSbpSignature(const NdSbpSignatureT& nd_sbp_signature,
cfg::SbpSignature* sbp_signature);
void CheckSbpSignatureAndNdSbpEquals(const cfg::SbpSignature& sbp_sig,
const cfg::NdSbpSignature& nd_sbp_sig);
} // namespace oneflow
namespace std {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册