提交 c7bf405e 编写于 作者: duduscript's avatar duduscript 提交者: Will Zhang

Fix naming problem (#553)

* add l1 l2 regularization

* fix some problem

* Fix bug

* fix op API

* Fix a bug in op

* add l1l2

* remove regularization_diff_blob

* remove sign in kernel_util

* fix naming problem and add l1l2 function

* simplify code

* remove else

* update modelupdate API

* Fix bug

* ~

* fix bug

* add basic_rnn_op

* Fix naming problem

* clang-format

* Fix bug for hidden_diff_blob update

* Fix bug for hidden_diff compute

* fix bug for h0 initilizer

* fix bug for h0 initilizer

* add else

* remove else

* fix

* optimize

* optimize

* add has_init_hidden_initilizer

* optimize
上级 2c2b79b5
......@@ -7,11 +7,16 @@ const PbMessage& BasicRnnKernel<device_type, T>::GetRecurrentOpConf() const {
return this->op_conf().basic_rnn_conf();
}
template<DeviceType device_type, typename T>
bool BasicRnnKernel<device_type, T>::HasInitHiddenInitializer() const {
return this->op_conf().basic_rnn_conf().has_init_hidden_initializer();
}
template<DeviceType device_type, typename T>
void BasicRnnKernel<device_type, T>::ForwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* rec_ht_blob = this->GetHiddenBlob(BnInOp2Blob);
const Blob* hidden_blob = this->GetHiddenBlob(BnInOp2Blob);
Blob* plus_op_out_blob = BnInOp2Blob("plus_op_out");
Blob* out_blob = BnInOp2Blob("out");
......@@ -21,10 +26,10 @@ void BasicRnnKernel<device_type, T>::ForwardDataContent(
static_cast<T>(0), BnInOp2Blob("in"), BnInOp2Blob("i2h_weight"),
plus_op_out_blob);
// plus_op_out += rec_ht * h2h_weight
// plus_op_out += hidden * h2h_weight
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasNoTrans, CblasTrans,
static_cast<T>(1), static_cast<T>(1),
rec_ht_blob, BnInOp2Blob("h2h_weight"),
hidden_blob, BnInOp2Blob("h2h_weight"),
plus_op_out_blob);
// plus_op_out += bias_multiplier * bias
......@@ -45,17 +50,17 @@ void BasicRnnKernel<device_type, T>::ForwardDataContent(
UNEXPECTED_RUN();
}
// rec_ht = out
BnInOp2Blob("rec_ht")->CopyDataContentFrom<device_type>(ctx.device_ctx,
out_blob);
// rec_out = out
BnInOp2Blob("rec_out")->CopyDataContentFrom<device_type>(ctx.device_ctx,
out_blob);
}
template<DeviceType device_type, typename T>
void BasicRnnKernel<device_type, T>::ForwardDataId(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
BnInOp2Blob("rec_ht")->CopyDataIdFrom<device_type>(ctx.device_ctx,
BnInOp2Blob("in"));
BnInOp2Blob("out")->CopyDataIdFrom<device_type>(ctx.device_ctx,
BnInOp2Blob("in"));
}
template<DeviceType device_type, typename T>
......@@ -63,30 +68,30 @@ void BasicRnnKernel<device_type, T>::BackwardDataContent(
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* out_blob = BnInOp2Blob("out");
const Blob* rec_ht_blob = this->GetHiddenBlob(BnInOp2Blob);
const Blob* hidden_blob = this->GetHiddenBlob(BnInOp2Blob);
const Blob* out_diff_blob = BnInOp2Blob("out_diff");
const Blob* rec_ht_diff_blob = BnInOp2Blob("rec_ht_diff");
const Blob* rec_out_diff_blob = BnInOp2Blob("rec_out_diff");
// reuse memory
Blob* plus_op_out_diff_blob = BnInOp2Blob("plus_op_out");
if (this->op_conf().basic_rnn_conf().activation() == kTanH) {
BasicRnnKernelUtil<device_type, T>::ComputeTanHDiff(
ctx.device_ctx, out_blob->shape().elem_cnt(), out_blob->dptr<T>(),
out_diff_blob->dptr<T>(), rec_ht_diff_blob->dptr<T>(),
out_diff_blob->dptr<T>(), rec_out_diff_blob->dptr<T>(),
plus_op_out_diff_blob->mut_dptr<T>());
} else if (this->op_conf().basic_rnn_conf().activation() == kSigmoid) {
BasicRnnKernelUtil<device_type, T>::ComputeSigmoidDiff(
ctx.device_ctx, out_blob->shape().elem_cnt(), out_blob->dptr<T>(),
out_diff_blob->dptr<T>(), rec_ht_diff_blob->dptr<T>(),
out_diff_blob->dptr<T>(), rec_out_diff_blob->dptr<T>(),
plus_op_out_diff_blob->mut_dptr<T>());
} else {
UNEXPECTED_RUN();
}
// h2h_weight_diff = plus_op_out_diff * rec_ht
// h2h_weight_diff = plus_op_out_diff * hidden
KernelUtil<device_type, T>::BlobGemm(ctx.device_ctx, CblasTrans, CblasNoTrans,
static_cast<T>(1), static_cast<T>(0),
plus_op_out_diff_blob, rec_ht_blob,
plus_op_out_diff_blob, hidden_blob,
BnInOp2Blob("h2h_weight_diff"));
// i2h_weight_diff = plus_op_out_diff * in
......@@ -107,13 +112,11 @@ void BasicRnnKernel<device_type, T>::BackwardDataContent(
static_cast<T>(0), BnInOp2Blob("bias_multiplier"), plus_op_out_diff_blob,
BnInOp2Blob("bias_diff"));
if (this->NeedExternalH0() && BnInOp2Blob("rec_ht_diff")->col_id() == 0) {
// h0_diff = plus_op_out_diff * h2h_weight
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasNoTrans, CblasNoTrans, static_cast<T>(0),
static_cast<T>(0), plus_op_out_diff_blob, BnInOp2Blob("h2h_weight"),
BnInOp2Blob("h0_diff"));
}
// hidden_diff = plus_op_out_diff * h2h_weight
KernelUtil<device_type, T>::BlobGemm(
ctx.device_ctx, CblasNoTrans, CblasNoTrans, static_cast<T>(1),
static_cast<T>(0), plus_op_out_diff_blob, BnInOp2Blob("h2h_weight"),
this->GetHiddenDiffBlob(BnInOp2Blob));
}
template<DeviceType device_type, typename T>
......@@ -170,18 +173,20 @@ void BasicRnnKernel<device_type, T>::InitModelTmpBlobs(
template<typename T>
class BasicRnnKernelUtil<DeviceType::kCPU, T> final {
public:
static void ComputeTanHDiff(DeviceCtx* ctx, int64_t n, const T* ht,
const T* ht_diff, const T* rec_ht_diff,
static void ComputeTanHDiff(DeviceCtx* ctx, int64_t n, const T* out,
const T* out_diff, const T* rec_out_diff,
T* plus_out_diff) {
FOR_RANGE(int64_t, i, 0, n) {
plus_out_diff[i] = (1 - ht[i] * ht[i]) * (ht_diff[i] + rec_ht_diff[i]);
plus_out_diff[i] =
(1 - out[i] * out[i]) * (out_diff[i] + rec_out_diff[i]);
}
}
static void ComputeSigmoidDiff(DeviceCtx* ctx, int64_t n, const T* ht,
const T* ht_diff, const T* rec_ht_diff,
static void ComputeSigmoidDiff(DeviceCtx* ctx, int64_t n, const T* out,
const T* out_diff, const T* rec_out_diff,
T* plus_out_diff) {
FOR_RANGE(int64_t, i, 0, n) {
plus_out_diff[i] = ht[i] * (1 - ht[i]) * (ht_diff[i] + rec_ht_diff[i]);
plus_out_diff[i] =
out[i] * (1 - out[i]) * (out_diff[i] + rec_out_diff[i]);
}
}
};
......
......@@ -14,6 +14,7 @@ class BasicRnnKernel final : public RecurrentKernel<device_type, T> {
private:
const PbMessage& GetRecurrentOpConf() const override;
bool HasInitHiddenInitializer() const override;
void ForwardDataContent(
const KernelCtx&,
std::function<Blob*(const std::string&)>) const override;
......@@ -37,11 +38,11 @@ class BasicRnnKernel final : public RecurrentKernel<device_type, T> {
template<DeviceType device_type, typename T>
class BasicRnnKernelUtil final {
public:
static void ComputeTanHDiff(DeviceCtx* ctx, int64_t n, const T* ht,
const T* ht_diff, const T* rec_ht_diff,
static void ComputeTanHDiff(DeviceCtx* ctx, int64_t n, const T* out,
const T* out_diff, const T* rec_out_diff,
T* plus_out_diff);
static void ComputeSigmoidDiff(DeviceCtx* ctx, int64_t n, const T* ht,
const T* ht_diff, const T* rec_ht_diff,
static void ComputeSigmoidDiff(DeviceCtx* ctx, int64_t n, const T* out,
const T* out_diff, const T* rec_out_diff,
T* plus_out_diff);
};
......
......@@ -20,7 +20,14 @@ template<DeviceType device_type, typename T>
Blob* RecurrentKernel<device_type, T>::GetHiddenBlob(
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (BnInOp2Blob("in")->col_id() == 0) { return BnInOp2Blob("h0"); }
return BnInOp2Blob("rec_ht");
return BnInOp2Blob("rec_in");
}
template<DeviceType device_type, typename T>
Blob* RecurrentKernel<device_type, T>::GetHiddenDiffBlob(
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (BnInOp2Blob("in")->col_id() == 0) { return BnInOp2Blob("h0_diff"); }
return BnInOp2Blob("rec_in_diff");
}
template<DeviceType device_type, typename T>
......@@ -28,11 +35,14 @@ void RecurrentKernel<device_type, T>::InitModelBlobsWithRandomSeed(
const KernelCtx& ctx, std::mt19937 random_seed_gen,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
if (NeedExternalH0()) {
const InitializerConf& init_hidden_initializer =
static_cast<const InitializerConf&>(GetMessageFromPbMessage(
GetRecurrentOpConf(), "init_hidden_initializer"));
const InitializerConf* init_hidden_initializer = nullptr;
if (HasInitHiddenInitializer()) {
init_hidden_initializer =
static_cast<const InitializerConf*>(&GetMessageFromPbMessage(
GetRecurrentOpConf(), "init_hidden_initializer"));
}
KernelUtil<device_type, T>::InitializeWithProperConf(
ctx.device_ctx, &init_hidden_initializer, random_seed_gen(),
ctx.device_ctx, init_hidden_initializer, random_seed_gen(),
BnInOp2Blob("h0"));
}
VirtualInitModelBlobsWithRandomSeed(ctx, random_seed_gen, BnInOp2Blob);
......
......@@ -15,8 +15,10 @@ class RecurrentKernel : public KernelIf<device_type> {
RecurrentKernel() = default;
virtual const PbMessage& GetRecurrentOpConf() const = 0;
virtual bool HasInitHiddenInitializer() const = 0;
bool NeedExternalH0() const;
Blob* GetHiddenBlob(std::function<Blob*(const std::string&)>) const;
Blob* GetHiddenDiffBlob(std::function<Blob*(const std::string&)>) const;
void InitModelBlobsWithRandomSeed(
const KernelCtx& ctx, std::mt19937 random_seed_gen,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册