未验证 提交 b2bbde00 编写于 作者: J Juncheng 提交者: GitHub

System kernels use memcpy/memset primitive (#6225)

上级 09df022c
...@@ -151,7 +151,7 @@ class RunLazyJobInstructionType final : public InstructionType { ...@@ -151,7 +151,7 @@ class RunLazyJobInstructionType final : public InstructionType {
const auto& PushCb = [blob](int64_t of_blob_ptr) { const auto& PushCb = [blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr); OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->mut_blob()->CopyHeaderFrom(of_blob->mut_device_ctx(), blob); of_blob->mut_blob()->CopyHeaderFrom(of_blob->mut_device_ctx(), blob);
of_blob->mut_blob()->CopyDataContentFrom(of_blob->mut_device_ctx(), blob); AutoMemcpy(of_blob->mut_device_ctx(), of_blob->mut_blob(), blob);
}; };
CHECK(push_cbs.emplace(op_name, PushCb).second); CHECK(push_cbs.emplace(op_name, PushCb).second);
} }
...@@ -166,7 +166,7 @@ class RunLazyJobInstructionType final : public InstructionType { ...@@ -166,7 +166,7 @@ class RunLazyJobInstructionType final : public InstructionType {
const auto& PullCb = [mut_blob](int64_t of_blob_ptr) { const auto& PullCb = [mut_blob](int64_t of_blob_ptr) {
OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr); OfBlob* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
mut_blob->CopyHeaderFrom(of_blob->mut_device_ctx(), &of_blob->blob()); mut_blob->CopyHeaderFrom(of_blob->mut_device_ctx(), &of_blob->blob());
mut_blob->CopyDataContentFrom(of_blob->mut_device_ctx(), &of_blob->blob()); AutoMemcpy(of_blob->mut_device_ctx(), mut_blob, &of_blob->blob());
}; };
CHECK(pull_cbs.emplace(op_name, PullCb).second); CHECK(pull_cbs.emplace(op_name, PullCb).second);
} }
......
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class AssignKernel final : public Kernel { class AssignKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(AssignKernel); OF_DISALLOW_COPY_AND_MOVE(AssignKernel);
...@@ -29,16 +28,12 @@ class AssignKernel final : public Kernel { ...@@ -29,16 +28,12 @@ class AssignKernel final : public Kernel {
void ForwardDataContent(KernelContext* ctx) const override; void ForwardDataContent(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void AssignKernel::ForwardDataContent(KernelContext* ctx) const {
void AssignKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { const Blob* value = ctx->BnInOp2Blob("value");
ctx->BnInOp2Blob("ref")->CopyValidDataContentFrom(ctx->device_ctx(), ctx->BnInOp2Blob("value")); Blob* ref = ctx->BnInOp2Blob("ref");
AutoMemcpy(ctx->stream_ctx(), ref, value);
} }
REGISTER_KERNEL_WITH_DEVICE(OperatorConf::kAssignConf, DeviceType::kCPU, REGISTER_KERNEL(OperatorConf::kAssignConf, AssignKernel);
AssignKernel<DeviceType::kCPU>);
#ifdef WITH_CUDA
REGISTER_KERNEL_WITH_DEVICE(OperatorConf::kAssignConf, DeviceType::kGPU,
AssignKernel<DeviceType::kGPU>);
#endif
} // namespace oneflow } // namespace oneflow
...@@ -76,11 +76,13 @@ void CalcSumOfBlobs<float16>(DeviceCtx* ctx, ...@@ -76,11 +76,13 @@ void CalcSumOfBlobs<float16>(DeviceCtx* ctx,
} }
} }
void CopyFromFirstToOtherBlobs(DeviceCtx* ctx, void CopyFromFirstToOtherBlobs(KernelContext* ctx,
const std::function<Blob*(const std::string&)>& BnInOp2Blob, const std::function<Blob*(const std::string&)>& BnInOp2Blob,
const PbRpf<std::string>& bns, CopyBlobFieldMthd Copy) { const PbRpf<std::string>& bns) {
const Blob* blob_0 = BnInOp2Blob(bns.Get(0)); const Blob* blob_0 = BnInOp2Blob(bns.Get(0));
FOR_RANGE(size_t, i, 1, bns.size()) { (BnInOp2Blob(bns.Get(i))->*Copy)(ctx, blob_0); } FOR_RANGE(size_t, i, 1, bns.size()) {
AutoMemcpy(ctx->stream_ctx(), BnInOp2Blob(bns.Get(i)), blob_0);
}
} }
class DataContentDesc final { class DataContentDesc final {
...@@ -225,8 +227,7 @@ void BoxingKernel<T>::ForwardDataContent(KernelContext* ctx) const { ...@@ -225,8 +227,7 @@ void BoxingKernel<T>::ForwardDataContent(KernelContext* ctx) const {
} else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) { } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) {
ConcatSplitDataContent(device_ctx, BnInOp2Blob, op_attribute().input_bns(), ConcatSplitDataContent(device_ctx, BnInOp2Blob, op_attribute().input_bns(),
boxing_conf.concat_box().axis(), obn_0_, 0); boxing_conf.concat_box().axis(), obn_0_, 0);
CopyFromFirstToOtherBlobs(device_ctx, BnInOp2Blob, op_attribute().output_bns(), CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns());
DataContentIterator::GetCopyBlobFieldMthd());
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
...@@ -237,8 +238,7 @@ void BoxingKernel<T>::ForwardDataContent(KernelContext* ctx) const { ...@@ -237,8 +238,7 @@ void BoxingKernel<T>::ForwardDataContent(KernelContext* ctx) const {
op_attribute().output_bns(), boxing_conf.split_box().axis()); op_attribute().output_bns(), boxing_conf.split_box().axis());
} else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) { } else if (boxing_conf.out_box_case() == BoxingOpConf::kCloneBox) {
CalcSumOfBlobs<T>(device_ctx, BnInOp2Blob, op_attribute().input_bns(), obn_0_.Get(0)); CalcSumOfBlobs<T>(device_ctx, BnInOp2Blob, op_attribute().input_bns(), obn_0_.Get(0));
CopyFromFirstToOtherBlobs(device_ctx, BnInOp2Blob, op_attribute().output_bns(), CopyFromFirstToOtherBlobs(ctx, BnInOp2Blob, op_attribute().output_bns());
DataContentIterator::GetCopyBlobFieldMthd());
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
......
...@@ -15,10 +15,10 @@ limitations under the License. ...@@ -15,10 +15,10 @@ limitations under the License.
*/ */
#include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/primitive/memset.h"
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class BoxingZerosKernel final : public Kernel { class BoxingZerosKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(BoxingZerosKernel); OF_DISALLOW_COPY_AND_MOVE(BoxingZerosKernel);
...@@ -26,15 +26,22 @@ class BoxingZerosKernel final : public Kernel { ...@@ -26,15 +26,22 @@ class BoxingZerosKernel final : public Kernel {
~BoxingZerosKernel() override = default; ~BoxingZerosKernel() override = default;
private: private:
void VirtualKernelInit(KernelContext* ctx) override;
void ForwardDataContent(KernelContext* ctx) const override; void ForwardDataContent(KernelContext* ctx) const override;
std::unique_ptr<primitive::Memset> primitive_;
}; };
template<DeviceType device_type> void BoxingZerosKernel::VirtualKernelInit(KernelContext* ctx) {
void BoxingZerosKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { primitive_ = primitive::NewPrimitive<primitive::MemsetFactory>(this->op_conf().device_tag());
CHECK(primitive_);
}
void BoxingZerosKernel::ForwardDataContent(KernelContext* ctx) const {
Blob* out = ctx->BnInOp2Blob("out"); Blob* out = ctx->BnInOp2Blob("out");
Memset<device_type>(ctx->device_ctx(), out->mut_dptr(), 0, out->ByteSizeOfBlobBody()); primitive_->Launch(ctx->stream_ctx(), out->mut_dptr(), 0, out->ByteSizeOfBlobBody());
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kBoxingZerosConf, BoxingZerosKernel); REGISTER_KERNEL(OperatorConf::kBoxingZerosConf, BoxingZerosKernel);
} // namespace oneflow } // namespace oneflow
...@@ -23,7 +23,6 @@ namespace oneflow { ...@@ -23,7 +23,6 @@ namespace oneflow {
using namespace boxing::collective; using namespace boxing::collective;
template<DeviceType device_type>
class CollectiveBoxingGenericKernel final : public Kernel { class CollectiveBoxingGenericKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericKernel); OF_DISALLOW_COPY_AND_MOVE(CollectiveBoxingGenericKernel);
...@@ -35,8 +34,7 @@ class CollectiveBoxingGenericKernel final : public Kernel { ...@@ -35,8 +34,7 @@ class CollectiveBoxingGenericKernel final : public Kernel {
void ForwardDataContent(KernelContext* ctx) const override; void ForwardDataContent(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void CollectiveBoxingGenericKernel::ForwardDataContent(KernelContext* ctx) const {
void CollectiveBoxingGenericKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
RuntimeRequestInfo request; RuntimeRequestInfo request;
const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc(); const RankDesc& rank_desc = this->op_conf().collective_boxing_generic_conf().rank_desc();
const DataType data_type = rank_desc.op_desc().data_type(); const DataType data_type = rank_desc.op_desc().data_type();
...@@ -67,7 +65,6 @@ void CollectiveBoxingGenericKernel<device_type>::ForwardDataContent(KernelContex ...@@ -67,7 +65,6 @@ void CollectiveBoxingGenericKernel<device_type>::ForwardDataContent(KernelContex
Global<CollectiveBoxingExecutor>::Get()->Enqueue(rank_desc, request); Global<CollectiveBoxingExecutor>::Get()->Enqueue(rank_desc, request);
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCollectiveBoxingGenericConf, REGISTER_KERNEL(OperatorConf::kCollectiveBoxingGenericConf, CollectiveBoxingGenericKernel);
CollectiveBoxingGenericKernel);
} // namespace oneflow } // namespace oneflow
...@@ -67,7 +67,7 @@ void CollectiveBoxingPackKernel<device_type, T>::ForwardDataContent(KernelContex ...@@ -67,7 +67,7 @@ void CollectiveBoxingPackKernel<device_type, T>::ForwardDataContent(KernelContex
ctx->device_ctx(), transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape, ctx->device_ctx(), transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape,
perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>()); perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>());
} else { } else {
out->CopyDataContentFrom(ctx->device_ctx(), in); AutoMemcpy(ctx->stream_ctx(), out, in);
} }
} }
......
...@@ -67,7 +67,7 @@ void CollectiveBoxingUnpackKernel<device_type, T>::ForwardDataContent(KernelCont ...@@ -67,7 +67,7 @@ void CollectiveBoxingUnpackKernel<device_type, T>::ForwardDataContent(KernelCont
ctx->device_ctx(), transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape, ctx->device_ctx(), transpose_in_shape.NumAxes(), transpose_in_shape, transpose_out_shape,
perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>()); perm, transpose_in_shape.elem_cnt(), in->dptr<T>(), out->mut_dptr<T>());
} else { } else {
out->CopyDataContentFrom(ctx->device_ctx(), in); AutoMemcpy(ctx->stream_ctx(), out, in);
} }
} }
......
...@@ -33,7 +33,6 @@ class CopyHdKernel final : public Kernel { ...@@ -33,7 +33,6 @@ class CopyHdKernel final : public Kernel {
}; };
void CopyHdKernel::VirtualKernelInit(KernelContext* ctx) { void CopyHdKernel::VirtualKernelInit(KernelContext* ctx) {
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(this->op_conf().device_tag()));
CHECK(this->op_conf().has_copy_hd_conf()); CHECK(this->op_conf().has_copy_hd_conf());
const CopyHdOpConf& copy_hd_conf = this->op_conf().copy_hd_conf(); const CopyHdOpConf& copy_hd_conf = this->op_conf().copy_hd_conf();
primitive::MemcpyKind kind{}; primitive::MemcpyKind kind{};
...@@ -44,7 +43,8 @@ void CopyHdKernel::VirtualKernelInit(KernelContext* ctx) { ...@@ -44,7 +43,8 @@ void CopyHdKernel::VirtualKernelInit(KernelContext* ctx) {
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
primitive_ = primitive::NewPrimitive<primitive::MemcpyFactory>(device_type, kind); primitive_ =
primitive::NewPrimitive<primitive::MemcpyFactory>(this->op_conf().device_tag(), kind);
CHECK(primitive_); CHECK(primitive_);
} }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
namespace {
void CheckSizeAndCopyBlob(DeviceCtx* ctx, Blob* dst, const Blob* src) {
dst->CopyValidDataContentFrom(ctx, src);
}
} // namespace
template<DeviceType device_type>
class DistributeAddKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeAddKernel);
DistributeAddKernel() = default;
~DistributeAddKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
const Blob* GetInBlob(KernelContext* ctx) const;
};
template<DeviceType device_type>
void DistributeAddKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
CheckSizeAndCopyBlob(ctx->device_ctx(), ctx->BnInOp2Blob("out"), GetInBlob(ctx));
}
template<DeviceType device_type>
const Blob* DistributeAddKernel<device_type>::GetInBlob(KernelContext* ctx) const {
const Blob* in_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {
const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));
if (cur_blob != nullptr && cur_blob != in_blob) {
CHECK_ISNULL(in_blob);
in_blob = cur_blob;
}
}
return in_blob;
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDistributeAddConf, DistributeAddKernel);
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
namespace {
void CheckSizeAndCopyBlob(DeviceCtx* ctx, Blob* dst, const Blob* src) {
dst->CopyDataContentFrom(ctx, src);
}
} // namespace
template<DeviceType device_type>
class DistributeCloneKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeCloneKernel);
DistributeCloneKernel() = default;
~DistributeCloneKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
Blob* GetOutBlob(KernelContext* ctx) const;
};
template<DeviceType device_type>
void DistributeCloneKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
CheckSizeAndCopyBlob(ctx->device_ctx(), GetOutBlob(ctx), ctx->BnInOp2Blob("in"));
}
template<DeviceType device_type>
Blob* DistributeCloneKernel<device_type>::GetOutBlob(KernelContext* ctx) const {
Blob* out_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) {
Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i));
if (cur_blob != nullptr && cur_blob != out_blob) {
CHECK_ISNULL(out_blob);
out_blob = cur_blob;
}
}
return out_blob;
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDistributeCloneConf, DistributeCloneKernel);
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
namespace {
void CheckSizeAndCopyBlob(DeviceCtx* ctx, Blob* dst, const Blob* src) {
dst->CopyValidDataContentFrom(ctx, src);
}
} // namespace
template<DeviceType device_type>
class DistributeConcatKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeConcatKernel);
DistributeConcatKernel() = default;
~DistributeConcatKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
const Blob* GetInBlob(KernelContext* ctx) const;
};
template<DeviceType device_type>
void DistributeConcatKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
CheckSizeAndCopyBlob(ctx->device_ctx(), ctx->BnInOp2Blob("out"), GetInBlob(ctx));
}
template<DeviceType device_type>
const Blob* DistributeConcatKernel<device_type>::GetInBlob(KernelContext* ctx) const {
const Blob* in_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {
const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));
if (cur_blob != nullptr && cur_blob != in_blob) {
CHECK_ISNULL(in_blob);
in_blob = cur_blob;
}
}
return in_blob;
}
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDistributeConcatConf, DistributeConcatKernel);
} // namespace oneflow
...@@ -18,15 +18,93 @@ limitations under the License. ...@@ -18,15 +18,93 @@ limitations under the License.
namespace oneflow { namespace oneflow {
namespace { class DistributeAddKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeAddKernel);
DistributeAddKernel() = default;
~DistributeAddKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
const Blob* GetInBlob(KernelContext* ctx) const;
};
void DistributeAddKernel::ForwardDataContent(KernelContext* ctx) const {
AutoMemcpy(ctx->stream_ctx(), ctx->BnInOp2Blob("out"), GetInBlob(ctx));
}
const Blob* DistributeAddKernel::GetInBlob(KernelContext* ctx) const {
const Blob* in_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {
const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));
if (cur_blob != nullptr && cur_blob != in_blob) {
CHECK_ISNULL(in_blob);
in_blob = cur_blob;
}
}
return in_blob;
}
void CheckSizeAndCopyBlob(DeviceCtx* ctx, Blob* dst, const Blob* src) { REGISTER_KERNEL(OperatorConf::kDistributeAddConf, DistributeAddKernel);
dst->CopyValidDataContentFrom(ctx, src);
class DistributeCloneKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeCloneKernel);
DistributeCloneKernel() = default;
~DistributeCloneKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
Blob* GetOutBlob(KernelContext* ctx) const;
};
void DistributeCloneKernel::ForwardDataContent(KernelContext* ctx) const {
AutoMemcpy(ctx->stream_ctx(), GetOutBlob(ctx), ctx->BnInOp2Blob("in"));
}
Blob* DistributeCloneKernel::GetOutBlob(KernelContext* ctx) const {
Blob* out_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) {
Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i));
if (cur_blob != nullptr && cur_blob != out_blob) {
CHECK_ISNULL(out_blob);
out_blob = cur_blob;
}
}
return out_blob;
}
REGISTER_KERNEL(OperatorConf::kDistributeCloneConf, DistributeCloneKernel);
class DistributeConcatKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(DistributeConcatKernel);
DistributeConcatKernel() = default;
~DistributeConcatKernel() = default;
private:
void ForwardDataContent(KernelContext* ctx) const override;
const Blob* GetInBlob(KernelContext* ctx) const;
};
void DistributeConcatKernel::ForwardDataContent(KernelContext* ctx) const {
AutoMemcpy(ctx->stream_ctx(), ctx->BnInOp2Blob("out"), GetInBlob(ctx));
}
const Blob* DistributeConcatKernel::GetInBlob(KernelContext* ctx) const {
const Blob* in_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().input_bns().size()) {
const Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().input_bns().Get(i));
if (cur_blob != nullptr && cur_blob != in_blob) {
CHECK_ISNULL(in_blob);
in_blob = cur_blob;
}
}
return in_blob;
} }
} // namespace REGISTER_KERNEL(OperatorConf::kDistributeConcatConf, DistributeConcatKernel);
template<DeviceType device_type>
class DistributeSplitKernel final : public Kernel { class DistributeSplitKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DistributeSplitKernel); OF_DISALLOW_COPY_AND_MOVE(DistributeSplitKernel);
...@@ -39,19 +117,16 @@ class DistributeSplitKernel final : public Kernel { ...@@ -39,19 +117,16 @@ class DistributeSplitKernel final : public Kernel {
Blob* GetOutBlob(KernelContext* ctx) const; Blob* GetOutBlob(KernelContext* ctx) const;
}; };
template<DeviceType device_type> void DistributeSplitKernel::ForwardDataContent(KernelContext* ctx) const {
void DistributeSplitKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { AutoMemcpy(ctx->stream_ctx(), GetOutBlob(ctx), ctx->BnInOp2Blob("in"));
CheckSizeAndCopyBlob(ctx->device_ctx(), GetOutBlob(ctx), ctx->BnInOp2Blob("in"));
} }
template<DeviceType device_type> void DistributeSplitKernel::ForwardShape(KernelContext* ctx) const {
void DistributeSplitKernel<device_type>::ForwardShape(KernelContext* ctx) const {
Blob* out_blob = GetOutBlob(ctx); Blob* out_blob = GetOutBlob(ctx);
out_blob->mut_shape_view()->set_shape(ctx->BnInOp2Blob("in")->shape()); out_blob->mut_shape_view()->set_shape(ctx->BnInOp2Blob("in")->shape());
} }
template<DeviceType device_type> Blob* DistributeSplitKernel::GetOutBlob(KernelContext* ctx) const {
Blob* DistributeSplitKernel<device_type>::GetOutBlob(KernelContext* ctx) const {
Blob* out_blob = nullptr; Blob* out_blob = nullptr;
FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) { FOR_RANGE(int, i, 0, this->op_attribute().output_bns().size()) {
Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i)); Blob* cur_blob = ctx->BnInOp2Blob(this->op_attribute().output_bns().Get(i));
...@@ -63,6 +138,6 @@ Blob* DistributeSplitKernel<device_type>::GetOutBlob(KernelContext* ctx) const { ...@@ -63,6 +138,6 @@ Blob* DistributeSplitKernel<device_type>::GetOutBlob(KernelContext* ctx) const {
return out_blob; return out_blob;
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDistributeSplitConf, DistributeSplitKernel); REGISTER_KERNEL(OperatorConf::kDistributeSplitConf, DistributeSplitKernel);
} // namespace oneflow } // namespace oneflow
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class DynamicReshapeKernel final : public Kernel { class DynamicReshapeKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeKernel); OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeKernel);
...@@ -28,13 +27,12 @@ class DynamicReshapeKernel final : public Kernel { ...@@ -28,13 +27,12 @@ class DynamicReshapeKernel final : public Kernel {
void ForwardDataContent(KernelContext* ctx) const override; void ForwardDataContent(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void DynamicReshapeKernel::ForwardDataContent(KernelContext* ctx) const {
void DynamicReshapeKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
const Blob* in_blob = ctx->BnInOp2Blob("in"); const Blob* in_blob = ctx->BnInOp2Blob("in");
Blob* out_blob = ctx->BnInOp2Blob("out"); Blob* out_blob = ctx->BnInOp2Blob("out");
out_blob->CopyDataContentFrom(ctx->device_ctx(), in_blob); AutoMemcpy(ctx->stream_ctx(), out_blob, in_blob);
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDynamicReshapeConf, DynamicReshapeKernel); REGISTER_KERNEL(OperatorConf::kDynamicReshapeConf, DynamicReshapeKernel);
} // namespace oneflow } // namespace oneflow
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class DynamicReshapeLikeKernel final : public Kernel { class DynamicReshapeLikeKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeLikeKernel); OF_DISALLOW_COPY_AND_MOVE(DynamicReshapeLikeKernel);
...@@ -28,13 +27,12 @@ class DynamicReshapeLikeKernel final : public Kernel { ...@@ -28,13 +27,12 @@ class DynamicReshapeLikeKernel final : public Kernel {
void ForwardDataContent(KernelContext* ctx) const override; void ForwardDataContent(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void DynamicReshapeLikeKernel::ForwardDataContent(KernelContext* ctx) const {
void DynamicReshapeLikeKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
const Blob* in_blob = ctx->BnInOp2Blob("x"); const Blob* in_blob = ctx->BnInOp2Blob("x");
Blob* out_blob = ctx->BnInOp2Blob("y"); Blob* out_blob = ctx->BnInOp2Blob("y");
out_blob->CopyDataContentFrom(ctx->device_ctx(), in_blob); AutoMemcpy(ctx->stream_ctx(), out_blob, in_blob);
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeKernel); REGISTER_KERNEL(OperatorConf::kDynamicReshapeLikeConf, DynamicReshapeLikeKernel);
} // namespace oneflow } // namespace oneflow
...@@ -15,10 +15,10 @@ limitations under the License. ...@@ -15,10 +15,10 @@ limitations under the License.
*/ */
#include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel_context.h" #include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/primitive/memcpy.h"
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class IdentityKernel final : public Kernel { class IdentityKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(IdentityKernel); OF_DISALLOW_COPY_AND_MOVE(IdentityKernel);
...@@ -30,20 +30,20 @@ class IdentityKernel final : public Kernel { ...@@ -30,20 +30,20 @@ class IdentityKernel final : public Kernel {
void ForwardHeader(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void IdentityKernel::ForwardDataContent(KernelContext* ctx) const {
void IdentityKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { const Blob* in_blob = ctx->BnInOp2Blob("in");
ctx->BnInOp2Blob("out")->CopyValidDataContentFrom(ctx->device_ctx(), ctx->BnInOp2Blob("in")); Blob* out_blob = ctx->BnInOp2Blob("out");
AutoMemcpy(ctx->stream_ctx(), out_blob, in_blob);
} }
template<DeviceType device_type> void IdentityKernel::ForwardHeader(KernelContext* ctx) const {
void IdentityKernel<device_type>::ForwardHeader(KernelContext* ctx) const {
ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->device_ctx(), ctx->BnInOp2Blob("in")); ctx->BnInOp2Blob("out")->CopyHeaderFrom(ctx->device_ctx(), ctx->BnInOp2Blob("in"));
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kIdentityConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kIdentityConf, IdentityKernel);
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCopyConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCopyConf, IdentityKernel);
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCastToMirroredConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCastToMirroredConf, IdentityKernel);
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kCastFromMirroredConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kCastFromMirroredConf, IdentityKernel);
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kBoxingIdentityConf, IdentityKernel); REGISTER_KERNEL(OperatorConf::kBoxingIdentityConf, IdentityKernel);
} // namespace oneflow } // namespace oneflow
...@@ -23,7 +23,6 @@ namespace oneflow { ...@@ -23,7 +23,6 @@ namespace oneflow {
namespace { namespace {
template<DeviceType device_type>
class InputKernel final : public Kernel { class InputKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(InputKernel); OF_DISALLOW_COPY_AND_MOVE(InputKernel);
...@@ -52,6 +51,6 @@ class InputKernel final : public Kernel { ...@@ -52,6 +51,6 @@ class InputKernel final : public Kernel {
} // namespace } // namespace
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kInputConf, InputKernel); REGISTER_KERNEL(OperatorConf::kInputConf, InputKernel);
} // namespace oneflow } // namespace oneflow
...@@ -18,6 +18,9 @@ limitations under the License. ...@@ -18,6 +18,9 @@ limitations under the License.
#include "oneflow/core/register/register_manager.h" #include "oneflow/core/register/register_manager.h"
#include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/memory/memory_case.pb.h"
#include "oneflow/core/primitive/memcpy.h"
#include "oneflow/core/primitive/memset.h"
#include "oneflow/core/stream/stream_context_adapter.h"
namespace oneflow { namespace oneflow {
...@@ -225,71 +228,67 @@ void IntSequenceInitializer(const IntRangeInitializerConf& initializer_conf, uin ...@@ -225,71 +228,67 @@ void IntSequenceInitializer(const IntRangeInitializerConf& initializer_conf, uin
RangeInitializer<T, IntRangeInitializerConf>(initializer_conf, random_seed, blob); RangeInitializer<T, IntRangeInitializerConf>(initializer_conf, random_seed, blob);
} }
void ComputeOffset(const int32_t num_axes, const int64_t* shape, const int32_t* permutation, } // namespace
DimVector& offset) {
offset.resize(num_axes);
DimVector buff(num_axes);
int64_t cur_offset = 1;
for (int32_t i = num_axes - 1; i >= 0; --i) {
buff[i] = cur_offset;
cur_offset *= shape[i];
}
for (int32_t i = 0; i < num_axes; ++i) { offset[permutation[i]] = buff[i]; }
}
void IncreaseIndex(const int64_t* shape, DimVector& index) { void AutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz,
for (int32_t i = index.size() - 1; i >= 0; --i) { const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) {
++index[i]; std::unique_ptr<StreamContext> stream_ctx(NewStreamContextAdapter(ctx));
if (index[i] >= shape[i]) { AutoMemcpy(stream_ctx.get(), dst, src, sz, dst_mem_case, src_mem_case);
index[i] -= shape[i];
} else {
break;
}
}
} }
} // namespace void AutoMemcpy(DeviceCtx* ctx, Blob* dst, const Blob* src) {
std::unique_ptr<StreamContext> stream_ctx(NewStreamContextAdapter(ctx));
AutoMemcpy(stream_ctx.get(), dst, src);
}
void AutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz, void AutoMemcpy(StreamContext* stream_ctx, void* dst, const void* src, size_t sz,
const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) { const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) {
void (*func)(DeviceCtx*, void* dst, const void* src, size_t sz); primitive::MemcpyKind kind{};
if (src_mem_case.has_host_mem() && dst_mem_case.has_host_mem()) { if (stream_ctx->device_type() == DeviceType::kCPU) {
func = &Memcpy<DeviceType::kCPU>; CHECK(src_mem_case.has_host_mem());
CHECK(dst_mem_case.has_host_mem());
kind = primitive::MemcpyKind::kDtoD;
} else { } else {
#ifdef WITH_CUDA if (src_mem_case.has_host_mem()) {
func = &Memcpy<DeviceType::kGPU>; CHECK(!dst_mem_case.has_host_mem());
#else kind = primitive::MemcpyKind::kHtoD;
UNIMPLEMENTED(); } else if (dst_mem_case.has_host_mem()) {
#endif // WITH_CUDA CHECK(!src_mem_case.has_host_mem());
kind = primitive::MemcpyKind::kDtoH;
} else {
kind = primitive::MemcpyKind::kDtoD;
}
} }
func(ctx, dst, src, sz); std::unique_ptr<primitive::Memcpy> primitive =
primitive::NewPrimitive<primitive::MemcpyFactory>(stream_ctx->device_type(), kind);
CHECK(primitive);
primitive->Launch(stream_ctx, dst, src, sz);
}
void AutoMemcpy(StreamContext* stream_ctx, Blob* dst, const Blob* src) {
const size_t body_bytes = src->ByteSizeOfBlobBody();
CHECK_EQ(dst->ByteSizeOfBlobBody(), body_bytes);
AutoMemcpy(stream_ctx, dst->mut_dptr(), src->dptr(), body_bytes, dst->mem_case(),
src->mem_case());
} }
void SyncAutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz, void SyncAutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz,
const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) { const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case) {
AutoMemcpy(ctx, dst, src, sz, dst_mem_case, src_mem_case); AutoMemcpy(ctx, dst, src, sz, dst_mem_case, src_mem_case);
if (src_mem_case.has_device_cuda_mem() || dst_mem_case.has_device_cuda_mem()) { ctx->SyncDevice();
#ifdef WITH_CUDA
OF_CUDA_CHECK(cudaStreamSynchronize(ctx->cuda_stream()));
#else
UNIMPLEMENTED();
#endif // WITH_CUDA
}
} }
void AutoMemset(DeviceCtx* ctx, void* dst, const char value, size_t sz, void AutoMemset(DeviceCtx* ctx, void* dst, const char value, size_t sz,
const MemoryCase& dst_mem_case) { const MemoryCase& dst_mem_case) {
void (*func)(DeviceCtx*, void* dst, const char value, size_t sz); std::unique_ptr<StreamContext> stream_ctx(NewStreamContextAdapter(ctx));
if (dst_mem_case.has_host_mem()) { AutoMemset(stream_ctx.get(), dst, value, sz, dst_mem_case);
func = &Memset<DeviceType::kCPU>; }
} else {
#ifdef WITH_CUDA void AutoMemset(StreamContext* stream_ctx, void* dst, const char value, size_t sz,
func = &Memset<DeviceType::kGPU>; const MemoryCase& /*dst_mem_case*/) {
#else std::unique_ptr<primitive::Memset> primitive =
UNIMPLEMENTED(); primitive::NewPrimitive<primitive::MemsetFactory>(stream_ctx->device_type());
#endif // WITH_CUDA primitive->Launch(stream_ctx, dst, value, sz);
}
func(ctx, dst, value, sz);
} }
#define KU_IF_METHOD \ #define KU_IF_METHOD \
...@@ -300,44 +299,6 @@ KU_IF_METHOD Axpy(DeviceCtx* ctx, const int n, const T* alpha, const T* x, const ...@@ -300,44 +299,6 @@ KU_IF_METHOD Axpy(DeviceCtx* ctx, const int n, const T* alpha, const T* x, const
const int incy) { const int incy) {
Derived::Axpy(ctx, n, *alpha, x, incx, y, incy); Derived::Axpy(ctx, n, *alpha, x, incx, y, incy);
} }
KU_IF_METHOD CopyColsRegion(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num,
const T* x, const int64_t x_col_offset, const int64_t x_lda, T* y,
const int64_t y_col_offset, const int64_t y_lda) {
for (int64_t i = 0; i < row_num; ++i) {
for (int64_t j = 0; j < col_num; ++j) {
y[i * y_lda + y_col_offset + j] = x[i * x_lda + x_col_offset + j];
}
}
}
KU_IF_METHOD Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const T* x, T* y) {
int64_t block_size = 1;
int32_t shared_idxs_num = 0;
for (int32_t i = num_axis - 1; i >= 0 && permutation[i] == i; --i) {
block_size *= y_shape.At(i);
++shared_idxs_num;
}
if (num_axis < 2 || shared_idxs_num == num_axis) {
memcpy(y, x, elem_cnt * sizeof(T));
return;
}
int32_t trans_axis = num_axis - shared_idxs_num;
DimVector x_to_y_offset;
ComputeOffset(trans_axis, y_shape.ptr(), permutation.data(), x_to_y_offset);
DimVector x_index_digits(trans_axis, 0);
int64_t num_blocks = elem_cnt / block_size;
FOR_RANGE(int64_t, x_idx, 0, num_blocks) {
int64_t y_idx = std::inner_product(x_to_y_offset.cbegin(), x_to_y_offset.cend(),
x_index_digits.cbegin(), 0);
if (block_size == 1) {
y[y_idx] = x[x_idx];
} else {
memcpy(y + block_size * y_idx, x + block_size * x_idx, block_size * sizeof(T));
}
IncreaseIndex(x_shape.ptr(), x_index_digits);
}
}
KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) { *addr = value; } KU_IF_METHOD Set(DeviceCtx* ctx, const T value, T* addr) { *addr = value; }
#define KU_FLOATING_METHOD \ #define KU_FLOATING_METHOD \
......
...@@ -115,106 +115,6 @@ __global__ void gpu_set(const T value, T* addr) { ...@@ -115,106 +115,6 @@ __global__ void gpu_set(const T value, T* addr) {
*addr = value; *addr = value;
} }
cublasOperation_t CblasTrans2CublasTrans(CBLAS_TRANSPOSE trans) {
cublasOperation_t cublas_trans;
if (trans == CBLAS_TRANSPOSE::CblasNoTrans) {
cublas_trans = cublasOperation_t::CUBLAS_OP_N;
} else if (trans == CBLAS_TRANSPOSE::CblasTrans) {
cublas_trans = cublasOperation_t::CUBLAS_OP_T;
} else if (trans == CBLAS_TRANSPOSE::CblasConjTrans) {
cublas_trans = cublasOperation_t::CUBLAS_OP_C;
} else {
// do nothing
}
return cublas_trans;
}
template<int32_t NDIMS>
struct Int32Array {
int32_t val[NDIMS];
};
template<typename T>
__global__ void CopyColsRegionGpu(const int64_t row_num, const int64_t col_num, const T* x,
const int64_t x_col_offset, const int64_t x_lda, T* y,
const int64_t y_col_offset, const int64_t y_lda) {
CUDA_1D_KERNEL_LOOP(index, row_num * col_num) {
const int64_t i = index / col_num;
const int64_t j = index % col_num;
y[i * y_lda + y_col_offset + j] = x[i * x_lda + x_col_offset + j];
}
}
template<int32_t NDIMS>
__device__ int32_t GetXIndex(const int32_t* y_shape, const int32_t* x_strides, int32_t y_idx) {
int32_t x_idx = 0;
for (int32_t i = NDIMS - 1; i >= 0; --i) {
x_idx += (y_idx % y_shape[i]) * x_strides[i];
y_idx /= y_shape[i];
}
return x_idx;
}
template<int32_t NDIMS, typename T>
__global__ void TransposeGpu(const Int32Array<NDIMS> y_shape, const Int32Array<NDIMS> x_strides,
const int32_t elem_cnt, const T* x, T* y) {
__shared__ int32_t x_strides_shared[NDIMS];
__shared__ int32_t y_dims_shared[NDIMS];
const int32_t tid = threadIdx.x;
if (tid < NDIMS) {
y_dims_shared[tid] = y_shape.val[tid];
x_strides_shared[tid] = x_strides.val[tid];
}
__syncthreads();
CUDA_1D_KERNEL_LOOP(y_idx, elem_cnt) {
const int32_t x_idx = GetXIndex<NDIMS>(y_dims_shared, x_strides_shared, y_idx);
#if __CUDA_ARCH__ >= 350
y[y_idx] = __ldg(x + x_idx);
#else
y[y_idx] = x[x_idx];
#endif
}
}
template<int32_t NDIMS, typename T>
void Transpose(DeviceCtx* ctx, const ShapeView& x_shape, const ShapeView& y_shape,
const PbRf<int32_t>& permutation, const int64_t elem_cnt, const T* x, T* y) {
CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
Int32Array<NDIMS> y_shape_struct;
FOR_RANGE(int32_t, i, 0, NDIMS) { y_shape_struct.val[i] = y_shape.At(i); }
Int32Array<NDIMS> x_strides;
int32_t buff[NDIMS];
int32_t cur_stride = 1;
for (int32_t i = NDIMS - 1; i >= 0; --i) {
buff[i] = cur_stride;
cur_stride *= x_shape.At(i);
}
for (int32_t i = 0; i < NDIMS; ++i) { x_strides.val[i] = buff[permutation[i]]; }
TransposeGpu<NDIMS, T>
<<<SMBlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
y_shape_struct, x_strides, elem_cnt, x, y);
}
template<typename T>
struct TransposeUtil final {
#define MAKE_TRANSPOSE_SWITCH_ENTRY(func_name, NDIMS) func_name<NDIMS, T>
DEFINE_STATIC_SWITCH_FUNC(void, Transpose, MAKE_TRANSPOSE_SWITCH_ENTRY,
MAKE_NDIM_CTRV_SEQ(DIM_SEQ))
};
template<typename T>
__global__ void AssignStridedAddrGpu(T** dev_ptrs, T* start_ptr, int32_t stride_len,
int32_t stride_num) {
CUDA_1D_KERNEL_LOOP(i, stride_num) { dev_ptrs[i] = start_ptr + i * stride_len; }
}
template<typename T>
void AssignStridedAddr(DeviceCtx* ctx, T** dev_ptrs, T* start_ptr, int stride_len, int stride_num) {
AssignStridedAddrGpu<T>
<<<BlocksNum4ThreadsNum(stride_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
dev_ptrs, start_ptr, stride_len, stride_num);
}
} // namespace } // namespace
#define MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY(func_name, T) cub::DeviceReduce::func_name<T*, T*> #define MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY(func_name, T) cub::DeviceReduce::func_name<T*, T*>
...@@ -227,24 +127,6 @@ DEFINE_STATIC_SWITCH_FUNC(cudaError_t, Sum, MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY, ...@@ -227,24 +127,6 @@ DEFINE_STATIC_SWITCH_FUNC(cudaError_t, Sum, MAKE_CUB_DEVICE_REDUCE_SWITCH_ENTRY,
template<typename T, typename Derived> \ template<typename T, typename Derived> \
void GpuKernelUtilIf<T, Derived>:: void GpuKernelUtilIf<T, Derived>::
KU_IF_METHOD CopyColsRegion(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num,
const T* x, const int64_t x_col_offset, const int64_t x_lda, T* y,
const int64_t y_col_offset, const int64_t y_lda) {
CopyColsRegionGpu<T>
<<<BlocksNum4ThreadsNum(row_num * col_num), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>(
row_num, col_num, x, x_col_offset, x_lda, y, y_col_offset, y_lda);
}
KU_IF_METHOD Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const T* x, T* y) {
CHECK_LE(y_shape.elem_cnt(), GetMaxVal<int32_t>());
CHECK_EQ(num_axis, y_shape.NumAxes());
CHECK_EQ(num_axis, x_shape.NumAxes());
TransposeUtil<T>::SwitchTranspose(SwitchCase(num_axis), ctx, x_shape, y_shape, permutation,
elem_cnt, x, y);
}
KU_IF_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf, KU_IF_METHOD InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob) { uint32_t random_seed, Blob* blob) {
WithHostBlobAndStreamSynchronizeEnv(ctx, blob, [&](Blob* host_blob) { WithHostBlobAndStreamSynchronizeEnv(ctx, blob, [&](Blob* host_blob) {
......
...@@ -30,13 +30,20 @@ namespace oneflow { ...@@ -30,13 +30,20 @@ namespace oneflow {
class Blob; class Blob;
class InitializerConf; class InitializerConf;
class MemoryCase; class MemoryCase;
class StreamContext;
void AutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz, void AutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz,
const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case); const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case);
void AutoMemcpy(DeviceCtx* ctx, Blob* dst, const Blob* src);
void AutoMemcpy(StreamContext* stream_ctx, void* dst, const void* src, size_t sz,
const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case);
void AutoMemcpy(StreamContext* stream_ctx, Blob* dst, const Blob* src);
void SyncAutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz, void SyncAutoMemcpy(DeviceCtx* ctx, void* dst, const void* src, size_t sz,
const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case); const MemoryCase& dst_mem_case, const MemoryCase& src_mem_case);
void AutoMemset(DeviceCtx* ctx, void* dst, const char value, size_t sz, void AutoMemset(DeviceCtx* ctx, void* dst, const char value, size_t sz,
const MemoryCase& dst_mem_case); const MemoryCase& dst_mem_case);
void AutoMemset(StreamContext* stream_ctx, void* dst, const char value, size_t sz,
const MemoryCase& dst_mem_case);
template<DeviceType device_type, typename T, typename U = void> template<DeviceType device_type, typename T, typename U = void>
struct KernelUtil; struct KernelUtil;
...@@ -46,12 +53,6 @@ template<typename T, typename Derived> ...@@ -46,12 +53,6 @@ template<typename T, typename Derived>
struct CpuKernelUtilIf { struct CpuKernelUtilIf {
static void Axpy(DeviceCtx* ctx, const int n, const T* alpha, const T* x, const int incx, T* y, static void Axpy(DeviceCtx* ctx, const int n, const T* alpha, const T* x, const int incx, T* y,
const int incy); const int incy);
static void CopyColsRegion(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num,
const T* x, const int64_t x_col_offset, const int64_t x_lda, T* y,
const int64_t y_col_offset, const int64_t y_lda);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const T* x, T* y);
static void Set(DeviceCtx* ctx, const T value, T* addr); static void Set(DeviceCtx* ctx, const T value, T* addr);
}; };
...@@ -104,12 +105,6 @@ struct KernelUtil<DeviceType::kCPU, T, typename std::enable_if<IsIntegral<T>::va ...@@ -104,12 +105,6 @@ struct KernelUtil<DeviceType::kCPU, T, typename std::enable_if<IsIntegral<T>::va
// GPU, Integral, Floating // GPU, Integral, Floating
template<typename T, typename Derived> template<typename T, typename Derived>
struct GpuKernelUtilIf { struct GpuKernelUtilIf {
static void CopyColsRegion(DeviceCtx* ctx, const int64_t row_num, const int64_t col_num,
const T* x, const int64_t x_col_offset, const int64_t x_lda, T* y,
const int64_t y_col_offset, const int64_t y_lda);
static void Transpose(DeviceCtx* ctx, const int32_t num_axis, const ShapeView& x_shape,
const ShapeView& y_shape, const PbRf<int32_t>& permutation,
const int64_t elem_cnt, const T* x, T* y);
static void InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf, static void InitializeWithConf(DeviceCtx* ctx, const InitializerConf& initializer_conf,
uint32_t random_seed, Blob* blob); uint32_t random_seed, Blob* blob);
static void Set(DeviceCtx* ctx, const T value, T* addr); static void Set(DeviceCtx* ctx, const T value, T* addr);
...@@ -158,50 +153,6 @@ struct KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsIntegral<T>::va ...@@ -158,50 +153,6 @@ struct KernelUtil<DeviceType::kGPU, T, typename std::enable_if<IsIntegral<T>::va
static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z); static void Mul(DeviceCtx* ctx, const int64_t n, const T* x, const T* y, T* z);
}; };
using CopyBlobFieldMthd = void (Blob::*)(DeviceCtx*, const Blob*);
class DataContentIterator final {
public:
OF_DISALLOW_COPY_AND_MOVE(DataContentIterator);
DataContentIterator() = delete;
~DataContentIterator() = default;
DataContentIterator(std::function<Blob*(const std::string&)> BnInOp2Blob,
const PbRpf<std::string>* bns, int32_t axis) {
BnInOp2Blob_ = BnInOp2Blob;
seg_num_ = BnInOp2Blob(bns->Get(0))->static_shape().Count(0, axis);
seg_idx_ = 0;
bns_ = bns;
bn_idx_ = 0;
axis_ = axis;
}
std::tuple<char*, size_t> GetNext() {
std::tuple<char*, size_t> ret(nullptr, 0);
if (seg_idx_ == seg_num_) { return ret; }
Blob* blob = BnInOp2Blob_(bns_->Get(bn_idx_));
int64_t elem_num = blob->static_shape().Count(axis_);
std::get<1>(ret) = elem_num * GetSizeOfDataType(blob->data_type());
std::get<0>(ret) = blob->mut_dptr<char>() + seg_idx_ * std::get<1>(ret);
bn_idx_ += 1;
if (bn_idx_ == bns_->size()) {
bn_idx_ = 0;
seg_idx_ += 1;
}
return ret;
}
static CopyBlobFieldMthd GetCopyBlobFieldMthd() { return &Blob::CopyDataContentFrom; }
private:
std::function<Blob*(const std::string&)> BnInOp2Blob_;
int64_t seg_num_;
int64_t seg_idx_;
const PbRpf<std::string>* bns_;
int32_t bn_idx_;
int32_t axis_;
};
template<typename T, typename U> template<typename T, typename U>
typename std::enable_if<std::is_same<T, U>::value>::type CopyElem(const T* in_dptr, U* out_dptr, typename std::enable_if<std::is_same<T, U>::value>::type CopyElem(const T* in_dptr, U* out_dptr,
int64_t elem_num) { int64_t elem_num) {
......
...@@ -20,7 +20,6 @@ limitations under the License. ...@@ -20,7 +20,6 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class OutputKernel final : public Kernel { class OutputKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(OutputKernel); OF_DISALLOW_COPY_AND_MOVE(OutputKernel);
...@@ -32,8 +31,7 @@ class OutputKernel final : public Kernel { ...@@ -32,8 +31,7 @@ class OutputKernel final : public Kernel {
void ForwardHeader(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void OutputKernel::ForwardDataContent(KernelContext* ctx) const {
void OutputKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().output_conf().has_job_name()); CHECK(this->op_conf().output_conf().has_job_name());
const auto& job_name = this->op_conf().output_conf().job_name(); const auto& job_name = this->op_conf().output_conf().job_name();
...@@ -48,12 +46,11 @@ void OutputKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { ...@@ -48,12 +46,11 @@ void OutputKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name); job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
} }
} else { } else {
ctx->BnInOp2Blob("out")->CopyDataContentFrom(ctx->device_ctx(), ctx->BnInOp2Blob("in")); AutoMemcpy(ctx->stream_ctx(), ctx->BnInOp2Blob("out"), ctx->BnInOp2Blob("in"));
} }
} }
template<DeviceType device_type> void OutputKernel::ForwardHeader(KernelContext* ctx) const {
void OutputKernel<device_type>::ForwardHeader(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
// Do nothing. // Do nothing.
} else { } else {
...@@ -61,6 +58,6 @@ void OutputKernel<device_type>::ForwardHeader(KernelContext* ctx) const { ...@@ -61,6 +58,6 @@ void OutputKernel<device_type>::ForwardHeader(KernelContext* ctx) const {
} }
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kOutputConf, OutputKernel); REGISTER_KERNEL(OperatorConf::kOutputConf, OutputKernel);
} // namespace oneflow } // namespace oneflow
...@@ -20,7 +20,6 @@ limitations under the License. ...@@ -20,7 +20,6 @@ limitations under the License.
namespace oneflow { namespace oneflow {
template<DeviceType device_type>
class ReturnKernel final : public Kernel { class ReturnKernel final : public Kernel {
public: public:
OF_DISALLOW_COPY_AND_MOVE(ReturnKernel); OF_DISALLOW_COPY_AND_MOVE(ReturnKernel);
...@@ -32,8 +31,7 @@ class ReturnKernel final : public Kernel { ...@@ -32,8 +31,7 @@ class ReturnKernel final : public Kernel {
void ForwardHeader(KernelContext* ctx) const override; void ForwardHeader(KernelContext* ctx) const override;
}; };
template<DeviceType device_type> void ReturnKernel::ForwardDataContent(KernelContext* ctx) const {
void ReturnKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
CHECK(this->op_conf().return_conf().has_job_name()); CHECK(this->op_conf().return_conf().has_job_name());
const auto& job_name = this->op_conf().return_conf().job_name(); const auto& job_name = this->op_conf().return_conf().job_name();
...@@ -48,13 +46,12 @@ void ReturnKernel<device_type>::ForwardDataContent(KernelContext* ctx) const { ...@@ -48,13 +46,12 @@ void ReturnKernel<device_type>::ForwardDataContent(KernelContext* ctx) const {
job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name); job_instance->PullBlobByOpName(reinterpret_cast<uint64_t>(&ofblob), op_name);
} }
} else { } else {
ctx->BnInOp2Blob("out")->CopyDataContentFrom(ctx->device_ctx(), ctx->BnInOp2Blob("in")); AutoMemcpy(ctx->stream_ctx(), ctx->BnInOp2Blob("out"), ctx->BnInOp2Blob("in"));
ctx->device_ctx()->SyncDevice(); ctx->device_ctx()->SyncDevice();
} }
} }
template<DeviceType device_type> void ReturnKernel::ForwardHeader(KernelContext* ctx) const {
void ReturnKernel<device_type>::ForwardHeader(KernelContext* ctx) const {
if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) { if (CHECK_JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
// Do nothing. // Do nothing.
} else { } else {
...@@ -62,6 +59,6 @@ void ReturnKernel<device_type>::ForwardHeader(KernelContext* ctx) const { ...@@ -62,6 +59,6 @@ void ReturnKernel<device_type>::ForwardHeader(KernelContext* ctx) const {
} }
} }
ADD_DEVICE_TYPE_KERNEL_CREATOR(OperatorConf::kReturnConf, ReturnKernel); REGISTER_KERNEL(OperatorConf::kReturnConf, ReturnKernel);
} // namespace oneflow } // namespace oneflow
...@@ -42,21 +42,6 @@ void Blob::Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* hea ...@@ -42,21 +42,6 @@ void Blob::Init(const MemoryCase& mem_case, const BlobDesc* blob_desc, char* hea
MutShapeView(shape_ptr, static_shape().NumAxes()).set_shape(static_shape()); MutShapeView(shape_ptr, static_shape().NumAxes()).set_shape(static_shape());
} }
void Blob::CopyDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
this->blob_access_checker()->CheckBodyMutable();
AutoMemcpy(device_ctx, mut_dptr(), rhs->dptr(), ByteSizeOfBlobBody(), mem_case(),
rhs->mem_case());
}
void Blob::CopyValidDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs) {
if (this == rhs) { return; }
this->blob_access_checker()->CheckBodyMutable();
const size_t body_byte_size = ByteSizeOfBlobBody();
CHECK_EQ(rhs->ByteSizeOfBlobBody(), body_byte_size);
AutoMemcpy(device_ctx, mut_dptr(), rhs->dptr(), body_byte_size, mem_case(), rhs->mem_case());
}
void Blob::CopyHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs) { void Blob::CopyHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs) {
size_t header_size = blob_desc().ByteSizeOfBlobHeader(); size_t header_size = blob_desc().ByteSizeOfBlobHeader();
CHECK_EQ(header_size, rhs->blob_desc().ByteSizeOfBlobHeader()); CHECK_EQ(header_size, rhs->blob_desc().ByteSizeOfBlobHeader());
......
...@@ -87,8 +87,6 @@ class Blob final { ...@@ -87,8 +87,6 @@ class Blob final {
void reset_dptr(char* dptr) { dptr_ = dptr; } void reset_dptr(char* dptr) { dptr_ = dptr; }
void CopyDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs);
void CopyValidDataContentFrom(DeviceCtx* device_ctx, const Blob* rhs);
void CopyHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs); void CopyHeaderFrom(DeviceCtx* device_ctx, const Blob* rhs);
bool IsBodyEmpty() const { return shape().elem_cnt() == 0; } bool IsBodyEmpty() const { return shape().elem_cnt() == 0; }
......
...@@ -38,10 +38,9 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu ...@@ -38,10 +38,9 @@ class ReduceKernel final : public user_op::OpKernel, public user_op::CudaGraphSu
if (input_tensor->shape().elem_cnt() == 0) { if (input_tensor->shape().elem_cnt() == 0) {
if (output_tensor->shape().elem_cnt() != 0) { if (output_tensor->shape().elem_cnt() != 0) {
AutoMemset( Memset<device_type>(
ctx->device_ctx(), output_tensor->mut_dptr<T>(), 0, ctx->device_ctx(), output_tensor->mut_dptr<T>(), 0,
output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()), output_tensor->shape().elem_cnt() * GetSizeOfDataType(output_tensor->data_type()));
output_tensor->mem_case());
} }
return; return;
} }
......
...@@ -44,9 +44,9 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::Cu ...@@ -44,9 +44,9 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::Cu
const auto& axis = ctx->Attr<std::vector<int32_t>>("axis"); const auto& axis = ctx->Attr<std::vector<int32_t>>("axis");
if (tensor_x->shape().elem_cnt() == 0) { if (tensor_x->shape().elem_cnt() == 0) {
if (tensor_y->shape().elem_cnt() != 0) { if (tensor_y->shape().elem_cnt() != 0) {
AutoMemset(ctx->device_ctx(), tensor_y->mut_dptr<T>(), 0, Memset<device_type>(
tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()), ctx->device_ctx(), tensor_y->mut_dptr<T>(), 0,
tensor_y->mem_case()); tensor_y->shape().elem_cnt() * GetSizeOfDataType(tensor_y->data_type()));
} }
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册