未验证 提交 a86f11b5 编写于 作者: C Chengmo 提交者: GitHub

Speed GEO dense calc & communication (#21579)

* test=develop, speed dense calc & communication
上级 666c3bb9
......@@ -364,12 +364,14 @@ class GeoSgdCommunicator : public Communicator {
const std::vector<SparseIdsMap>& ids_send_vec,
const std::string& var_name, const std::string& splited_var_name);
void SendUpdateDenseVars(const std::string& var_name);
void SendUpdateDenseVars(const std::string& var_name,
const std::string& splited_var_name);
void SendUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name,
const std::unordered_set<int64_t>& ids_table);
void RecvUpdateDenseVars(const std::string& var_name);
void RecvUpdateDenseVars(const std::string& var_name,
const std::string& splited_var_name);
void RecvUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name);
......@@ -420,21 +422,32 @@ class GeoSgdCommunicator : public Communicator {
int trainer_nums_ = 1;
size_t geo_need_push_nums_ = 100;
bool is_geo_sgd_ = false;
Scope* training_scope_;
std::shared_ptr<Scope> delta_scope_; // parameter local delta: recv - old
std::shared_ptr<Scope>
old_scope_; // parameter local, storage the param after last recv
std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope
int send_var_nums_ = 0;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unordered_map<std::string, bool>
var_list_; // if var is sparse, using selected rows, bool=true
// parameter for local training
Scope* training_scope_;
// parameter for delta calc and send
std::shared_ptr<Scope> delta_scope_;
// parameter for storage the pserver param after last recv
std::shared_ptr<Scope> old_scope_;
// parameter on pserver
std::shared_ptr<Scope> pserver_scope_;
// if var is sparse, using selected rows, bool=true
std::unordered_map<std::string, bool> var_list_;
std::shared_ptr<BlockingQueue<std::shared_ptr<SparseIdsMap>>>
need_push_queue_;
std::vector<SparseIdsMap> ids_send_vec_;
std::unordered_map<std::string, std::vector<int64_t>> absolute_section_;
std::unordered_map<std::string, int64_t> vars_first_dimension_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<std::thread> send_thread_{nullptr};
......
......@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Distribute CTR model for test fleet api
"""
from __future__ import print_function
......@@ -30,10 +33,22 @@ fluid.default_main_program().random_seed = 1
class TestDistCTR2x2(FleetDistRunnerBase):
"""
For test CTR model, using Fleet api
"""
def net(self, batch_size=4, lr=0.01):
"""
network definition
Args:
batch_size(int): the size of mini-batch for training
lr(float): learning rate of training
Returns:
avg_cost: LoDTensor of cost.
"""
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
""" network definition """
dnn_data = fluid.layers.data(
name="dnn_data",
shape=[-1, 1],
......@@ -56,7 +71,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
datas = [dnn_data, lr_data, label]
# build dnn model
dnn_layer_dims = [128, 64, 32, 1]
# add 12800 for test huge dense Variable
dnn_layer_dims = [128, 128000, 64, 32, 1]
dnn_embedding = fluid.layers.embedding(
is_distributed=False,
input=dnn_data,
......@@ -116,6 +132,11 @@ class TestDistCTR2x2(FleetDistRunnerBase):
wn.write(str(program))
def do_training(self, fleet):
"""
do training using dataset, using fetch handler to catch variable
Args:
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
"""
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
......@@ -163,9 +184,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe.train_from_dataset(
program=fleet.main_program,
dataset=dataset,
fetch_handler=FH([self.avg_cost.name],
period_secs=2,
return_np=True),
fetch_handler=FH([self.avg_cost.name], period_secs=2),
debug=False)
pass_time = time.time() - pass_start
......
......@@ -46,7 +46,7 @@ class TestDistGeoCtr_2x2(TestFleetBase):
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_v"] = "4"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册