未验证 提交 d11c140e 编写于 作者: X xujiaqi01 提交者: GitHub

fix dump, fix cvm check (#25400)

* fix dump, fix cvm check
test=develop

* fix
test=develop

* fix
test=develop

* fix
test=develop
上级 8ebffc78
......@@ -111,6 +111,7 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) {
writer_ << os.str();
}
}
void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) {
bool enable_random_dump = desc.enable_random_dump();
if (!enable_random_dump) {
......
......@@ -99,7 +99,7 @@ void DistMultiTrainer::InitTrainerEnv(const ProgramDesc &main_program,
}
void DistMultiTrainer::InitOtherEnv(const ProgramDesc &main_program) {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
pull_dense_worker_->SetRootScope(root_scope_);
......@@ -158,7 +158,7 @@ void DistMultiTrainer::Finalize() {
}
}
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
pull_dense_worker_->Stop();
......
......@@ -49,7 +49,12 @@ TEST(DisMultiTrainerTest, test1) {
dataset->SetTrainerNum(1);
dataset->SetDataFeedDesc(str);
dataset->CreateReaders();
Scope root_scope;
tmp1->SetScope(&root_scope);
tmp1->Initialize(t, dataset.get());
ProgramDesc p;
tmp1->InitOtherEnv(p);
tmp1->Finalize();
#endif
}
} // namespace framework
......
......@@ -106,7 +106,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
void MultiTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
......@@ -133,7 +133,7 @@ void MultiTrainer::Run() {
}
void MultiTrainer::Finalize() {
if (need_dump_field_) {
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();
......
......@@ -22,6 +22,8 @@ void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; }
void TrainerBase::ParseDumpConfig(const TrainerDesc& desc) {
dump_fields_path_ = desc.dump_fields_path();
need_dump_field_ = false;
need_dump_param_ = false;
if (dump_fields_path_ == "") {
VLOG(2) << "dump_fields_path_ is empty";
return;
......
......@@ -27,19 +27,11 @@ class CVMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CVM");
OP_INOUT_CHECK(ctx->HasInput("CVM"), "Input", "CVM", "CVM");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "CVM");
auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, platform::errors::InvalidArgument(
"Input(X)'s rank should be 2."));
PADDLE_ENFORCE_EQ(
cvm_dims.size(), 2UL,
platform::errors::InvalidArgument("Input(CVM)'s rank should be 2."));
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, platform::errors::InvalidArgument(
"The 2nd dimension of "
"Input(CVM) should be 2."));
if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
......
......@@ -62,15 +62,18 @@ class TrainerFactory(object):
trainer._set_mpi_rank(opt_info["mpi_rank"])
if opt_info.get("mpi_size") is not None:
trainer._set_mpi_size(opt_info["mpi_size"])
if opt_info.get("dump_fields") is not None:
if opt_info.get("dump_fields") is not None and len(
opt_info.get("dump_fields")) != 0:
trainer._set_dump_fields(opt_info["dump_fields"])
if opt_info.get("dump_fields_path") is not None:
if opt_info.get("dump_fields_path") is not None and len(
opt_info.get("dump_fields_path")) != 0:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("dump_file_num") is not None:
trainer._set_dump_file_num(opt_info["dump_file_num"])
if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"])
if opt_info.get("dump_param") is not None:
if opt_info.get("dump_param") is not None and len(
opt_info.get("dump_param")) != 0:
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("enable_random_dump") is not None:
trainer._set_enable_random_dump(opt_info[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册