提交 e74bb463 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

fix bug of normalization load model from snapshot (#1119)

上级 953e1aec
......@@ -166,24 +166,25 @@ void NormalizationKernel<device_type, T>::InitModelBlobsWithDir(
DeviceCtx* ctx, int32_t part_id, int32_t part_num, const std::string& model_load_dir,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const auto& conf = this->op_conf().normalization_conf();
int32_t dim_num = this->kernel_conf().normalization_conf().transpose_cols();
if (conf.scale()) {
Blob* gamma_blob = BnInOp2Blob("gamma");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, gamma_blob,
"gamma", dim_num, gamma_blob->shape().Count(1));
"gamma", gamma_blob->shape().At(0),
gamma_blob->shape().Count(1));
}
if (conf.center()) {
Blob* beta_blob = BnInOp2Blob("beta");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, beta_blob,
"beta", dim_num, beta_blob->shape().Count(1));
"beta", beta_blob->shape().At(0),
beta_blob->shape().Count(1));
}
Blob* mean_blob = BnInOp2Blob("moving_mean");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, mean_blob,
"moving_mean", dim_num,
"moving_mean", mean_blob->shape().At(0),
mean_blob->shape().Count(1));
Blob* variance_blob = BnInOp2Blob("moving_variance");
KernelUtil<device_type, T>::InitializeWithDir(ctx, 0, part_num, model_load_dir, variance_blob,
"moving_variance", dim_num,
"moving_variance", variance_blob->shape().At(0),
variance_blob->shape().Count(1));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册