提交 ced21d0d 编写于 作者: Y Yi Zhu 提交者: Will Zhang

Update innerproduct proto (#531)

* updt names in ip

* updt comment && fix bug

* fix bug && updt proto

* rename innerproduct to fullyconnected

* updt field names in fully connected op

* roll back impl of initialization of bias

* fix bug of bias_multiplier_initializer

* remove redundant SetHasDataIdField

* remove redundant set_data_type

* refine code


Former-commit-id: 534138dc
上级 e564558b
......@@ -37,11 +37,10 @@ op {
op {
name: "ip10"
innerproduct_conf {
fully_connected_conf {
in: "conv/out"
out: "out"
out_num: 10
has_bias_term: true
units: 10
}
}
......
#include "oneflow/core/kernel/innerproduct_kernel.h"
#include "oneflow/core/kernel/fully_connected_kernel.h"
#include "oneflow/core/kernel/kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
void InnerProductKernel<device_type, T>::ForwardDataContent(
void FullyConnectedKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
......@@ -16,19 +16,17 @@ void InnerProductKernel<device_type, T>::ForwardDataContent(
static_cast<T>(1.0), static_cast<T>(0.0),
in_blob, weight_blob, out_blob);
if (this->op_conf().innerproduct_conf().has_bias_term()) {
const Blob* bias_blob = BnInOp2Blob("bias");
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
const Blob* bias_blob = BnInOp2Blob("bias");
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
// out = bias_multiplier * bias + out
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasNoTrans, CblasNoTrans, static_cast<T>(1.0),
static_cast<T>(1.0), bias_mul_blob, bias_blob, out_blob);
}
// out = bias_multiplier * bias + out
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasNoTrans, CblasNoTrans, static_cast<T>(1.0),
static_cast<T>(1.0), bias_mul_blob, bias_blob, out_blob);
}
template<DeviceType device_type, typename T>
void InnerProductKernel<device_type, T>::BackwardDataContent(
void FullyConnectedKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* in_blob = BnInOp2Blob("in");
......@@ -50,66 +48,59 @@ void InnerProductKernel<device_type, T>::BackwardDataContent(
static_cast<T>(0.0), out_diff_blob, weight_blob, in_diff_blob);
}
if (this->op_conf().innerproduct_conf().has_bias_term()) {
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
const Blob* bias_mul_blob = BnInOp2Blob("bias_multiplier");
Blob* bias_diff_blob = BnInOp2Blob("bias_diff");
// bias_diff = bias_multiplier * out_diff
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasTrans, CblasNoTrans, static_cast<T>(1.0),
static_cast<T>(0.0), bias_mul_blob, out_diff_blob, bias_diff_blob);
}
}
// bias_diff = bias_multiplier * out_diff
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasTrans, CblasNoTrans, static_cast<T>(1.0),
static_cast<T>(0.0), bias_mul_blob, out_diff_blob, bias_diff_blob);
} // namespace oneflow
template<DeviceType device_type, typename T>
void InnerProductKernel<device_type, T>::InitModelBlobsWithRandomSeed(
void FullyConnectedKernel<device_type, T>::InitModelBlobsWithRandomSeed(
const KernelCtx& ctx, std::mt19937 random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
KernelUtil<device_type, T>::InitializeWithProperConf(
ctx.device_ctx,
OF_PB_POINTER_GET(this->op_conf().innerproduct_conf(),
OF_PB_POINTER_GET(this->op_conf().fully_connected_conf(),
weight_initializer),
random_seed_gen(), BnInOp2Blob("weight"));
if (this->op_conf().innerproduct_conf().has_bias_term()) {
KernelUtil<device_type, T>::InitializeWithProperConf(
ctx.device_ctx,
OF_PB_POINTER_GET(this->op_conf().innerproduct_conf(),
bias_initializer),
random_seed_gen(), BnInOp2Blob("bias"));
}
KernelUtil<device_type, T>::InitializeWithProperConf(
ctx.device_ctx,
OF_PB_POINTER_GET(this->op_conf().fully_connected_conf(),
bias_initializer),
random_seed_gen(), BnInOp2Blob("bias"));
}
template<DeviceType device_type, typename T>
void InnerProductKernel<device_type, T>::InitModelBlobsWithDir(
void FullyConnectedKernel<device_type, T>::InitModelBlobsWithDir(
const KernelCtx& ctx, int32_t part_id, int32_t part_num,
const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
Blob* weight_blob = BnInOp2Blob("weight");
int32_t dim_num = this->op_conf().innerproduct_conf().out_num();
Blob* weight_blob = BnInOp2Blob("weightes");
int32_t dim_num = this->op_conf().fully_connected_conf().units();
KernelUtil<device_type, T>::InitializeWithModelDir(
ctx.device_ctx, part_id, part_num, model_load_dir, weight_blob, "weight",
dim_num, weight_blob->shape().Count(1));
if (this->op_conf().innerproduct_conf().has_bias_term()) {
KernelUtil<device_type, T>::InitializeWithModelDir(
ctx.device_ctx, part_id, part_num, model_load_dir, BnInOp2Blob("bias"),
"bias", dim_num, 1);
}
KernelUtil<device_type, T>::InitializeWithModelDir(
ctx.device_ctx, part_id, part_num, model_load_dir, BnInOp2Blob("bias"),
"bias", dim_num, 1);
}
template<DeviceType device_type, typename T>
void InnerProductKernel<device_type, T>::InitModelTmpBlobs(
void FullyConnectedKernel<device_type, T>::InitModelTmpBlobs(
const KernelCtx& ctx, const ParallelContext* parallel_ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (this->op_conf().innerproduct_conf().has_bias_term()) {
InitializerConf bias_multiplier_initializer_conf;
bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
KernelUtil<device_type, T>::Initialize(ctx.device_ctx,
bias_multiplier_initializer_conf, 0,
BnInOp2Blob("bias_multiplier"));
}
InitializerConf bias_multiplier_initializer_conf;
bias_multiplier_initializer_conf.mutable_constant_conf()->set_value(1.0f);
KernelUtil<device_type, T>::Initialize(ctx.device_ctx,
bias_multiplier_initializer_conf, 0,
BnInOp2Blob("bias_multiplier"));
}
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kInnerproductConf, InnerProductKernel,
FLOATING_DATA_TYPE_SEQ);
ADD_DEFAULT_KERNEL_CREATOR(OperatorConf::kFullyConnectedConf,
FullyConnectedKernel, FLOATING_DATA_TYPE_SEQ);
} // namespace oneflow
#ifndef ONEFLOW_CORE_KERNEL_INNERPRODUCT_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_INNERPRODUCT_KERNEL_H_
#ifndef ONEFLOW_CORE_KERNEL_FULLY_CONNECTED_KERNEL_H_
#define ONEFLOW_CORE_KERNEL_FULLY_CONNECTED_KERNEL_H_
#include "oneflow/core/kernel/kernel.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class InnerProductKernel final : public KernelIf<device_type> {
class FullyConnectedKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(InnerProductKernel);
InnerProductKernel() = default;
~InnerProductKernel() = default;
OF_DISALLOW_COPY_AND_MOVE(FullyConnectedKernel);
FullyConnectedKernel() = default;
~FullyConnectedKernel() = default;
private:
void ForwardDataContent(
......@@ -33,4 +33,4 @@ class InnerProductKernel final : public KernelIf<device_type> {
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_INNERPRODUCT_KERNEL_H_
#endif // ONEFLOW_CORE_KERNEL_FULLY_CONNECTED_KERNEL_H_
#include "oneflow/core/operator/innerproduct_op.h"
#include "oneflow/core/operator/fully_connected_op.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
void InnerProductOp::InitFromOpConf() {
CHECK(op_conf().has_innerproduct_conf());
void FullyConnectedOp::InitFromOpConf() {
CHECK(op_conf().has_fully_connected_conf());
EnrollInputBn("in");
EnrollOutputBn("out");
EnrollModelBn("weight");
if (GetBoolFromSpecialConf("has_bias_term")) {
EnrollModelBn("bias");
EnrollModelTmpBn("bias_multiplier");
}
EnrollModelBn("bias");
EnrollModelTmpBn("bias_multiplier");
}
const PbMessage& InnerProductOp::GetSpecialConf() const {
return op_conf().innerproduct_conf();
const PbMessage& FullyConnectedOp::GetSpecialConf() const {
return op_conf().fully_connected_conf();
}
void InnerProductOp::InferBlobDescs(
void FullyConnectedOp::InferBlobDescs(
std::function<BlobDesc*(const std::string)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
// useful vars
const InnerProductOpConf& conf = op_conf().innerproduct_conf();
const FullyConnectedOpConf& conf = op_conf().fully_connected_conf();
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
CHECK_EQ(in_blob_desc->data_type(), JobDesc::Singleton()->DefaultDataType());
int32_t out_num = conf.out_num();
int32_t units = conf.units();
if (parallel_ctx->policy() == kModelParallel) {
BalancedSplitter splitter(out_num, parallel_ctx->parallel_num());
out_num = splitter.At(parallel_ctx->parallel_id()).size();
BalancedSplitter splitter(units, parallel_ctx->parallel_num());
units = splitter.At(parallel_ctx->parallel_id()).size();
}
// out
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), out_num});
out_blob_desc->set_data_type(JobDesc::Singleton()->DefaultDataType());
out_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), units});
out_blob_desc->set_has_data_id_field(in_blob_desc->has_data_id_field());
// weight
BlobDesc* weight_blob_desc = GetBlobDesc4BnInOp("weight");
weight_blob_desc->mut_shape() =
Shape({out_num, in_blob_desc->shape().Count(1)});
weight_blob_desc->set_data_type(JobDesc::Singleton()->DefaultDataType());
weight_blob_desc->set_has_data_id_field(false);
GetBlobDesc4BnInOp("weight")->mut_shape() =
Shape({units, in_blob_desc->shape().Count(1)});
if (conf.has_bias_term()) {
// bias
BlobDesc* bias_blob_desc = GetBlobDesc4BnInOp("bias");
bias_blob_desc->mut_shape() = Shape({1, out_num});
bias_blob_desc->set_data_type(JobDesc::Singleton()->DefaultDataType());
bias_blob_desc->set_has_data_id_field(false);
// bias
GetBlobDesc4BnInOp("bias")->mut_shape() = Shape({1, units});
// bias_multiplier
BlobDesc* bias_mt_blob_desc = GetBlobDesc4BnInOp("bias_multiplier");
bias_mt_blob_desc->mut_shape() = Shape({in_blob_desc->shape().At(0), 1});
bias_mt_blob_desc->set_data_type(JobDesc::Singleton()->DefaultDataType());
bias_mt_blob_desc->set_has_data_id_field(false);
}
// bias_multiplier
GetBlobDesc4BnInOp("bias_multiplier")->mut_shape() =
Shape({in_blob_desc->shape().At(0), 1});
}
REGISTER_OP(OperatorConf::kInnerproductConf, InnerProductOp);
REGISTER_OP(OperatorConf::kFullyConnectedConf, FullyConnectedOp);
} // namespace oneflow
#ifndef ONEFLOW_CORE_OPERATOR_INNERPRODUCT_OP_H_
#define ONEFLOW_CORE_OPERATOR_INNERPRODUCT_OP_H_
#ifndef ONEFLOW_CORE_OPERATOR_FULLY_CONNECTED_OP_H_
#define ONEFLOW_CORE_OPERATOR_FULLY_CONNECTED_OP_H_
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class InnerProductOp final : public Operator {
class FullyConnectedOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(InnerProductOp);
InnerProductOp() = default;
~InnerProductOp() = default;
OF_DISALLOW_COPY_AND_MOVE(FullyConnectedOp);
FullyConnectedOp() = default;
~FullyConnectedOp() = default;
void InitFromOpConf() override;
bool NeedExtraInDiffMemWhenBackward() const override { return false; }
......@@ -20,7 +20,7 @@ class InnerProductOp final : public Operator {
const ParallelContext* parallel_ctx) const override;
int32_t ModelSplitAxis() const override { return 1; }
int32_t MaxModelSplitNum() const override {
return op_conf().innerproduct_conf().out_num();
return op_conf().fully_connected_conf().units();
}
private:
......@@ -28,4 +28,4 @@ class InnerProductOp final : public Operator {
} // namespace oneflow
#endif // ONEFLOW_CORE_OPERATOR_INNERPRODUCT_OP_H_
#endif // ONEFLOW_CORE_OPERATOR_FULLY_CONNECTED_OP_H_
......@@ -67,14 +67,12 @@ message ConvolutionOpConf {
optional InitializerConf bias_initializer = 14;
}
message InnerProductOpConf {
message FullyConnectedOpConf {
required string in = 1;
required string out = 2;
required int32 out_num = 3;
optional bool has_bias_term = 4 [default = true];
optional InitializerConf weight_initializer = 5;
optional InitializerConf bias_initializer = 6;
required int32 units = 3;
optional InitializerConf weight_initializer = 4;
optional InitializerConf bias_initializer = 5;
}
message BasicDataLoaderOpConf {
......@@ -255,7 +253,7 @@ message OperatorConf {
optional string model_load_dir = 2;
oneof op_type {
ConvolutionOpConf convolution_conf = 100;
InnerProductOpConf innerproduct_conf = 101;
FullyConnectedOpConf fully_connected_conf = 101;
BasicDataLoaderOpConf basic_data_loader_conf = 102;
PoolingOpConf pooling_conf = 103;
ReluOpConf relu_conf = 104;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册