diff --git a/paddle/fluid/operators/merge_ids_op.cc b/paddle/fluid/operators/merge_ids_op.cc index c6ec4ab047d5e91625e646fd26108d2e477cdce5..6e0e13698097ade36449f2e8ff6ab981a1b24311 100644 --- a/paddle/fluid/operators/merge_ids_op.cc +++ b/paddle/fluid/operators/merge_ids_op.cc @@ -20,13 +20,16 @@ namespace operators { class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddInput( - "X", - "(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " - "size of embedding table") + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}") + .AsDuplicable(); + AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ") + .AsDuplicable(); + AddInput("X", + "(LoDTensors) multi input tensor with shape{Rows, N}, N is the " + "size of embedding table") + .AsDuplicable(); + AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.") .AsDuplicable(); - AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddComment(R"DOC( Merge multi LoDTensor's into one according to Ids's shard num. @@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids."); - PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out."); + PADDLE_ENFORCE(ctx->HasInputs("Ids"), + "MergeIdsOp must has multi input Ids."); + PADDLE_ENFORCE(ctx->HasInputs("Rows"), + "MergeIdsOp must has multi input Rows."); + PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has multi input X."); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), + "MergeIdsOp must has multi output Out."); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - auto ids_dims = ctx->GetInputDim("Ids"); + auto ids_dims = ctx->GetInputsDim("Ids"); if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[0][1], 1); } auto x_var_type = ctx->GetInputsVarType("X"); for (auto &var_type : x_var_type) { diff --git a/paddle/fluid/operators/merge_ids_op.h b/paddle/fluid/operators/merge_ids_op.h index 83712a8519c6817151e1922c606c0fdd4682a2db..fef9e023d02f45e21ec409ad398ba7d9bdd36880 100644 --- a/paddle/fluid/operators/merge_ids_op.h +++ b/paddle/fluid/operators/merge_ids_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" @@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel { if (!platform::is_cpu_place(place)) { PADDLE_THROW("MergeIds do not support GPU kernel"); } - VLOG(3) << "run in MergeIdsOpKernel"; - const auto *ids_var = ctx.InputVar("Ids"); - PADDLE_ENFORCE(ids_var->IsType(), - "only support to merge Ids of LoDTensor"); + const auto ids = ctx.MultiInput("Ids"); + const auto row_ids = ctx.MultiInput("Rows"); + const auto x_tensors = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); - const auto &ids_tensor = ids_var->Get(); - const auto &ids_dims = ids_tensor.dims(); - const int64_t *ids = ids_tensor.data(); + PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(), + "the number of Rows and X should be the same"); + PADDLE_ENFORCE_EQ(ids.size(), outs.size(), + "the number of Ids and Out should be the same"); - auto x_tensors = ctx.MultiInput("X"); + int row_ids_size = 0; + int row_size = 0; + int embedding_size = 0; - auto *out = ctx.Output("Out"); + for (int i = 0; i < x_tensors.size(); ++i) { + const auto *x_tensor = x_tensors[i]; + const auto *row_id = row_ids[i]; - int batch_size = 0; - int embedding_size = 0; - for (auto &input : x_tensors) { - if (framework::product(input->dims()) != 0) { - if (embedding_size == 0) { - embedding_size = input->dims()[1]; - } - PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], - "embedding size of all input should be the same"); - batch_size += input->dims()[0]; + if (embedding_size == 0) { + embedding_size = x_tensor->dims()[1]; } + PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1], + "embedding size of all input should be the same"); + row_size += x_tensor->dims()[0]; + row_ids_size += row_id->dims()[0]; } + PADDLE_ENFORCE_EQ( - batch_size, ids_dims[0], - "the batch size of ids and merged embedding value should be the same"); + row_size, row_ids_size, + "the merged X dim[0] and merged Rows dim[0] should be the same"); + + std::unordered_map> + selected_rows_idx_map; + for (int i = 0; i < x_tensors.size(); ++i) { + const auto *row_id = row_ids[i]; + + for (int j = 0; j < row_id->numel(); ++j) { + int64_t key = row_id->data()[j]; + std::tuple val = std::make_tuple(i, j); + selected_rows_idx_map.insert(std::make_pair(key, val)); + } + } + PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(), + "the rows and tensor map size should be the same"); + + for (int i = 0; i < outs.size(); ++i) { + auto *out_ids = ids[i]; + auto *out = outs[i]; - const size_t shard_num = x_tensors.size(); + out->set_lod(out_ids->lod()); - if (shard_num == 1) { - VLOG(3) << "only one shard, we can copy the data directly"; - TensorCopy(*x_tensors[0], place, out); - } else { - std::vector in_indexs(shard_num, 0); + int nums = static_cast(out_ids->dims()[0]); auto *out_data = out->mutable_data( - framework::make_ddim({batch_size, embedding_size}), place); - // copy data from ins[shard_num] to out. - for (int i = 0; i < ids_dims[0]; ++i) { - int64_t id = ids[i]; - size_t shard_id = static_cast(id) % shard_num; - int index = in_indexs[shard_id]; - memcpy(out_data + embedding_size * i, - x_tensors[shard_id]->data() + index * embedding_size, + framework::make_ddim({nums, embedding_size}), place); + for (int j = 0; j < nums; ++j) { + int id = out_ids->data()[j]; + auto row_tuple = selected_rows_idx_map[id]; + int64_t row_idx = std::get<1>(row_tuple); + const auto *x_tensor = x_tensors[std::get<0>(row_tuple)]; + + memcpy(out_data + embedding_size * j, + x_tensor->data() + row_idx * embedding_size, sizeof(T) * embedding_size); - in_indexs[shard_id] += 1; - } - - for (size_t i = 0; i < shard_num; ++i) { - PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0], - "after merge, all data in x_tensor should be used"); } } } diff --git a/paddle/fluid/operators/split_ids_op.cc b/paddle/fluid/operators/split_ids_op.cc index c867c46873ae7ddbdbda280351e4ab28235bcc08..243f81e296fb95a2c7e9f717950b8a958ad98852 100644 --- a/paddle/fluid/operators/split_ids_op.cc +++ b/paddle/fluid/operators/split_ids_op.cc @@ -20,20 +20,27 @@ namespace operators { class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); - AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") + AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}") + .AsDuplicable(); + + AddOutput("Out", "(LoDTensors) The outputs of the input Ids.") .AsDuplicable(); AddComment(R"DOC( Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number Example: Input: - X = [1,2,3,4,5,6] + X = [[1,2,3,4,5,6],[2,3]] Out(3 output): - out0 = [3, 6] - out1 = [1, 4] - out2 = [2, 5] + if compress is True: + out0 = [3, 3, 6] + out1 = [1, 4] + out2 = [2, 2, 5] + else: + out0 = [3, 6] + out1 = [1, 4] + out2 = [2, 5] )DOC"); } }; @@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); + PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); auto ids_var_type = ctx->GetInputsVarType("Ids").front(); - auto ids_dims = ctx->GetInputDim("Ids"); + auto ids_dims = ctx->GetInputsDim("Ids"); if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.MultiInput("Ids").front()->type()), + ctx.GetPlace()); + } }; class SplitIdsOpInferVarType : public framework::VarTypeInference { @@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference { } }; +class SplitIdsOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto grad = new framework::OpDesc(); + grad->SetType("concat"); + grad->SetInput("X", OutputGrad("Out")); + grad->SetOutput("Out", InputGrad("Ids")); + grad->SetAttr("axis", 0); + return std::unique_ptr(grad); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, - ops::SplitIdsOpInferVarType); + ops::SplitIdsOpGradMaker, ops::SplitIdsOpInferVarType); + REGISTER_OP_CPU_KERNEL( split_ids, ops::SplitIdsOpKernel, ops::SplitIdsOpKernel); diff --git a/paddle/fluid/operators/split_ids_op.h b/paddle/fluid/operators/split_ids_op.h index c4af5a65fc5f81c1af7c1fdcca637ca37c940637..69ac6c5a6b9a8b318520eb9a3ff89a3a6be48339 100644 --- a/paddle/fluid/operators/split_ids_op.h +++ b/paddle/fluid/operators/split_ids_op.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include +#include #include #include #include "paddle/fluid/framework/op_registry.h" @@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel { PADDLE_THROW("SplitIds do not support GPU kernel"); } - const auto *ids_var = ctx.InputVar("Ids"); + const auto ids_vars = ctx.MultiInputVar("Ids"); + + PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0"); + auto *ids_var = ids_vars[0]; + if (ids_var->IsType()) { - const auto &ids_dims = ctx.Input("Ids")->dims(); - const T *ids = ctx.Input("Ids")->data(); + int batch_size = 0; + const auto ids_tensors = ctx.MultiInput("Ids"); + for (size_t i = 0; i < ids_tensors.size(); ++i) { + batch_size += ids_tensors[i]->dims()[0]; + } + VLOG(4) << "Get Total BatchSize is: " << batch_size; + + std::vector all_ids(batch_size); + int offset = 0; + for (size_t i = 0; i < ids_tensors.size(); ++i) { + const auto *ids = ids_tensors[i]; + std::memcpy(all_ids.data() + offset, ids->data(), + ids->numel() * sizeof(T)); + offset += ids->numel(); + } + + std::set st(all_ids.begin(), all_ids.end()); + all_ids.assign(st.begin(), st.end()); + auto outs = ctx.MultiOutput("Out"); const size_t shard_num = outs.size(); - std::vector> out_ids; out_ids.resize(outs.size()); // split id by their shard_num. - for (int i = 0; i < ids_dims[0]; ++i) { - T id = ids[i]; + for (int i = 0; i < all_ids.size(); ++i) { + T id = all_ids[i]; size_t shard_id = static_cast(id) % shard_num; out_ids[shard_id].push_back(id); } @@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(ids_dims[0], static_cast(ids_selected_rows->rows().size()), ""); - const T *ids = ids_selected_rows->value().data(); + const T *ids_data = ids_selected_rows->value().data(); const auto &ids_rows = ids_selected_rows->rows(); auto outs = ctx.MultiOutput("Out"); const size_t shard_num = outs.size(); @@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel { T *output = out->mutable_value()->mutable_data(ddim, place); for (int64_t i = 0; i < ddim[0]; ++i) { memcpy(output + i * row_width, - ids + id_to_index[out->rows()[i]] * row_width, + ids_data + id_to_index[out->rows()[i]] * row_width, row_width * sizeof(T)); } } diff --git a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py index 26ce7024117162e8bad403a9d8b8518c27578c83..b109e4ea62669c735128f4824eb9d02ad43900e0 100644 --- a/python/paddle/fluid/tests/unittests/test_merge_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_merge_ids_op.py @@ -22,15 +22,28 @@ from op_test import OpTest class TestMergeIdsOp(OpTest): def setUp(self): self.op_type = "merge_ids" - ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') - x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') - x1 = np.array([]).astype('float32') - x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], - [0.5, 0.6]]).astype('float32') - out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], - [0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32') - self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} - self.outputs = {'Out': out} + ids1 = np.array([[0], [2], [5], [6]]).astype('int64') + ids2 = np.array([[0], [2], [2], [3]]).astype('int64') + + rows1 = np.array([[0], [2]]).astype('int64') + rows2 = np.array([[3], [5]]).astype('int64') + rows3 = np.array([[6]]).astype('int64') + + x0 = np.array([[0.1, 0.2], [0.2, 0.3]]).astype('float32') + x1 = np.array([[0.3, 0.4], [0.4, 0.5]]).astype('float32') + x2 = np.array([[0.5, 0.6]]).astype('float32') + + out1 = np.array( + [[0.1, 0.2], [0.2, 0.3], [0.4, 0.5], [0.5, 0.6]]).astype('float32') + out2 = np.array( + [[0.1, 0.2], [0.2, 0.3], [0.2, 0.3], [0.3, 0.4]]).astype('float32') + + self.inputs = { + 'Ids': [('ids1', ids1), ('ids2', ids2)], + "Rows": [('rows1', rows1), ('rows2', rows2), ('rows3', rows3)], + "X": [('x0', x0), ('x1', x1), ('x2', x2)] + } + self.outputs = {'Out': [('out1', out1), ('out2', out2)]} def test_check_output(self): self.check_output() diff --git a/python/paddle/fluid/tests/unittests/test_split_ids_op.py b/python/paddle/fluid/tests/unittests/test_split_ids_op.py index 4c3d0258980fd8595704a65219deb520b96e222e..d674dad2293921c06135b4ee528538d266cb2904 100644 --- a/python/paddle/fluid/tests/unittests/test_split_ids_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_ids_op.py @@ -25,18 +25,21 @@ from paddle.fluid.op import Operator class TestSplitIdsOp(OpTest): def setUp(self): self.op_type = "split_ids" - ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + ids1 = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') + ids2 = np.array([[6], [2], [3], [3], [5], [2], [6]]).astype('int64') + ids3 = np.array([[2], [2], [2], [3], [5], [5], [6]]).astype('int64') + out0 = np.array([[0], [3], [6]]).astype('int64') out1 = np.array([[]]).astype('int64') - out2 = np.array([[2], [2], [5], [5]]).astype('int64') - self.inputs = {'Ids': ids} + out2 = np.array([[2], [5]]).astype('int64') + self.inputs = {'Ids': [('ids1', ids1), ('ids2', ids2), ('ids3', ids3)]} self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} def test_check_output(self): self.check_output() -class TestSpliteIds(unittest.TestCase): +class TestSplitSelectedRows(unittest.TestCase): def get_places(self): places = [core.CPUPlace()] return places diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 2192139f8d5950286691a77333dd8ec35505b033..677a67d3dbce14097f1ccf799007905cc58551d6 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -712,7 +712,7 @@ in a single call.") for _, op in enumerate(self.optimize_ops): # optimizer is connected to itself if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \ - op not in global_ops: + op not in global_ops: log("append opt op: ", op.type, op.input_arg_names, merged_var) __append_optimize_op__(op, per_opt_block, @@ -1033,15 +1033,11 @@ to transpile() call.") def _replace_lookup_table_op_with_prefetch(self, program, pserver_endpoints): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op - # self.all_prefetch_input_vars = - # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1] - # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]] + self.all_in_ids_vars = [] self.all_prefetch_input_vars = [] - - # self.all_prefetch_input_vars = - # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1] - # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]] self.all_prefetch_output_vars = [] + self.all_out_emb_vars = [] + lookup_table_op_index = -1 continue_search_lookup_table_op = True while continue_search_lookup_table_op: @@ -1051,72 +1047,68 @@ to transpile() call.") if op.type == LOOKUP_TABLE_TYPE: continue_search_lookup_table_op = True - lookup_table_op_index = list(all_ops).index(op) + lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list( + all_ops).index(op) ids_name = op.input("Ids") out_name = op.output("Out") ids_var = program.global_block().vars[ids_name[0]] - prefetch_input_vars = self._create_splited_vars( - source_var=ids_var, - block=program.global_block(), - tag="_prefetch_in_") - self.all_prefetch_input_vars.append(prefetch_input_vars) + self.all_in_ids_vars.append(ids_var) out_var = program.global_block().vars[out_name[0]] - prefetch_output_vars = self._create_splited_vars( - source_var=out_var, - block=program.global_block(), - tag="_prefetch_out_") - self.all_prefetch_output_vars.append(prefetch_output_vars) - - # insert split_ids_op - program.global_block()._insert_op( - index=lookup_table_op_index, - type="split_ids", - inputs={ - 'Ids': [ - program.global_block().vars[varname] - for varname in ids_name - ] - }, - outputs={"Out": prefetch_input_vars}) - - # insert prefetch_op - program.global_block()._insert_op( - index=lookup_table_op_index + 1, - type="prefetch", - inputs={'X': prefetch_input_vars}, - outputs={"Out": prefetch_output_vars}, - attrs={ - "epmap": pserver_endpoints, - # FIXME(qiao) temporarily disable this config because prefetch - # is not act as other rpc op, it's more like a forward op - # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) - - # insert concat_op - program.global_block()._insert_op( - index=lookup_table_op_index + 2, - type="merge_ids", - inputs={ - 'Ids': [ - program.global_block().vars[varname] - for varname in ids_name - ], - 'X': prefetch_output_vars - }, - outputs={ - "Out": [ - program.global_block().vars[varname] - for varname in out_name - ] - }) + self.all_out_emb_vars.append(out_var) # delete lookup_table_op delete_ops(program.global_block(), [op]) # break for loop break + for index in range(len(self.pserver_endpoints)): + in_var = program.global_block().create_var( + name=str("prefetch_compress_in_tmp_" + str(index)), + type=self.all_in_ids_vars[0].type, + shape=self.all_in_ids_vars[0].shape, + dtype=self.all_in_ids_vars[0].dtype) + self.all_prefetch_input_vars.append(in_var) + + out_var = program.global_block().create_var( + name=str("prefetch_compress_out_tmp_" + str(index)), + type=self.all_out_emb_vars[0].type, + shape=self.all_out_emb_vars[0].shape, + dtype=self.all_out_emb_vars[0].dtype) + self.all_prefetch_output_vars.append(out_var) + + # insert split_ids_op + program.global_block()._insert_op( + index=lookup_table_op_index, + type="split_ids", + inputs={'Ids': self.all_in_ids_vars}, + outputs={"Out": self.all_prefetch_input_vars}) + + # insert prefetch_op + program.global_block()._insert_op( + index=lookup_table_op_index + 1, + type="prefetch", + inputs={'X': self.all_prefetch_input_vars}, + outputs={"Out": self.all_prefetch_output_vars}, + attrs={ + "epmap": pserver_endpoints, + # FIXME(qiao) temporarily disable this config because prefetch + # is not act as other rpc op, it's more like a forward op + # RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + # insert concat_op + program.global_block()._insert_op( + index=lookup_table_op_index + 2, + type="merge_ids", + inputs={ + 'Ids': self.all_in_ids_vars, + 'Rows': self.all_prefetch_input_vars, + 'X': self.all_prefetch_output_vars + }, + outputs={"Out": self.all_out_emb_vars}) + def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_op to send gradient to pservers @@ -1159,32 +1151,31 @@ to transpile() call.") # STEP: create prefetch block table_var = pserver_program.global_block().vars[self.table_name] prefetch_var_name_to_block_id = [] - for index in range(len(self.all_prefetch_input_vars)): - prefetch_block = pserver_program._create_block(optimize_block.idx) - trainer_ids = self.all_prefetch_input_vars[index][pserver_index] - pserver_ids = pserver_program.global_block().create_var( - name=trainer_ids.name, - type=trainer_ids.type, - shape=trainer_ids.shape, - dtype=trainer_ids.dtype) - trainer_out = self.all_prefetch_output_vars[index][pserver_index] - pserver_out = pserver_program.global_block().create_var( - name=trainer_out.name, - type=trainer_out.type, - shape=trainer_out.shape, - dtype=trainer_out.dtype) - prefetch_block.append_op( - type="lookup_sparse_table", - inputs={'Ids': pserver_ids, - "W": table_var}, - outputs={"Out": pserver_out}, - attrs={ - "is_sparse": True, # has no effect on lookup_table op - "is_distributed": True, - "padding_idx": -1 - }) - prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str( - prefetch_block.idx)) + prefetch_block = pserver_program._create_block(optimize_block.idx) + trainer_ids = self.all_prefetch_input_vars[pserver_index] + pserver_ids = pserver_program.global_block().create_var( + name=trainer_ids.name, + type=trainer_ids.type, + shape=trainer_ids.shape, + dtype=trainer_ids.dtype) + trainer_out = self.all_prefetch_output_vars[pserver_index] + pserver_out = pserver_program.global_block().create_var( + name=trainer_out.name, + type=trainer_out.type, + shape=trainer_out.shape, + dtype=trainer_out.dtype) + prefetch_block.append_op( + type="lookup_sparse_table", + inputs={'Ids': pserver_ids, + "W": table_var}, + outputs={"Out": pserver_out}, + attrs={ + "is_sparse": True, # has no effect on lookup_table op + "is_distributed": True, + "padding_idx": -1 + }) + prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str( + prefetch_block.idx)) return prefetch_var_name_to_block_id def _create_table_optimize_block(self, pserver_index, pserver_program, @@ -1363,16 +1354,6 @@ to transpile() call.") program.global_block()._sync_with_cpp() return var_mapping - def _create_splited_vars(self, source_var, block, tag): - return [ - block.create_var( - name=str(source_var.name + tag + str(index)), - type=source_var.type, - shape=source_var.shape, - dtype=source_var.dtype) - for index in range(len(self.pserver_endpoints)) - ] - def _clone_var(self, block, var, persistable=True): return block.create_var( name=var.name,