提交 6ab4006f 编写于 作者: qq_22305325's avatar qq_22305325 提交者: Jinhui Yuan

Dev matmul dot multiply (#1189)

* add matmul & dot & multiply

* optimize dot kernel

* fix multiply kernel code style

* optimize matmul kernel
上级 a21dea46
#include "oneflow/core/kernel/dot_kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void DotKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
const int64_t piece_size = in_blob->shape().At(0);
const int64_t dim = in_blob->shape().Count(1);
const Blob* weight_blob = BnInOp2Blob("weight");
Blob* out_blob = BnInOp2Blob("out");
Blob* tmp_blob = BnInOp2Blob("tmp");
Blob* tmp_storage_blob = BnInOp2Blob("tmp_storage");
// out = in .* weight
KernelUtil<device_type, T>::Mul(ctx.device_ctx, piece_size * dim, in_blob->dptr<T>(),
weight_blob->dptr<T>(), tmp_blob->mut_dptr<T>());
KernelUtil<device_type, T>::RowSum(ctx.device_ctx, piece_size, dim, tmp_blob->dptr<T>(),
out_blob->mut_dptr<T>(), tmp_storage_blob->mut_dptr<T>(),
sizeof(T) * piece_size * dim);
if (this->op_conf().matmul_conf().has_bias()) {
const Blob* bias_blob = BnInOp2Blob("bias");
// out += bias
KernelUtil<device_type, T>::Axpy(ctx.device_ctx, piece_size, OneVal<T>::value,
bias_blob->dptr<T>(), 1, out_blob->mut_dptr<T>(), 1);
}
}
template<DeviceType device_type, typename T>
void DotKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const Blob* in_blob = BnInOp2Blob("in");
const Blob* diff_mul_blob = BnInOp2Blob("diff_multiplier");
Blob* tmp_blob = BnInOp2Blob("tmp");
Blob* in_diff_blob = BnInOp2Blob("in_diff");
const Blob* weight_blob = BnInOp2Blob("weight");
Blob* weight_diff_blob = BnInOp2Blob("weight_diff");
// tmp = out_diff * diff_mul_blob
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal<T>::value,
ZeroVal<T>::value, out_diff_blob, diff_mul_blob, tmp_blob);
// weight_diff = tmp .* in
KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_blob->shape().elem_cnt(), tmp_blob->dptr<T>(),
in_blob->dptr<T>(), weight_diff_blob->mut_dptr<T>());
// in_diff = tmp .* weight
KernelUtil<device_type, T>::Mul(ctx.device_ctx, weight_blob->shape().elem_cnt(),
tmp_blob->dptr<T>(), weight_blob->dptr<T>(),
in_diff_blob->mut_dptr<T>());
if (this->op_conf().matmul_conf().has_bias()) {
Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
// bias_diff = out_diff
KernelUtil<device_type, T>::Copy(ctx.device_ctx, out_diff_blob->shape().elem_cnt(),
out_diff_blob->dptr<T>(), 1, bias_diff_blob->mut_dptr<T>(), 1);
}
}
template<DeviceType device_type, typename T>
void DotKernel<device_type, T>::InitConstBufBlobs(
DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
InitializerConf diff_multiplier_initializer_conf;
diff_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
KernelUtil<device_type, T>::InitializeWithConf(ctx, diff_multiplier_initializer_conf, 0,
BnInOp2Blob("diff_multiplier"));
}
template<DeviceType device_type, typename T>
const PbMessage& DotKernel<device_type, T>::GetCustomizedOpConf() const {
return this->op_conf().dot_conf();
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kDotConf, DotKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_DOT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_DOT_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class DotKernel final : public KernelIfWithModel<device_type, T> {
public:
OF_DISALLOW_COPY_AND_MOVE(DotKernel);
DotKernel() = default;
~DotKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void BackwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void InitConstBufBlobs(DeviceCtx*,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
const PbMessage& GetCustomizedOpConf() const override;
};
} // namespace oneflow
#endif // ONEFLOE_CORE_KERNEL_DOT_KERNEL_H_
#include "oneflow/core/kernel/matmul_kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void MatmulKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
const Blob* weight_blob = BnInOp2Blob("weight");
Blob* out_blob = BnInOp2Blob("out");
// out = in * weight'
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasTrans, OneVal<T>::value,
ZeroVal<T>::value, in_blob, weight_blob, out_blob);
if (this->op_conf().matmul_conf().has_bias()) {
const Blob* bias_blob = BnInOp2Blob("bias");
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
// out = bias_multiplier * bias + out
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans,
OneVal<T>::value, OneVal<T>::value, bias_mul_blob,
bias_blob, out_blob);
}
}
template<DeviceType device_type, typename T>
void MatmulKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const Blob* in_blob = BnInOp2Blob("in");
Blob* in_diff_blob = BnInOp2Blob("in_diff");
const Blob* weight_blob = BnInOp2Blob("weight");
Blob* weight_diff_blob = BnInOp2Blob("weight_diff");
// weight_diff = out_diff * in'
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value,
ZeroVal<T>::value, out_diff_blob, in_blob, weight_diff_blob);
// in_diff = out_diff * weight
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasNoTrans, OneVal<T>::value,
ZeroVal<T>::value, out_diff_blob, weight_blob, in_diff_blob);
if (this->op_conf().matmul_conf().has_bias()) {
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
// bias_diff = bias_multiplier' * out_diff
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans, OneVal<T>::value,
ZeroVal<T>::value, bias_mul_blob, out_diff_blob,
bias_diff_blob);
}
}
template<DeviceType device_type, typename T>
void MatmulKernel<device_type, T>::InitConstBufBlobs(
DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (!this->op_conf().matmul_conf().has_bias()) { return; }
InitializerConf bias_multiplier_initializer_conf;
bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
KernelUtil<device_type, T>::InitializeWithConf(ctx, bias_multiplier_initializer_conf, 0,
BnInOp2Blob("bias_multiplier"));
}
template<DeviceType device_type, typename T>
const PbMessage& MatmulKernel<device_type, T>::GetCustomizedOpConf() const {
return this->op_conf().matmul_conf();
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMatmulConf, MatmulKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_MATMUL_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_MATMUL_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class MatmulKernel final : public KernelIfWithModel<device_type, T> {
public:
OF_DISALLOW_COPY_AND_MOVE(MatmulKernel);
MatmulKernel() = default;
~MatmulKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void BackwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void InitConstBufBlobs(DeviceCtx*,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override;
const PbMessage& GetCustomizedOpConf() const override;
};
} // namespace oneflow
#endif // ONEFLOE_CORE_KERNEL_MATMUL_KERNEL_H_
#include "oneflow/core/kernel/multiply_kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void MultiplyKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_0_blob = BnInOp2Blob("in_0");
const Blob* in_1_blob = BnInOp2Blob("in_1");
Blob* out_blob = BnInOp2Blob("out");
// out = in_0 .* in_1
KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_0_blob->shape().elem_cnt(),
in_0_blob->dptr<T>(), in_1_blob->dptr<T>(),
out_blob->mut_dptr<T>());
}
template<DeviceType device_type, typename T>
void MultiplyKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx, std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const Blob* in_0_blob = BnInOp2Blob("in_0");
Blob* in_0_diff_blob = BnInOp2Blob("in_0_diff");
const Blob* in_1_blob = BnInOp2Blob("in_1");
Blob* in_1_diff_blob = BnInOp2Blob("in_1_diff");
// in_1_diff = out_diff * in_0
KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_0_blob->shape().elem_cnt(),
in_0_blob->dptr<T>(), out_diff_blob->dptr<T>(),
in_1_diff_blob->mut_dptr<T>());
// in_0_diff = out_diff * in_1
KernelUtil<device_type, T>::Mul(ctx.device_ctx, in_1_blob->shape().elem_cnt(),
in_1_blob->dptr<T>(), out_diff_blob->dptr<T>(),
in_0_diff_blob->mut_dptr<T>());
}
template<DeviceType device_type, typename T>
const PbMessage& MultiplyKernel<device_type, T>::GetCustomizedOpConf() const {
return this->op_conf().multiply_conf();
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kMultiplyConf, MultiplyKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_MULTIPLY_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_MULTIPLY_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class MultiplyKernel final : public KernelIfWithModel<device_type, T> {
public:
OF_DISALLOW_COPY_AND_MOVE(MultiplyKernel);
MultiplyKernel() = default;
~MultiplyKernel() = default;
private:
void ForwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
void BackwardDataContent(const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
const PbMessage& GetCustomizedOpConf() const override;
};
} // namespace oneflow
#endif // ONEFLOE_CORE_KERNEL_MULTIPLY_KERNEL_H_
#include "oneflow/core/operator/dot_op.h"
namespace oneflow {
void DotOp::InitFromOpConf() {
CHECK(op_conf().has_dot_conf());
EnrollInputBn("in");
EnrollInputBn("weight");
EnrollDataTmpBn("tmp");
EnrollDataTmpBn("tmp_storage");
EnrollConstBufBn("diff_multiplier");
EnrollOutputBn("out");
if (op_conf().dot_conf().has_bias()) { EnrollInputBn("bias"); }
}
const PbMessage& DotOp::GetCustomizedConf() const { return op_conf().dot_conf(); }
void DotOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
BlobDesc* weight_blob_desc = GetBlobDesc4BnInOp("weight");
CHECK_EQ(in_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
CHECK_EQ(in_blob_desc->shape().At(0), weight_blob_desc->shape().At(0));
CHECK_EQ(in_blob_desc->shape().Count(1), weight_blob_desc->shape().Count(1));
// tmp & tmp storage
BlobDesc* tmp_blob_desc = GetBlobDesc4BnInOp("tmp");
*tmp_blob_desc = *in_blob_desc;
BlobDesc* tmp_storage_blob_desc = GetBlobDesc4BnInOp("tmp_storage");
*tmp_storage_blob_desc = *in_blob_desc;
// out
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *in_blob_desc;
out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0)});
// diff_multiplier
GetBlobDesc4BnInOp("diff_multiplier")->mut_shape() = Shape({1, in_blob_desc->shape().Count(1)});
}
REGISTER_OP(OperatorConf::kDotConf, DotOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_DOT_OP_H_
#define ONEFLOW_CORE_OPERATOR_DOT_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class DotOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(DotOp);
DotOp() = default;
~DotOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_DOT_OP_H_
#include "oneflow/core/operator/matmul_op.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
void MatmulOp::InitFromOpConf() {
CHECK(op_conf().has_matmul_conf());
EnrollInputBn("in");
EnrollInputBn("weight");
EnrollOutputBn("out");
if (op_conf().matmul_conf().has_bias()) {
EnrollInputBn("bias");
EnrollConstBufBn("bias_multiplier");
}
}
const PbMessage& MatmulOp::GetCustomizedConf() const { return op_conf().matmul_conf(); }
void MatmulOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const MatmulOpConf& conf = op_conf().matmul_conf();
BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
CHECK_EQ(in_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
int32_t units = conf.units();
if (parallel_ctx->policy() == kModelParallel) {
BalancedSplitter splitter(units, parallel_ctx->parallel_num());
units = splitter.At(parallel_ctx->parallel_id()).size();
}
// out
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *in_blob_desc;
out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), units});
if (conf.has_bias()) {
// bias_multiplier
GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() = Shape({in_blob_desc->shape().At(0), 1});
}
}
REGISTER_OP(OperatorConf::kMatmulConf, MatmulOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_
#define ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class MatmulOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(MatmulOp);
MatmulOp() = default;
~MatmulOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
int32_t ModelSplitAxis() const override { return 1; }
int32_t MaxModelSplitNum() const override { return op_conf().matmul_conf().units(); }
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_MATMUL_OP_H_
#include "oneflow/core/operator/multiply_op.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
void MultiplyOp::InitFromOpConf() {
CHECK(op_conf().has_multiply_conf());
EnrollInputBn("in_0");
EnrollInputBn("in_1");
EnrollOutputBn("out");
}
const PbMessage& MultiplyOp::GetCustomizedConf() const { return op_conf().multiply_conf(); }
void MultiplyOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
BlobDesc* in_0_blob_desc = GetBlobDesc4BnInOp("in_0");
BlobDesc* in_1_blob_desc = GetBlobDesc4BnInOp("in_1");
CHECK_EQ(in_0_blob_desc->data_type(), Global<JobDesc>::Get()->DefaultDataType());
CHECK_EQ(in_0_blob_desc->shape(), in_1_blob_desc->shape());
// out
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *in_0_blob_desc;
}
REGISTER_OP(OperatorConf::kMultiplyConf, MultiplyOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_
#define ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class MultiplyOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(MultiplyOp);
MultiplyOp() = default;
~MultiplyOp() = default;
void InitFromOpConf() override;
const PbMessage& GetCustomizedConf() const override;
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_MULTIPLY_OP_H_
......@@ -609,6 +609,26 @@ message AccuracyOpConf {
required string accuracy = 4;
}
message MatmulOpConf {
required string in = 1;
required string weight = 2;
optional string bias = 3;
required int32 units = 4;
required string out = 5;
}
message DotOpConf {
required string in = 1;
required string weight = 2;
optional string bias = 3;
required string out = 4;
}
message MultiplyOpConf {
required string in_0 = 1;
required string in_1 = 2;
required string out = 4;
}
enum Norm {
L1 = 1;
L2 = 2;
......@@ -658,7 +678,10 @@ message OperatorConf {
Conv2DOpConf conv_2d_conf = 126;
Conv3DOpConf conv_3d_conf = 127;
TransposeOpConf transpose_conf = 128;
HingeLossOpConf hinge_loss_conf = 129;
MatmulOpConf matmul_conf = 129;
DotOpConf dot_conf = 130;
MultiplyOpConf multiply_conf = 131;
HingeLossOpConf hinge_loss_conf = 132;
DropoutOpConf dropout_conf = 140;
AveragePooling1DOpConf average_pooling_1d_conf = 200;
MaxPooling1DOpConf max_pooling_1d_conf = 201;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册