未验证 提交 eeeef957 编写于 作者: C Chengmo 提交者: GitHub

Fix ps gpu (#26218)

* support ps-gpu
上级 6cbeafb6
...@@ -110,7 +110,7 @@ void prefetch_core( ...@@ -110,7 +110,7 @@ void prefetch_core(
int pservers = context.Attr<int>("pserver_num"); int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(context.GetPlace()); auto &actual_ctx = *pool.Get(platform::CPUPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
...@@ -144,7 +144,6 @@ void prefetch_core( ...@@ -144,7 +144,6 @@ void prefetch_core(
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
} }
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout( PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient")); "internal error in RPCClient"));
...@@ -167,6 +166,7 @@ void prefetch_core( ...@@ -167,6 +166,7 @@ void prefetch_core(
for (int64_t i = 0; i < dims[0]; ++i) { for (int64_t i = 0; i < dims[0]; ++i) {
auto origin_id = ids_in_this_section[i]; auto origin_id = ids_in_this_section[i];
std::vector<float> vecs(row_numel); std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin()); std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs; (*recved_vec_map)[origin_id] = vecs;
} }
...@@ -213,18 +213,18 @@ void prefetchs(const std::vector<std::string> &id_var_names, ...@@ -213,18 +213,18 @@ void prefetchs(const std::vector<std::string> &id_var_names,
const auto place = const auto place =
scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place(); scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place();
if (!platform::is_cpu_place(place)) { std::vector<std::vector<int64_t>> ids_group;
PADDLE_THROW("multi prefetch only support CPU currently");
}
std::vector<int64_t> ids_union; std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables; TableAndEndpoints tables;
for (auto &id_name : id_var_names) { for (auto &id_name : id_var_names) {
auto *in_var = scope.FindVar(id_name); auto &id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto &id_tensor = in_var->Get<framework::LoDTensor>(); std::vector<int64_t> ids;
std::copy_n(id_tensor.data<int64_t>(), id_tensor.numel(), TensorToVector(id_tensor, context.device_context(), &ids);
back_inserter(ids_union)); ids_union.insert(ids_union.end(), ids.begin(), ids.end());
ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod());
} }
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end()); std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
...@@ -258,25 +258,48 @@ void prefetchs(const std::vector<std::string> &id_var_names, ...@@ -258,25 +258,48 @@ void prefetchs(const std::vector<std::string> &id_var_names,
} }
for (size_t i = 0; i < out_var_names.size(); i++) { for (size_t i = 0; i < out_var_names.size(); i++) {
auto *in_var = scope.FindVar(id_var_names[i]); std::vector<int64_t> ids = ids_group[i];
auto &id_tensor = in_var->Get<framework::LoDTensor>(); auto ids_size = ids.size();
auto ids_size = id_tensor.dims()[0];
const auto *id_data = id_tensor.data<int64_t>();
auto *out_t = auto *out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>(); scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->set_lod(id_tensor.lod()); out_t->set_lod(ids_lods[i]);
out_t->Resize(framework::make_ddim({ids_size, vec_dim_1})); out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids_size), vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place); auto *out_d = out_t->mutable_data<float>(place);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) { if (platform::is_cpu_place(out_t->place())) {
const auto &id = id_data[idx]; for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
if (padding_idx != distributed::kNoPadding && id == padding_idx) { const auto &id = ids[idx];
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); if (padding_idx != distributed::kNoPadding && id == padding_idx) {
} else { memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
std::copy_n(recved_vec_map[id].begin(), vec_dim_1, } else {
out_d + idx * vec_dim_1); std::copy_n(recved_vec_map[id].begin(), vec_dim_1,
out_d + idx * vec_dim_1);
}
}
} else {
#ifdef PADDLE_WITH_CUDA
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = ids[idx];
auto stream = context.cuda_device_context().stream();
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
platform::GpuMemsetAsync(out_d + idx * vec_dim_1, 0,
sizeof(float) * vec_dim_1, stream);
} else {
auto &cpu_place =
BOOST_GET_CONST(platform::CPUPlace,
paddle::platform::CPUDeviceContext().GetPlace());
auto &gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, out_t->place());
memory::Copy(gpu_place, out_d + idx * vec_dim_1, cpu_place,
&recved_vec_map[id][0], sizeof(float) * vec_dim_1,
stream);
}
} }
#else
PADDLE_ENFORCE(true, platform::errors::PermissionDenied(
"Paddle is not compiled with GPU!"));
#endif
} }
} }
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,6 +14,7 @@ limitations under the License. */ ...@@ -17,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h" #include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -75,47 +73,6 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -75,47 +73,6 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
} }
}; };
template <typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
auto id_names = context.InputNames("Ids");
auto embedding_name = context.InputNames("W").front();
auto out_names = context.OutputNames("Outputs");
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto is_distributed = context.Attr<bool>("is_distributed");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs(id_names, out_names, embedding_name,
is_distributed, lookup_tables, endpoints,
context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto &scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]);
auto *out_var = scope.FindVar(out_names[i]);
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
}
};
class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -170,15 +127,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -170,15 +127,12 @@ class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Lookup Tablel Prefetch Operator. Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W, This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor. then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W; the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up. random value and set the value into the table for the next looking up.
)DOC"); )DOC");
} }
}; };
...@@ -191,4 +145,5 @@ REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp, ...@@ -191,4 +145,5 @@ REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp,
ops::DistributedLookupTableOpMaker); ops::DistributedLookupTableOpMaker);
REGISTER_OP_CPU_KERNEL(distributed_lookup_table, REGISTER_OP_CPU_KERNEL(distributed_lookup_table,
ops::DistributedLookupTableKernel<float>); ops::DistributedLookupTableKernel<
paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/distributed_lookup_table_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
distributed_lookup_table,
ops::DistributedLookupTableKernel<plat::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
auto id_names = context.InputNames("Ids");
auto embedding_name = context.InputNames("W").front();
auto out_names = context.OutputNames("Outputs");
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto is_distributed = context.Attr<bool>("is_distributed");
operators::distributed::prefetchs(id_names, out_names, embedding_name,
is_distributed, lookup_tables, endpoints,
context, context.scope());
}
};
} // namespace operators
} // namespace paddle
...@@ -44,7 +44,7 @@ class RecvSaveOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,7 @@ class RecvSaveOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace()); platform::CPUPlace());
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
...@@ -393,6 +393,12 @@ class FleetTranspiler(Fleet): ...@@ -393,6 +393,12 @@ class FleetTranspiler(Fleet):
"in fleet.save_inference_model() function, executor must be as Executor type" "in fleet.save_inference_model() function, executor must be as Executor type"
) )
# Todo(MrChengmo): support recv&save GPU-Kernel for ps-gpu model save
if not isinstance(executor.place, fluid.CPUPlace):
save_executor = Executor(fluid.CPUPlace())
else:
save_executor = executor
if main_program is not None: if main_program is not None:
if isinstance(main_program, CompiledProgram): if isinstance(main_program, CompiledProgram):
raise TypeError( raise TypeError(
...@@ -670,6 +676,11 @@ if you would like to save all variables in a ...@@ -670,6 +676,11 @@ if you would like to save all variables in a
raise TypeError( raise TypeError(
"in fleet.save_persistables() function, executor must be as Executor type" "in fleet.save_persistables() function, executor must be as Executor type"
) )
# Todo(MrChengmo): support recv&save GPU-Kernel for ps-gpu model save
if not isinstance(executor.place, fluid.CPUPlace):
save_executor = Executor(fluid.CPUPlace())
else:
save_executor = executor
if main_program is None: if main_program is None:
main_program = self.main_program main_program = self.main_program
...@@ -679,7 +690,8 @@ if you would like to save all variables in a ...@@ -679,7 +690,8 @@ if you would like to save all variables in a
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
self._save_distributed_persistables(executor, dirname, main_program) self._save_distributed_persistables(save_executor, dirname,
main_program)
@staticmethod @staticmethod
def __exclude_vars(exclude_var_names=[]): def __exclude_vars(exclude_var_names=[]):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Distribute CTR model for test fleet api
"""
from __future__ import print_function
import shutil
import tempfile
import time
import paddle
import paddle.fluid as fluid
import os
import numpy as np
import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
from paddle.distributed.fleet.base.util_factory import fleet_util
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestDistGpuPsCTR2x2(TestDistCTR2x2):
"""
For test CTR model, using Fleet api & PS-GPU
"""
def check_model_right(self, dirname):
model_filename = os.path.join(dirname, "__model__")
with open(model_filename, "rb") as f:
program_desc_str = f.read()
program = fluid.Program.parse_from_string(program_desc_str)
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
wn.write(str(program))
def do_pyreader_training(self, fleet):
"""
do training using dataset, using fetch handler to catch variable
Args:
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
"""
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
for epoch_id in range(1):
self.reader.start()
try:
pass_start = time.time()
while True:
loss_val = exe.run(program=fleet.main_program,
fetch_list=[self.avg_cost.name])
loss_val = np.mean(loss_val)
reduce_output = fleet_util.all_reduce(
np.array(loss_val), mode="sum")
loss_all_trainer = fleet_util.all_gather(float(loss_val))
loss_val = float(reduce_output) / len(loss_all_trainer)
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val)
fleet_util.print_on_rank(message, 0)
pass_time = time.time() - pass_start
except fluid.core.EOFException:
self.reader.reset()
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(
exe, model_dir, [feed.name for feed in self.feeds], self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker()
def do_dataset_training(self, fleet):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
thread_num = 2
batch_size = 128
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
# config dataset
dataset = paddle.fleet.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(1):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
program=fleet.main_program,
dataset=dataset,
fetch_list=[self.avg_cost],
fetch_info=["cost"],
print_period=2,
debug=int(os.getenv("Debug", "0")))
pass_time = time.time() - pass_start
if os.getenv("SAVE_MODEL") == "1":
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
if fleet.is_first_worker():
fleet.save_persistables(executor=exe, dirname=model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker()
if __name__ == "__main__":
runtime_main(TestDistGpuPsCTR2x2)
...@@ -278,6 +278,23 @@ class TestFleetBase(unittest.TestCase): ...@@ -278,6 +278,23 @@ class TestFleetBase(unittest.TestCase):
tr0_ret = tr0.returncode tr0_ret = tr0.returncode
tr1_ret = tr0.returncode tr1_ret = tr0.returncode
if tr0_ret != 0:
print(
"========================Error tr0_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log"))
print(
"========================Error tr0_err end==========================="
)
if tr1_ret != 0:
print(
"========================Error tr1_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log"))
print(
"========================Error tr1_err end==========================="
)
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check") self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check") self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
......
...@@ -156,5 +156,40 @@ class TestDistCtrHalfAsync2x2(TestFleetBase): ...@@ -156,5 +156,40 @@ class TestDistCtrHalfAsync2x2(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistCtrPsGpuPyreaderAsync2x2(TestFleetBase):
def _setup_config(self):
self._mode = "async"
self._reader = "pyreader"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "30000", # 5sec to fail fast
"http_proxy": "",
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"CPU_NUM": "2",
"SAVE_MODEL": "1"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr_ps_gpu.py", delta=1e-5, check_error_log=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册