未验证 提交 d63ccc09 编写于 作者: L Li Xinqi 提交者: GitHub

Multi client push pull (#5492)

* replace ForeignJobInstance using JobInstance

* LazyJobStreamType

* NNGraphIf

* NNGraph -> NNGraphIf

* fix compile bugs

* add unit tests for instruction RunLazyJob

* GetInputBufferName/GetOutputBufferName

* multi-client push pull

* refactor wait_and_send_ids_kernel to support multi-client mode

* support multi-client mode return_op
Co-authored-by: Nliufengwei <2472937968@qq.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 8004ffc1
......@@ -52,14 +52,9 @@ Maybe<void> EagerBlobObject::TryInitBlob() {
Maybe<void> EagerBlobObject::InitBlob() {
CHECK_NE_OR_RETURN(blob_desc_.data_type(), DataType::kInvalidDataType);
{
header_buffer_.reset();
int64_t header_byte_size = blob_desc_.AlignedByteSizeOfBlobHeader();
const auto& FreeHeader = [header_byte_size](char* dptr) { std::free(dptr); };
char* ptr = reinterpret_cast<char*>(std::malloc(header_byte_size));
header_buffer_ = std::unique_ptr<char, std::function<void(char*)>>(ptr, FreeHeader);
}
blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer_.get(), nullptr));
char* header_buffer =
reinterpret_cast<char*>(const_cast<int64_t*>(blob_desc_.shape().dim_vec().data()));
blob_.reset(new Blob(*mem_case_, &blob_desc_, header_buffer, nullptr));
return Maybe<void>::Ok();
}
......
......@@ -51,7 +51,6 @@ class EagerBlobObject final : public BlobObject {
~EagerBlobObject() override {
non_pod_initer_.reset();
tensor_buffer_.reset();
header_buffer_.reset();
blob_.reset();
}
......@@ -79,7 +78,6 @@ class EagerBlobObject final : public BlobObject {
private:
std::unique_ptr<Blob> blob_;
std::unique_ptr<char, std::function<void(char*)>> header_buffer_;
std::shared_ptr<TensorBuffer> tensor_buffer_;
std::size_t blob_body_bytes_;
std::unique_ptr<MemoryAllocator> non_pod_initer_;
......
......@@ -13,7 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -30,7 +34,21 @@ class InputKernel final : public KernelIf<device_type> {
void Forward(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = this->job_desc().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name));
std::shared_ptr<JobInstance> job_instance;
BufferStatus buffer_status = buffer->TryReceive(&job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
if (buffer_status == kBufferStatusSuccess) {
OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("out"));
job_instance->PushBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
}
}
}
void ForwardHeader(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {}
};
......
......@@ -14,19 +14,40 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/output_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
template<DeviceType device_type>
void OutputKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = this->job_desc().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));
std::shared_ptr<JobInstance> job_instance;
BufferStatus buffer_status = buffer->TryReceive(&job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
if (buffer_status == kBufferStatusSuccess) {
OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("in"));
job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
}
} else {
BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
}
}
template<DeviceType device_type>
void OutputKernel<device_type>::ForwardHeader(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in"));
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
// Do nothing.
} else {
BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in"));
}
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kOutputConf, OutputKernel);
......
......@@ -14,20 +14,41 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/return_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
template<DeviceType device_type>
void ReturnKernel<device_type>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
ctx.device_ctx->SyncDevice();
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = this->job_desc().job_name();
const auto& op_name = this->op_conf().name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name));
std::shared_ptr<JobInstance> job_instance;
BufferStatus buffer_status = buffer->TryReceive(&job_instance);
CHECK_NE(buffer_status, kBufferStatusEmpty);
if (buffer_status == kBufferStatusSuccess) {
OfBlob ofblob(ctx.device_ctx, BnInOp2Blob("in"));
job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
}
} else {
BnInOp2Blob("out")->CopyDataContentFrom(ctx.device_ctx, BnInOp2Blob("in"));
ctx.device_ctx->SyncDevice();
}
}
template<DeviceType device_type>
void ReturnKernel<device_type>::ForwardHeader(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in"));
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
// Do nothing.
} else {
BnInOp2Blob("out")->CopyHeaderFrom(ctx.device_ctx, BnInOp2Blob("in"));
}
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReturnConf, ReturnKernel);
......
......@@ -13,7 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/wait_and_send_ids_kernel.h"
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/job/job_instance.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
......@@ -24,11 +28,25 @@ void WaitAndSendIdsKernel<T>::ForwardDataContent(
auto* status = static_cast<WaitAndSendIdsStatus*>(ctx.other);
const auto& conf = this->op_conf().wait_and_send_ids_conf();
if (status->out_idx_ >= status->out_num_) {
status->buffer_status_ =
Global<BufferMgr<int64_t>>::Get()->Get(conf.wait_buffer_name())->Receive(&status->in_id_);
if (status->buffer_status_ == kBufferStatusErrorClosed) { return; }
status->out_idx_ = 0;
status->out_num_ = conf.id_list(status->in_id_).value_size();
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& job_name = this->job_desc().job_name();
auto* buffer_mgr = Global<BufferMgr<std::shared_ptr<JobInstance>>>::Get();
auto* buffer = buffer_mgr->Get(GetSourceTickBufferName(job_name));
status->in_id_ = 0;
{
std::shared_ptr<JobInstance> job_instance;
status->buffer_status_ = buffer->Receive(&job_instance);
}
if (status->buffer_status_ == kBufferStatusErrorClosed) { return; }
status->out_idx_ = 0;
status->out_num_ = 1;
} else {
auto* buffer_mgr = Global<BufferMgr<int64_t>>::Get();
status->buffer_status_ = buffer_mgr->Get(conf.wait_buffer_name())->Receive(&status->in_id_);
if (status->buffer_status_ == kBufferStatusErrorClosed) { return; }
status->out_idx_ = 0;
status->out_num_ = conf.id_list(status->in_id_).value_size();
}
}
*BnInOp2Blob("out")->mut_dptr<T>() = conf.id_list(status->in_id_).value(status->out_idx_);
++status->out_idx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册