From d11c140e280880b9d031fa38361f3230aef6cf9c Mon Sep 17 00:00:00 2001 From: xujiaqi01 <173596896@qq.com> Date: Mon, 3 Aug 2020 11:59:39 +0800 Subject: [PATCH] fix dump, fix cvm check (#25400) * fix dump, fix cvm check test=develop * fix test=develop * fix test=develop * fix test=develop --- paddle/fluid/framework/device_worker.cc | 1 + paddle/fluid/framework/dist_multi_trainer.cc | 4 ++-- paddle/fluid/framework/dist_multi_trainer_test.cc | 5 +++++ paddle/fluid/framework/multi_trainer.cc | 4 ++-- paddle/fluid/framework/trainer.cc | 2 ++ paddle/fluid/operators/cvm_op.cc | 8 -------- python/paddle/fluid/trainer_factory.py | 9 ++++++--- 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index f7e64b4f65..aeec616171 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) { writer_ << os.str(); } } + void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) { bool enable_random_dump = desc.enable_random_dump(); if (!enable_random_dump) { diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 6ed68bb096..e2a7375df9 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program, } void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) { - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { InitDumpEnv(); } pull_dense_worker_->SetRootScope(root_scope_); @@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() { } } - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { FinalizeDumpEnv(); } pull_dense_worker_->Stop(); diff --git a/paddle/fluid/framework/dist_multi_trainer_test.cc b/paddle/fluid/framework/dist_multi_trainer_test.cc index f54029fd17..75543b7b30 100644 --- a/paddle/fluid/framework/dist_multi_trainer_test.cc +++ b/paddle/fluid/framework/dist_multi_trainer_test.cc @@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) { dataset->SetTrainerNum(1); dataset->SetDataFeedDesc(str); dataset->CreateReaders(); + Scope root_scope; + tmp1->SetScope(&root_scope); tmp1->Initialize(t, dataset.get()); + ProgramDesc p; + tmp1->InitOtherEnv(p); + tmp1->Finalize(); #endif } } // namespace framework diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 4ffd9a2f9c..4ae26903e6 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -106,7 +106,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, } void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) { - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { InitDumpEnv(); } VLOG(3) << "init other env done."; @@ -133,7 +133,7 @@ void MultiTrainer::Run() { } void MultiTrainer::Finalize() { - if (need_dump_field_) { + if (need_dump_field_ || need_dump_param_) { FinalizeDumpEnv(); } root_scope_->DropKids(); diff --git a/paddle/fluid/framework/trainer.cc b/paddle/fluid/framework/trainer.cc index 99a1589200..b033f9a99d 100644 --- a/paddle/fluid/framework/trainer.cc +++ b/paddle/fluid/framework/trainer.cc @@ -22,6 +22,8 @@ void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; } void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) { dump_fields_path_ = desc.dump_fields_path(); + need_dump_field_ = false; + need_dump_param_ = false; if (dump_fields_path_ == "") { VLOG(2) << "dump_fields_path_ is empty"; return; diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 995ff4a9c7..a1a8744c32 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM"); - OP_INOUT_CHECK(ctx->HasInput("CVM"), "Input", "CVM", "CVM"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM"); auto x_dims = ctx->GetInputDim("X"); - auto cvm_dims = ctx->GetInputDim("CVM"); PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument( "Input(X)'s rank should be 2.")); - PADDLE_ENFORCE_EQ( - cvm_dims.size(), 2UL, - platform::errors::InvalidArgument("Input(CVM)'s rank should be 2.")); - PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument( - "The 2nd dimension of " - "Input(CVM) should be 2.")); if (ctx->Attrs().Get("use_cvm")) { ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]}); diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 22ba46b90d..c2d80f52b8 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -62,15 +62,18 @@ class TrainerFactory(object): trainer._set_mpi_rank(opt_info["mpi_rank"]) if opt_info.get("mpi_size") is not None: trainer._set_mpi_size(opt_info["mpi_size"]) - if opt_info.get("dump_fields") is not None: + if opt_info.get("dump_fields") is not None and len( + opt_info.get("dump_fields")) != 0: trainer._set_dump_fields(opt_info["dump_fields"]) - if opt_info.get("dump_fields_path") is not None: + if opt_info.get("dump_fields_path") is not None and len( + opt_info.get("dump_fields_path")) != 0: trainer._set_dump_fields_path(opt_info["dump_fields_path"]) if opt_info.get("dump_file_num") is not None: trainer._set_dump_file_num(opt_info["dump_file_num"]) if opt_info.get("dump_converter") is not None: trainer._set_dump_converter(opt_info["dump_converter"]) - if opt_info.get("dump_param") is not None: + if opt_info.get("dump_param") is not None and len( + opt_info.get("dump_param")) != 0: trainer._set_dump_param(opt_info["dump_param"]) if opt_info.get("enable_random_dump") is not None: trainer._set_enable_random_dump(opt_info[ -- GitLab