提交 49dd19bc 编写于 作者: S Shiyuan Shang-Guan 提交者: Jinhui Yuan

align with tensorflow (#1461)



Former-commit-id: 94be14b8189a7123a4012bb34727b32f7ec07599
上级 9e6347a0
......@@ -8,7 +8,8 @@ void AccuracyPrintKernel<T>::Forward(const KernelCtx& kernel_ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
const Blob* accuracy_acc_blob = BnInOp2Blob("accuracy_acc");
T accuracy_num = accuracy_acc_blob->dptr<T>()[0];
int total_num = Global<JobDesc>::Get()->PieceSize() * Global<JobDesc>::Get()->PieceNumOfPrintAccuracy();
int total_num =
Global<JobDesc>::Get()->PieceSize() * Global<JobDesc>::Get()->PieceNumOfPrintAccuracy();
float accuracy = accuracy_num / total_num;
const char* accuracy_op_name = op_conf().name().c_str() + AccuracyPrintPrefix.length();
auto kernel_conf = this->kernel_conf();
......
......@@ -138,7 +138,7 @@ double ConstantWarmupLearningRate(const ConstantWarmupConf& conf, double lr,
double LinearWarmupLearningRate(const LinearWarmupConf& conf, double lr, int64_t cur_batch_num) {
CHECK_GT(conf.warmup_batches(), 0);
CHECK_GT(conf.start_multiplier(), 0);
CHECK_GE(conf.start_multiplier(), 0);
CHECK_LT(conf.start_multiplier(), 1);
double start_multiplier = conf.start_multiplier();
double multiplier = 1.0;
......
......@@ -6,7 +6,7 @@ namespace oneflow {
void NormalizationOp::InitFromOpConf() {
const auto& conf = op_conf().normalization_conf();
float min_epsilon = CUDNN_BN_MIN_EPSILON + 1e-8;
if(conf.epsilon() < min_epsilon){
if (conf.epsilon() < min_epsilon) {
this->mut_op_conf()->mutable_normalization_conf()->set_epsilon(min_epsilon);
}
CHECK_GE(conf.momentum(), 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册