提交 4a3f6a24 编写于 作者: B bubuface1987 提交者: chengtbf

add the initial version of concat kernel (#209)

* add the initial version of concat kernel

* when concat axis num is negative, converts it to positive

* indicates the Memcpy kind as cudaMemcpyDeviceToDevice
上级 6bbc8f0f
#include "oneflow/core/kernel/concat_kernel.h"
namespace oneflow {
template<DeviceType device_type, typename FloatingPointType>
void ConcatKernel<device_type, FloatingPointType>::Forward(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2BlobPtr) const {
const std::vector<std::string>& ibns = op()->input_bns();
if (ibns.size() == 0) return;
Blob* out_blob = BnInOp2BlobPtr(op()->SoleObn());
int32_t concat_axis = op()->op_conf().concat_conf().axis();
if (concat_axis < 0) concat_axis += out_blob->shape().NumAxes();
const int64_t concat_num_each_blob = out_blob->shape().Count(0, concat_axis);
const int64_t concat_element_size = out_blob->shape().Count(concat_axis + 1);
const int64_t out_concat_axis_size = out_blob->shape().At(concat_axis);
int64_t offset_concat_axis = 0;
for (size_t ibn_idx = 0; ibn_idx < ibns.size(); ++ibn_idx) {
const Blob* ibn_blob = BnInOp2BlobPtr(ibns[ibn_idx]);
const int64_t in_concat_axis_size = ibn_blob->shape().At(concat_axis);
for (int64_t concat_idx = 0; concat_idx < concat_num_each_blob;
++concat_idx) {
KernelUtil<device_type, FloatingPointType>::Memcpy(
ctx,
(static_cast<FloatingPointType*>(out_blob->mut_dptr()))
+ (concat_idx * out_concat_axis_size + offset_concat_axis)
* concat_element_size,
(static_cast<const FloatingPointType*>(ibn_blob->dptr()))
+ concat_idx * in_concat_axis_size * concat_element_size,
in_concat_axis_size * concat_element_size * sizeof(FloatingPointType),
cudaMemcpyKind::cudaMemcpyDeviceToDevice);
}
offset_concat_axis += in_concat_axis_size;
}
}
template<DeviceType device_type, typename FloatingPointType>
void ConcatKernel<device_type, FloatingPointType>::Backward(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2BlobPtr) const {
const Blob* odbn_blob = BnInOp2BlobPtr(op()->SoleOdbn());
const std::vector<std::string>& idbns = op()->input_diff_bns();
int32_t split_axis = op()->op_conf().concat_conf().axis();
if (split_axis < 0) split_axis += odbn_blob->shape().NumAxes();
const int64_t split_num_each_blob = odbn_blob->shape().Count(0, split_axis);
const int64_t split_element_size = odbn_blob->shape().Count(split_axis + 1);
const int64_t out_diff_split_axis_size = odbn_blob->shape().At(split_axis);
int64_t offset_split_axis = 0;
for (size_t idbns_idx = 0; idbns_idx < idbns.size(); ++idbns_idx) {
Blob* idbn_blob = BnInOp2BlobPtr(idbns[idbns_idx]);
const int64_t in_diff_split_axis_size = idbn_blob->shape().At(split_axis);
for (int64_t split_idx = 0; split_idx < split_num_each_blob; ++split_idx) {
KernelUtil<device_type, FloatingPointType>::Memcpy(
ctx,
static_cast<FloatingPointType*>(idbn_blob->mut_dptr())
+ split_idx * in_diff_split_axis_size * split_element_size,
static_cast<const FloatingPointType*>(odbn_blob->dptr())
+ (split_idx * out_diff_split_axis_size + offset_split_axis)
* split_element_size,
in_diff_split_axis_size * split_element_size
* sizeof(FloatingPointType),
cudaMemcpyKind::cudaMemcpyDeviceToDevice);
}
offset_split_axis += in_diff_split_axis_size;
}
}
INSTANTIATE_KERNEL_CLASS(ConcatKernel);
REGISTER_KERNEL(OperatorConf::kConcatConf, ConcatKernel);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_
#include "oneflow/core/kernel/kernel_manager.h"
namespace oneflow {
template<DeviceType device_type, typename FloatingPointType>
class ConcatKernel final : public Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(ConcatKernel);
ConcatKernel() = default;
~ConcatKernel() = default;
void Forward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_CONCAT_KERNEL_H_
#include "oneflow/core/kernel/concat_kernel.h"
#include <random>
#include "oneflow/core/device/cpu_device_context.h"
#include "oneflow/core/device/cuda_device_context.h"
#include "oneflow/core/kernel/kernel_test_common.h"
namespace oneflow {
namespace test {
namespace {
template<DeviceType device_type, typename FloatingPointType>
Kernel* BuildConcatKernel() {
OperatorConf op_conf;
op_conf.set_name("concat_test");
ConcatOpConf* concat_conf = op_conf.mutable_concat_conf();
concat_conf->add_in("concat/in0");
concat_conf->add_in("concat/in1");
concat_conf->set_axis(1);
concat_conf->set_out("concat_kernel_test");
auto concat_op = ConstructOp(op_conf);
OperatorProto op_proto;
concat_op->ToProto(&op_proto);
auto concat_kernel = new ConcatKernel<device_type, FloatingPointType>();
concat_kernel->InitFromOpProto(op_proto);
return concat_kernel;
}
template<DeviceType device_type, typename FloatingPointType>
std::function<Blob*(const std::string&)> BuildBnInOp2BlobPtr() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
FloatingPointType in_0_mat[] = {1, 2, 3, 4, 5, 6};
FloatingPointType in_1_mat[] = {7, 8, 9, 10, 11, 12};
FloatingPointType out_mat[12] = {0};
FloatingPointType expected_out_mat[] = {1, 2, 3, 7, 8, 9,
4, 5, 6, 10, 11, 12};
FloatingPointType out_diff_mat[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
FloatingPointType in_0_diff_mat[6] = {0};
FloatingPointType in_1_diff_mat[6] = {0};
FloatingPointType expected_in_0_diff_mat[] = {1, 2, 3, 7, 8, 9};
FloatingPointType expected_in_1_diff_mat[] = {4, 5, 6, 10, 11, 12};
auto bn2blob_ptr = new HashMap<std::string, Blob*>;
(*bn2blob_ptr)["in_0"] = KTCommon::CreateBlobWithVector({2, 3}, in_0_mat);
(*bn2blob_ptr)["in_1"] = KTCommon::CreateBlobWithVector({2, 3}, in_1_mat);
(*bn2blob_ptr)["out"] = KTCommon::CreateBlobWithVector({2, 6}, out_mat);
(*bn2blob_ptr)["expected_out"] =
KTCommon::CreateBlobWithVector({2, 6}, expected_out_mat);
(*bn2blob_ptr)["out_diff"] =
KTCommon::CreateBlobWithVector({2, 6}, out_diff_mat);
(*bn2blob_ptr)["in_0_diff"] =
KTCommon::CreateBlobWithVector({2, 3}, in_0_diff_mat);
(*bn2blob_ptr)["in_1_diff"] =
KTCommon::CreateBlobWithVector({2, 3}, in_1_diff_mat);
(*bn2blob_ptr)["expected_in_0_diff"] =
KTCommon::CreateBlobWithVector({2, 3}, expected_in_0_diff_mat);
(*bn2blob_ptr)["expected_in_1_diff"] =
KTCommon::CreateBlobWithVector({2, 3}, expected_in_1_diff_mat);
return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); };
}
template<DeviceType device_type, typename FloatingPointType>
void TestConcatKernel() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
KernelCtx ctx;
KTCommon::BuildKernelCtx(&ctx);
auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr<device_type, FloatingPointType>();
auto concat_kernel = BuildConcatKernel<device_type, FloatingPointType>();
concat_kernel->Forward(ctx, BnInOp2BlobPtr);
concat_kernel->Backward(ctx, BnInOp2BlobPtr);
KTCommon::SyncStream(&ctx);
KTCommon::CheckResult(BnInOp2BlobPtr, "out", "expected_out");
KTCommon::CheckResult(BnInOp2BlobPtr, "in_0_diff", "expected_in_0_diff");
KTCommon::CheckResult(BnInOp2BlobPtr, "in_1_diff", "expected_in_1_diff");
}
} // namespace
} // namespace test
TEST(ConcatKernel, concat_cpu) {
test::TestConcatKernel<DeviceType::kCPU, float>();
test::TestConcatKernel<DeviceType::kCPU, double>();
}
TEST(ConcatKernel, concat_gpu) {
test::TestConcatKernel<DeviceType::kGPU, float>();
test::TestConcatKernel<DeviceType::kGPU, double>();
} // namespace oneflow
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册