diff --git a/oneflow/core/kernel/concat_kernel.cpp b/oneflow/core/kernel/concat_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..90768f37220bb4277f9ef9ba384531ab5ee7b141 --- /dev/null +++ b/oneflow/core/kernel/concat_kernel.cpp @@ -0,0 +1,71 @@ +#include "oneflow/core/kernel/concat_kernel.h" + +namespace oneflow { + +template +void ConcatKernel::Forward( + const KernelCtx& ctx, + std::function BnInOp2BlobPtr) const { + const std::vector& ibns = op()->input_bns(); + if (ibns.size() == 0) return; + Blob* out_blob = BnInOp2BlobPtr(op()->SoleObn()); + int32_t concat_axis = op()->op_conf().concat_conf().axis(); + if (concat_axis < 0) concat_axis += out_blob->shape().NumAxes(); + const int64_t concat_num_each_blob = out_blob->shape().Count(0, concat_axis); + const int64_t concat_element_size = out_blob->shape().Count(concat_axis + 1); + const int64_t out_concat_axis_size = out_blob->shape().At(concat_axis); + int64_t offset_concat_axis = 0; + for (size_t ibn_idx = 0; ibn_idx < ibns.size(); ++ibn_idx) { + const Blob* ibn_blob = BnInOp2BlobPtr(ibns[ibn_idx]); + const int64_t in_concat_axis_size = ibn_blob->shape().At(concat_axis); + for (int64_t concat_idx = 0; concat_idx < concat_num_each_blob; + ++concat_idx) { + KernelUtil::Memcpy( + ctx, + (static_cast(out_blob->mut_dptr())) + + (concat_idx * out_concat_axis_size + offset_concat_axis) + * concat_element_size, + (static_cast(ibn_blob->dptr())) + + concat_idx * in_concat_axis_size * concat_element_size, + in_concat_axis_size * concat_element_size * sizeof(FloatingPointType), + cudaMemcpyKind::cudaMemcpyDeviceToDevice); + } + offset_concat_axis += in_concat_axis_size; + } +} + +template +void ConcatKernel::Backward( + const KernelCtx& ctx, + std::function BnInOp2BlobPtr) const { + const Blob* odbn_blob = BnInOp2BlobPtr(op()->SoleOdbn()); + const std::vector& idbns = op()->input_diff_bns(); + int32_t split_axis = op()->op_conf().concat_conf().axis(); + if (split_axis < 0) split_axis += odbn_blob->shape().NumAxes(); + const int64_t split_num_each_blob = odbn_blob->shape().Count(0, split_axis); + const int64_t split_element_size = odbn_blob->shape().Count(split_axis + 1); + const int64_t out_diff_split_axis_size = odbn_blob->shape().At(split_axis); + int64_t offset_split_axis = 0; + for (size_t idbns_idx = 0; idbns_idx < idbns.size(); ++idbns_idx) { + Blob* idbn_blob = BnInOp2BlobPtr(idbns[idbns_idx]); + const int64_t in_diff_split_axis_size = idbn_blob->shape().At(split_axis); + for (int64_t split_idx = 0; split_idx < split_num_each_blob; ++split_idx) { + KernelUtil::Memcpy( + ctx, + static_cast(idbn_blob->mut_dptr()) + + split_idx * in_diff_split_axis_size * split_element_size, + static_cast(odbn_blob->dptr()) + + (split_idx * out_diff_split_axis_size + offset_split_axis) + * split_element_size, + in_diff_split_axis_size * split_element_size + * sizeof(FloatingPointType), + cudaMemcpyKind::cudaMemcpyDeviceToDevice); + } + offset_split_axis += in_diff_split_axis_size; + } +} + +INSTANTIATE_KERNEL_CLASS(ConcatKernel); +REGISTER_KERNEL(OperatorConf::kConcatConf, ConcatKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/concat_kernel.h b/oneflow/core/kernel/concat_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9f9c0cbb6ac989f7c6477733068c0857942312ac --- /dev/null +++ b/oneflow/core/kernel/concat_kernel.h @@ -0,0 +1,23 @@ +#ifndef ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel_manager.h" + +namespace oneflow { + +template +class ConcatKernel final : public Kernel { + public: + OF_DISALLOW_COPY_AND_MOVE(ConcatKernel); + ConcatKernel() = default; + ~ConcatKernel() = default; + + void Forward(const KernelCtx&, + std::function) const override; + void Backward(const KernelCtx&, + std::function) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_ diff --git a/oneflow/core/kernel/concat_kernel_test.cpp b/oneflow/core/kernel/concat_kernel_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f44745f40caf261bfd1b1779ee8da4b264ef819 --- /dev/null +++ b/oneflow/core/kernel/concat_kernel_test.cpp @@ -0,0 +1,97 @@ +#include "oneflow/core/kernel/concat_kernel.h" +#include +#include "oneflow/core/device/cpu_device_context.h" +#include "oneflow/core/device/cuda_device_context.h" +#include "oneflow/core/kernel/kernel_test_common.h" + +namespace oneflow { + +namespace test { + +namespace { + +template +Kernel* BuildConcatKernel() { + OperatorConf op_conf; + op_conf.set_name("concat_test"); + ConcatOpConf* concat_conf = op_conf.mutable_concat_conf(); + concat_conf->add_in("concat/in0"); + concat_conf->add_in("concat/in1"); + concat_conf->set_axis(1); + concat_conf->set_out("concat_kernel_test"); + auto concat_op = ConstructOp(op_conf); + OperatorProto op_proto; + concat_op->ToProto(&op_proto); + auto concat_kernel = new ConcatKernel(); + concat_kernel->InitFromOpProto(op_proto); + return concat_kernel; +} + +template +std::function BuildBnInOp2BlobPtr() { + using KTCommon = KernelTestCommon; + FloatingPointType in_0_mat[] = {1, 2, 3, 4, 5, 6}; + FloatingPointType in_1_mat[] = {7, 8, 9, 10, 11, 12}; + FloatingPointType out_mat[12] = {0}; + FloatingPointType expected_out_mat[] = {1, 2, 3, 7, 8, 9, + 4, 5, 6, 10, 11, 12}; + FloatingPointType out_diff_mat[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + FloatingPointType in_0_diff_mat[6] = {0}; + FloatingPointType in_1_diff_mat[6] = {0}; + FloatingPointType expected_in_0_diff_mat[] = {1, 2, 3, 7, 8, 9}; + FloatingPointType expected_in_1_diff_mat[] = {4, 5, 6, 10, 11, 12}; + + auto bn2blob_ptr = new HashMap; + + (*bn2blob_ptr)["in_0"] = KTCommon::CreateBlobWithVector({2, 3}, in_0_mat); + (*bn2blob_ptr)["in_1"] = KTCommon::CreateBlobWithVector({2, 3}, in_1_mat); + (*bn2blob_ptr)["out"] = KTCommon::CreateBlobWithVector({2, 6}, out_mat); + (*bn2blob_ptr)["expected_out"] = + KTCommon::CreateBlobWithVector({2, 6}, expected_out_mat); + (*bn2blob_ptr)["out_diff"] = + KTCommon::CreateBlobWithVector({2, 6}, out_diff_mat); + (*bn2blob_ptr)["in_0_diff"] = + KTCommon::CreateBlobWithVector({2, 3}, in_0_diff_mat); + (*bn2blob_ptr)["in_1_diff"] = + KTCommon::CreateBlobWithVector({2, 3}, in_1_diff_mat); + (*bn2blob_ptr)["expected_in_0_diff"] = + KTCommon::CreateBlobWithVector({2, 3}, expected_in_0_diff_mat); + (*bn2blob_ptr)["expected_in_1_diff"] = + KTCommon::CreateBlobWithVector({2, 3}, expected_in_1_diff_mat); + + return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); }; +} + +template +void TestConcatKernel() { + using KTCommon = KernelTestCommon; + KernelCtx ctx; + KTCommon::BuildKernelCtx(&ctx); + + auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr(); + auto concat_kernel = BuildConcatKernel(); + + concat_kernel->Forward(ctx, BnInOp2BlobPtr); + concat_kernel->Backward(ctx, BnInOp2BlobPtr); + KTCommon::SyncStream(&ctx); + + KTCommon::CheckResult(BnInOp2BlobPtr, "out", "expected_out"); + KTCommon::CheckResult(BnInOp2BlobPtr, "in_0_diff", "expected_in_0_diff"); + KTCommon::CheckResult(BnInOp2BlobPtr, "in_1_diff", "expected_in_1_diff"); +} + +} // namespace + +} // namespace test + +TEST(ConcatKernel, concat_cpu) { + test::TestConcatKernel(); + test::TestConcatKernel(); +} + +TEST(ConcatKernel, concat_gpu) { + test::TestConcatKernel(); + test::TestConcatKernel(); +} // namespace oneflow + +} // namespace oneflow