提交 68ba6532 编写于 作者: Y yao_yf

add field in stra ckpt

上级 cbb4363f
...@@ -129,6 +129,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf ...@@ -129,6 +129,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
param_split_shape->add_dim(dim_pair.first); param_split_shape->add_dim(dim_pair.first);
indices_offset->add_dim(dim_pair.second); indices_offset->add_dim(dim_pair.second);
} }
parallel_layouts->set_field(tensor_layout.get_field_size());
} }
std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary); std::fstream output(save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
......
...@@ -53,6 +53,7 @@ message ParallelLayouts { ...@@ -53,6 +53,7 @@ message ParallelLayouts {
repeated TensorMap tensor_map = 2; repeated TensorMap tensor_map = 2;
repeated ParamSplitShape param_split_shape = 3; repeated ParamSplitShape param_split_shape = 3;
repeated IndicesOffset indices_offset = 4; repeated IndicesOffset indices_offset = 4;
required int32 field = 5;
} }
message ParallelLayoutItem { message ParallelLayoutItem {
......
...@@ -161,13 +161,9 @@ class WideDeepModel(nn.Cell): ...@@ -161,13 +161,9 @@ class WideDeepModel(nn.Cell):
self.layer_dims = self.deep_layer_dims_list + [1] self.layer_dims = self.deep_layer_dims_list + [1]
self.all_dim_list = [self.deep_input_dims] + self.layer_dims self.all_dim_list = [self.deep_input_dims] + self.layer_dims
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init), init_acts = [('Wide_b', [1], self.emb_init)]
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
('Wide_b', [1], self.emb_init)]
var_map = init_var_dict(self.init_args, init_acts) var_map = init_var_dict(self.init_args, init_acts)
self.wide_w = var_map["Wide_w"]
self.wide_b = var_map["Wide_b"] self.wide_b = var_map["Wide_b"]
self.embedding_table = var_map["V_l2"]
if parameter_server: if parameter_server:
self.wide_w.set_param_ps() self.wide_w.set_param_ps()
self.embedding_table.set_param_ps() self.embedding_table.set_param_ps()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册