diff --git a/oneflow/core/kernel/dot_kernel.cpp b/oneflow/core/kernel/dot_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43c5ef3c818df80bd238f7c0e0e29c0dfd0dd4b0 --- /dev/null +++ b/oneflow/core/kernel/dot_kernel.cpp @@ -0,0 +1,76 @@ +#include "oneflow/core/kernel/dot_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template +void DotKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + const int64_t piece_size = in_blob->shape().At(0); + const int64_t dim = in_blob->shape().Count(1); + const Blob* weight_blob = BnInOp2Blob("weight"); + Blob* out_blob = BnInOp2Blob("out"); + Blob* tmp_blob = BnInOp2Blob("tmp"); + Blob* tmp_storage_blob = BnInOp2Blob("tmp_storage"); + // out = in .* weight + KernelUtil::Mul(ctx.device_ctx, piece_size * dim, in_blob->dptr(), + weight_blob->dptr(), tmp_blob->mut_dptr()); + KernelUtil::RowSum(ctx.device_ctx, piece_size, dim, tmp_blob->dptr(), + out_blob->mut_dptr(), tmp_storage_blob->mut_dptr(), + sizeof(T) * piece_size * dim); + if (this->op_conf().matmul_conf().has_bias()) { + const Blob* bias_blob = BnInOp2Blob("bias"); + // out += bias + KernelUtil::Axpy(ctx.device_ctx, piece_size, OneVal::value, + bias_blob->dptr(), 1, out_blob->mut_dptr(), 1); + } +} + +template +void DotKernel::BackwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + const Blob* in_blob = BnInOp2Blob("in"); + const Blob* diff_mul_blob = BnInOp2Blob("diff_multiplier"); + Blob* tmp_blob = BnInOp2Blob("tmp"); + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + const Blob* weight_blob = BnInOp2Blob("weight"); + Blob* weight_diff_blob = BnInOp2Blob("weight_diff"); + + // tmp = out_diff * diff_mul_blob + KernelUtil::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal::value, + ZeroVal::value, out_diff_blob, diff_mul_blob, tmp_blob); + // weight_diff = tmp .* in + KernelUtil::Mul(ctx.device_ctx, in_blob->shape().elem_cnt(), tmp_blob->dptr(), + in_blob->dptr(), weight_diff_blob->mut_dptr()); + // in_diff = tmp .* weight + KernelUtil::Mul(ctx.device_ctx, weight_blob->shape().elem_cnt(), + tmp_blob->dptr(), weight_blob->dptr(), + in_diff_blob->mut_dptr()); + + if (this->op_conf().matmul_conf().has_bias()) { + Blob* bias_diff_blob = BnInOp2Blob("bias_diff"); + // bias_diff = out_diff + KernelUtil::Copy(ctx.device_ctx, out_diff_blob->shape().elem_cnt(), + out_diff_blob->dptr(), 1, bias_diff_blob->mut_dptr(), 1); + } +} + +template +void DotKernel::InitConstBufBlobs( + DeviceCtx* ctx, std::function BnInOp2Blob) const { + InitializerConf diff_multiplier_initializer_conf; + diff_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f); + KernelUtil::InitializeWithConf(ctx, diff_multiplier_initializer_conf, 0, + BnInOp2Blob("diff_multiplier")); +} + +template +const PbMessage& DotKernel::GetCustomizedOpConf() const { + return this->op_conf().dot_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kDotConf, DotKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/dot_kernel.h b/oneflow/core/kernel/dot_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8bb4998539f7c4f3a1bdb46500986b635925412a --- /dev/null +++ b/oneflow/core/kernel/dot_kernel.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_KERNEL_DOT_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_DOT_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +namespace oneflow { + +template +class DotKernel final : public KernelIfWithModel { + public: + OF_DISALLOW_COPY_AND_MOVE(DotKernel); + DotKernel() = default; + ~DotKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; + void BackwardDataContent(const KernelCtx&, + std::function) const override; + void InitConstBufBlobs(DeviceCtx*, + std::function BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOE_CORE_KERNEL_DOT_KERNEL_H_ diff --git a/oneflow/core/kernel/matmul_kernel.cpp b/oneflow/core/kernel/matmul_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e47f6e0bfd6c066ad90cd82df7382dff5f6abcb3 --- /dev/null +++ b/oneflow/core/kernel/matmul_kernel.cpp @@ -0,0 +1,66 @@ +#include "oneflow/core/kernel/matmul_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template +void MatmulKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* in_blob = BnInOp2Blob("in"); + const Blob* weight_blob = BnInOp2Blob("weight"); + Blob* out_blob = BnInOp2Blob("out"); + // out = in * weight' + KernelUtil::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasTrans, OneVal::value, + ZeroVal::value, in_blob, weight_blob, out_blob); + if (this->op_conf().matmul_conf().has_bias()) { + const Blob* bias_blob = BnInOp2Blob("bias"); + const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); + // out = bias_multiplier * bias + out + KernelUtil::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, + OneVal::value, OneVal::value, bias_mul_blob, + bias_blob, out_blob); + } +} + +template +void MatmulKernel::BackwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + const Blob* in_blob = BnInOp2Blob("in"); + Blob* in_diff_blob = BnInOp2Blob("in_diff"); + const Blob* weight_blob = BnInOp2Blob("weight"); + Blob* weight_diff_blob = BnInOp2Blob("weight_diff"); + // weight_diff = out_diff * in' + KernelUtil::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal::value, + ZeroVal::value, out_diff_blob, in_blob, weight_diff_blob); + // in_diff = out_diff * weight + KernelUtil::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal::value, + ZeroVal::value, out_diff_blob, weight_blob, in_diff_blob); + if (this->op_conf().matmul_conf().has_bias()) { + const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier"); + Blob* bias_diff_blob = BnInOp2Blob("bias_diff"); + // bias_diff = bias_multiplier' * out_diff + KernelUtil::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal::value, + ZeroVal::value, bias_mul_blob, out_diff_blob, + bias_diff_blob); + } +} + +template +void MatmulKernel::InitConstBufBlobs( + DeviceCtx* ctx, std::function BnInOp2Blob) const { + if (!this->op_conf().matmul_conf().has_bias()) { return; } + InitializerConf bias_multiplier_initializer_conf; + bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f); + KernelUtil::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0, + BnInOp2Blob("bias_multiplier")); +} + +template +const PbMessage& MatmulKernel::GetCustomizedOpConf() const { + return this->op_conf().matmul_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMatmulConf, MatmulKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/matmul_kernel.h b/oneflow/core/kernel/matmul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4a82c3da53b4dab006fb322335c2549cc84f3a0b --- /dev/null +++ b/oneflow/core/kernel/matmul_kernel.h @@ -0,0 +1,26 @@ +#ifndef ONEFLOW_CORE_KERNEL_MATMUL_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_MATMUL_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +namespace oneflow { + +template +class MatmulKernel final : public KernelIfWithModel { + public: + OF_DISALLOW_COPY_AND_MOVE(MatmulKernel); + MatmulKernel() = default; + ~MatmulKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; + void BackwardDataContent(const KernelCtx&, + std::function) const override; + void InitConstBufBlobs(DeviceCtx*, + std::function BnInOp2Blob) const override; + const PbMessage& GetCustomizedOpConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOE_CORE_KERNEL_MATMUL_KERNEL_H_ diff --git a/oneflow/core/kernel/multiply_kernel.cpp b/oneflow/core/kernel/multiply_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eb9defa32fcd261b4331233ae2b58fb1b0466778 --- /dev/null +++ b/oneflow/core/kernel/multiply_kernel.cpp @@ -0,0 +1,43 @@ +#include "oneflow/core/kernel/multiply_kernel.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template +void MultiplyKernel::ForwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* in_0_blob = BnInOp2Blob("in_0"); + const Blob* in_1_blob = BnInOp2Blob("in_1"); + Blob* out_blob = BnInOp2Blob("out"); + // out = in_0 .* in_1 + KernelUtil::Mul(ctx.device_ctx, in_0_blob->shape().elem_cnt(), + in_0_blob->dptr(), in_1_blob->dptr(), + out_blob->mut_dptr()); +} + +template +void MultiplyKernel::BackwardDataContent( + const KernelCtx& ctx, std::function BnInOp2Blob) const { + const Blob* out_diff_blob = BnInOp2Blob("out_diff"); + const Blob* in_0_blob = BnInOp2Blob("in_0"); + Blob* in_0_diff_blob = BnInOp2Blob("in_0_diff"); + const Blob* in_1_blob = BnInOp2Blob("in_1"); + Blob* in_1_diff_blob = BnInOp2Blob("in_1_diff"); + // in_1_diff = out_diff * in_0 + KernelUtil::Mul(ctx.device_ctx, in_0_blob->shape().elem_cnt(), + in_0_blob->dptr(), out_diff_blob->dptr(), + in_1_diff_blob->mut_dptr()); + // in_0_diff = out_diff * in_1 + KernelUtil::Mul(ctx.device_ctx, in_1_blob->shape().elem_cnt(), + in_1_blob->dptr(), out_diff_blob->dptr(), + in_0_diff_blob->mut_dptr()); +} + +template +const PbMessage& MultiplyKernel::GetCustomizedOpConf() const { + return this->op_conf().multiply_conf(); +} + +ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMultiplyConf, MultiplyKernel, FLOATING_DATA_TYPE_SEQ); + +} // namespace oneflow diff --git a/oneflow/core/kernel/multiply_kernel.h b/oneflow/core/kernel/multiply_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f9957fd519096ae9af7687c44b50f15adcff9580 --- /dev/null +++ b/oneflow/core/kernel/multiply_kernel.h @@ -0,0 +1,24 @@ +#ifndef ONEFLOW_CORE_KERNEL_MULTIPLY_KERNEL_H_ +#define ONEFLOW_CORE_KERNEL_MULTIPLY_KERNEL_H_ + +#include "oneflow/core/kernel/kernel.h" +namespace oneflow { + +template +class MultiplyKernel final : public KernelIfWithModel { + public: + OF_DISALLOW_COPY_AND_MOVE(MultiplyKernel); + MultiplyKernel() = default; + ~MultiplyKernel() = default; + + private: + void ForwardDataContent(const KernelCtx&, + std::function) const override; + void BackwardDataContent(const KernelCtx&, + std::function) const override; + const PbMessage& GetCustomizedOpConf() const override; +}; + +} // namespace oneflow + +#endif // ONEFLOE_CORE_KERNEL_MULTIPLY_KERNEL_H_ diff --git a/oneflow/core/operator/dot_op.cpp b/oneflow/core/operator/dot_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09fb47423555f92326871e517161ca65c45cbfbf --- /dev/null +++ b/oneflow/core/operator/dot_op.cpp @@ -0,0 +1,40 @@ +#include "oneflow/core/operator/dot_op.h" +namespace oneflow { + +void DotOp::InitFromOpConf() { + CHECK(op_conf().has_dot_conf()); + + EnrollInputBn("in"); + EnrollInputBn("weight"); + EnrollDataTmpBn("tmp"); + EnrollDataTmpBn("tmp_storage"); + EnrollConstBufBn("diff_multiplier"); + EnrollOutputBn("out"); + if (op_conf().dot_conf().has_bias()) { EnrollInputBn("bias"); } +} + +const PbMessage& DotOp::GetCustomizedConf() const { return op_conf().dot_conf(); } + +void DotOp::InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); + BlobDesc* weight_blob_desc = GetBlobDesc4BnInOp("weight"); + CHECK_EQ(in_blob_desc->data_type(), Global::Get()->DefaultDataType()); + CHECK_EQ(in_blob_desc->shape().At(0), weight_blob_desc->shape().At(0)); + CHECK_EQ(in_blob_desc->shape().Count(1), weight_blob_desc->shape().Count(1)); + // tmp & tmp storage + BlobDesc* tmp_blob_desc = GetBlobDesc4BnInOp("tmp"); + *tmp_blob_desc = *in_blob_desc; + BlobDesc* tmp_storage_blob_desc = GetBlobDesc4BnInOp("tmp_storage"); + *tmp_storage_blob_desc = *in_blob_desc; + // out + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *in_blob_desc; + out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0)}); + // diff_multiplier + GetBlobDesc4BnInOp("diff_multiplier")->mut_shape() = Shape({1, in_blob_desc->shape().Count(1)}); +} + +REGISTER_OP(OperatorConf::kDotConf, DotOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/dot_op.h b/oneflow/core/operator/dot_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dc2291eb6d50b517100765c30021a4dea06bbb99 --- /dev/null +++ b/oneflow/core/operator/dot_op.h @@ -0,0 +1,19 @@ +#ifndef ONEFLOW_CORE_OPERATOR_DOT_OP_H_ +#define ONEFLOW_CORE_OPERATOR_DOT_OP_H_ +#include "oneflow/core/operator/operator.h" +namespace oneflow { + +class DotOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(DotOp); + DotOp() = default; + ~DotOp() = default; + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_DOT_OP_H_ diff --git a/oneflow/core/operator/matmul_op.cpp b/oneflow/core/operator/matmul_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62bcb2a1daacb077c47500b288160ffca1c64da2 --- /dev/null +++ b/oneflow/core/operator/matmul_op.cpp @@ -0,0 +1,41 @@ +#include "oneflow/core/operator/matmul_op.h" +#include "oneflow/core/common/balanced_splitter.h" +namespace oneflow { + +void MatmulOp::InitFromOpConf() { + CHECK(op_conf().has_matmul_conf()); + + EnrollInputBn("in"); + EnrollInputBn("weight"); + EnrollOutputBn("out"); + if (op_conf().matmul_conf().has_bias()) { + EnrollInputBn("bias"); + EnrollConstBufBn("bias_multiplier"); + } +} + +const PbMessage& MatmulOp::GetCustomizedConf() const { return op_conf().matmul_conf(); } + +void MatmulOp::InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + const MatmulOpConf& conf = op_conf().matmul_conf(); + BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in"); + CHECK_EQ(in_blob_desc->data_type(), Global::Get()->DefaultDataType()); + int32_t units = conf.units(); + if (parallel_ctx->policy() == kModelParallel) { + BalancedSplitter splitter(units, parallel_ctx->parallel_num()); + units = splitter.At(parallel_ctx->parallel_id()).size(); + } + // out + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *in_blob_desc; + out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), units}); + if (conf.has_bias()) { + // bias_multiplier + GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({in_blob_desc->shape().At(0), 1}); + } +} + +REGISTER_OP(OperatorConf::kMatmulConf, MatmulOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/matmul_op.h b/oneflow/core/operator/matmul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9c07cb64157015f71c3467727c730919c21f25ea --- /dev/null +++ b/oneflow/core/operator/matmul_op.h @@ -0,0 +1,21 @@ +#ifndef ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_ +#define ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_ +#include "oneflow/core/operator/operator.h" +namespace oneflow { + +class MatmulOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(MatmulOp); + MatmulOp() = default; + ~MatmulOp() = default; + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + int32_t ModelSplitAxis() const override { return 1; } + int32_t MaxModelSplitNum() const override { return op_conf().matmul_conf().units(); } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_ diff --git a/oneflow/core/operator/multiply_op.cpp b/oneflow/core/operator/multiply_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be51110b729bf77ce16db29d098f7484bf6f9ca7 --- /dev/null +++ b/oneflow/core/operator/multiply_op.cpp @@ -0,0 +1,26 @@ +#include "oneflow/core/operator/multiply_op.h" +#include "oneflow/core/common/balanced_splitter.h" +namespace oneflow { + +void MultiplyOp::InitFromOpConf() { + CHECK(op_conf().has_multiply_conf()); + EnrollInputBn("in_0"); + EnrollInputBn("in_1"); + EnrollOutputBn("out"); +} + +const PbMessage& MultiplyOp::GetCustomizedConf() const { return op_conf().multiply_conf(); } + +void MultiplyOp::InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp("in_0"); + BlobDesc* in_1_blob_desc = GetBlobDesc4BnInOp("in_1"); + CHECK_EQ(in_0_blob_desc->data_type(), Global::Get()->DefaultDataType()); + CHECK_EQ(in_0_blob_desc->shape(), in_1_blob_desc->shape()); + // out + BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); + *out_blob_desc = *in_0_blob_desc; +} +REGISTER_OP(OperatorConf::kMultiplyConf, MultiplyOp); + +} // namespace oneflow diff --git a/oneflow/core/operator/multiply_op.h b/oneflow/core/operator/multiply_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9194a1c210a94cd188d422e0f0fb85388a56555f --- /dev/null +++ b/oneflow/core/operator/multiply_op.h @@ -0,0 +1,19 @@ +#ifndef ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_ +#define ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_ +#include "oneflow/core/operator/operator.h" +namespace oneflow { + +class MultiplyOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(MultiplyOp); + MultiplyOp() = default; + ~MultiplyOp() = default; + void InitFromOpConf() override; + const PbMessage& GetCustomizedConf() const override; + void InferBlobDescs(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_ diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 97202ec9bf56c998edd8a5ef1f584a4f440ca2af..1dfd084efa173b090985c034b7f94344ac968ae2 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -609,6 +609,26 @@ message AccuracyOpConf { required string accuracy = 4; } +message MatmulOpConf { + required string in = 1; + required string weight = 2; + optional string bias = 3; + required int32 units = 4; + required string out = 5; +} + +message DotOpConf { + required string in = 1; + required string weight = 2; + optional string bias = 3; + required string out = 4; +} + +message MultiplyOpConf { + required string in_0 = 1; + required string in_1 = 2; + required string out = 4; +} enum Norm { L1 = 1; L2 = 2; @@ -658,7 +678,10 @@ message OperatorConf { Conv2DOpConf conv_2d_conf = 126; Conv3DOpConf conv_3d_conf = 127; TransposeOpConf transpose_conf = 128; - HingeLossOpConf hinge_loss_conf = 129; + MatmulOpConf matmul_conf = 129; + DotOpConf dot_conf = 130; + MultiplyOpConf multiply_conf = 131; + HingeLossOpConf hinge_loss_conf = 132; DropoutOpConf dropout_conf = 140; AveragePooling1DOpConf average_pooling_1d_conf = 200; MaxPooling1DOpConf max_pooling_1d_conf = 201;