提交 b6590b05 编写于 作者: S seiriosPlus

submit by tangwei12, test=develop

上级 eef77fdd
...@@ -20,13 +20,16 @@ namespace operators { ...@@ -20,13 +20,16 @@ namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker { class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
AddInput( .AsDuplicable();
"X", AddInput("Rows", "(LoDTensor) the input ids with shape{row_size, 1}, ")
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the " .AsDuplicable();
AddInput("X",
"(LoDTensors) multi input tensor with shape{Rows, N}, N is the "
"size of embedding table") "size of embedding table")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors."); AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.")
.AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num. Merge multi LoDTensor's into one according to Ids's shard num.
...@@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -79,15 +82,19 @@ class MergeIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasInputs("Ids"),
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X."); "MergeIdsOp must has multi input Ids.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasInputs("Rows"),
"MergeIdsOp must has multi input Rows.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has multi input X.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
"MergeIdsOp must has multi output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[0][1], 1);
} }
auto x_var_type = ctx->GetInputsVarType("X"); auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) { for (auto &var_type : x_var_type) {
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <tuple>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -30,59 +32,70 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
if (!platform::is_cpu_place(place)) { if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel"); PADDLE_THROW("MergeIds do not support GPU kernel");
} }
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids"); const auto ids = ctx.MultiInput<framework::LoDTensor>("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(), const auto row_ids = ctx.MultiInput<framework::LoDTensor>("Rows");
"only support to merge Ids of LoDTensor"); const auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>(); PADDLE_ENFORCE_EQ(row_ids.size(), x_tensors.size(),
const auto &ids_dims = ids_tensor.dims(); "the number of Rows and X should be the same");
const int64_t *ids = ids_tensor.data<int64_t>(); PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same");
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X"); int row_ids_size = 0;
int row_size = 0;
int embedding_size = 0;
auto *out = ctx.Output<framework::LoDTensor>("Out"); for (int i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
const auto *row_id = row_ids[i];
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) { if (embedding_size == 0) {
embedding_size = input->dims()[1]; embedding_size = x_tensor->dims()[1];
} }
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1], PADDLE_ENFORCE_EQ(embedding_size, x_tensor->dims()[1],
"embedding size of all input should be the same"); "embedding size of all input should be the same");
batch_size += input->dims()[0]; row_size += x_tensor->dims()[0];
} row_ids_size += row_id->dims()[0];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0], row_size, row_ids_size,
"the batch size of ids and merged embedding value should be the same"); "the merged X dim[0] and merged Rows dim[0] should be the same");
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map;
for (int i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val));
}
}
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
"the rows and tensor map size should be the same");
for (int i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i];
auto *out = outs[i];
const size_t shard_num = x_tensors.size(); out->set_lod(out_ids->lod());
if (shard_num == 1) { int nums = static_cast<int>(out_ids->dims()[0]);
VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(*x_tensors[0], place, out);
} else {
std::vector<int> in_indexs(shard_num, 0);
auto *out_data = out->mutable_data<T>( auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place); framework::make_ddim({nums, embedding_size}), place);
// copy data from ins[shard_num] to out. for (int j = 0; j < nums; ++j) {
for (int i = 0; i < ids_dims[0]; ++i) { int id = out_ids->data<int64_t>()[j];
int64_t id = ids[i]; auto row_tuple = selected_rows_idx_map[id];
size_t shard_id = static_cast<size_t>(id) % shard_num; int64_t row_idx = std::get<1>(row_tuple);
int index = in_indexs[shard_id]; const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * i,
x_tensors[shard_id]->data<T>() + index * embedding_size, memcpy(out_data + embedding_size * j,
x_tensor->data<T>() + row_idx * embedding_size,
sizeof(T) * embedding_size); sizeof(T) * embedding_size);
in_indexs[shard_id] += 1;
}
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
} }
} }
} }
......
...@@ -20,17 +20,24 @@ namespace operators { ...@@ -20,17 +20,24 @@ namespace operators {
class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker { class SplitIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}"); AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}")
AddOutput("Out", "(LoDTensor) The outputs of the input Ids.") .AsDuplicable();
AddOutput("Out", "(LoDTensors) The outputs of the input Ids.")
.AsDuplicable(); .AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number Split a LoDTensor of Ids into multi LoDTensors, the number is pserver's number
Example: Example:
Input: Input:
X = [1,2,3,4,5,6] X = [[1,2,3,4,5,6],[2,3]]
Out(3 output): Out(3 output):
if compress is True:
out0 = [3, 3, 6]
out1 = [1, 4]
out2 = [2, 2, 5]
else:
out0 = [3, 6] out0 = [3, 6]
out1 = [1, 4] out1 = [1, 4]
out2 = [2, 5] out2 = [2, 5]
...@@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -43,16 +50,24 @@ class SplitIdsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "SplitIdsOp must has input Ids."); PADDLE_ENFORCE(ctx->HasInputs("Ids"), "SplitIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out."); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "SplitIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputsDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims[0].size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("Ids").front()->type()),
ctx.GetPlace());
}
}; };
class SplitIdsOpInferVarType : public framework::VarTypeInference { class SplitIdsOpInferVarType : public framework::VarTypeInference {
...@@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference { ...@@ -66,12 +81,28 @@ class SplitIdsOpInferVarType : public framework::VarTypeInference {
} }
}; };
class SplitIdsOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad = new framework::OpDesc();
grad->SetType("concat");
grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("Ids"));
grad->SetAttr("axis", 0);
return std::unique_ptr<framework::OpDesc>(grad);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker, REGISTER_OPERATOR(split_ids, ops::SplitIdsOp, ops::SplitIdsOpMaker,
ops::SplitIdsOpInferVarType); ops::SplitIdsOpGradMaker, ops::SplitIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>, split_ids, ops::SplitIdsOpKernel<paddle::platform::CPUPlace, int64_t>,
ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>); ops::SplitIdsOpKernel<paddle::platform::CPUPlace, float>);
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <iterator>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -31,19 +33,39 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_THROW("SplitIds do not support GPU kernel"); PADDLE_THROW("SplitIds do not support GPU kernel");
} }
const auto *ids_var = ctx.InputVar("Ids"); const auto ids_vars = ctx.MultiInputVar("Ids");
PADDLE_ENFORCE_GT(ids_vars.size(), 0, "The number of Ids should > 0");
auto *ids_var = ids_vars[0];
if (ids_var->IsType<framework::LoDTensor>()) { if (ids_var->IsType<framework::LoDTensor>()) {
const auto &ids_dims = ctx.Input<framework::LoDTensor>("Ids")->dims(); int batch_size = 0;
const T *ids = ctx.Input<framework::LoDTensor>("Ids")->data<T>(); const auto ids_tensors = ctx.MultiInput<framework::LoDTensor>("Ids");
for (size_t i = 0; i < ids_tensors.size(); ++i) {
batch_size += ids_tensors[i]->dims()[0];
}
VLOG(4) << "Get Total BatchSize is: " << batch_size;
std::vector<T> all_ids(batch_size);
int offset = 0;
for (size_t i = 0; i < ids_tensors.size(); ++i) {
const auto *ids = ids_tensors[i];
std::memcpy(all_ids.data() + offset, ids->data<T>(),
ids->numel() * sizeof(T));
offset += ids->numel();
}
std::set<T> st(all_ids.begin(), all_ids.end());
all_ids.assign(st.begin(), st.end());
auto outs = ctx.MultiOutput<framework::LoDTensor>("Out"); auto outs = ctx.MultiOutput<framework::LoDTensor>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
std::vector<std::vector<T>> out_ids; std::vector<std::vector<T>> out_ids;
out_ids.resize(outs.size()); out_ids.resize(outs.size());
// split id by their shard_num. // split id by their shard_num.
for (int i = 0; i < ids_dims[0]; ++i) { for (int i = 0; i < all_ids.size(); ++i) {
T id = ids[i]; T id = all_ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num; size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id); out_ids[shard_id].push_back(id);
} }
...@@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -64,7 +86,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids_dims[0], PADDLE_ENFORCE_EQ(ids_dims[0],
static_cast<int64_t>(ids_selected_rows->rows().size()), static_cast<int64_t>(ids_selected_rows->rows().size()),
""); "");
const T *ids = ids_selected_rows->value().data<T>(); const T *ids_data = ids_selected_rows->value().data<T>();
const auto &ids_rows = ids_selected_rows->rows(); const auto &ids_rows = ids_selected_rows->rows();
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out"); auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
const size_t shard_num = outs.size(); const size_t shard_num = outs.size();
...@@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> { ...@@ -87,7 +109,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
T *output = out->mutable_value()->mutable_data<T>(ddim, place); T *output = out->mutable_value()->mutable_data<T>(ddim, place);
for (int64_t i = 0; i < ddim[0]; ++i) { for (int64_t i = 0; i < ddim[0]; ++i) {
memcpy(output + i * row_width, memcpy(output + i * row_width,
ids + id_to_index[out->rows()[i]] * row_width, ids_data + id_to_index[out->rows()[i]] * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
} }
......
...@@ -22,15 +22,28 @@ from op_test import OpTest ...@@ -22,15 +22,28 @@ from op_test import OpTest
class TestMergeIdsOp(OpTest): class TestMergeIdsOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "merge_ids" self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') ids1 = np.array([[0], [2], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32') ids2 = np.array([[0], [2], [2], [3]]).astype('int64')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6], rows1 = np.array([[0], [2]]).astype('int64')
[0.5, 0.6]]).astype('float32') rows2 = np.array([[3], [5]]).astype('int64')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3], rows3 = np.array([[6]]).astype('int64')
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]} x0 = np.array([[0.1, 0.2], [0.2, 0.3]]).astype('float32')
self.outputs = {'Out': out} x1 = np.array([[0.3, 0.4], [0.4, 0.5]]).astype('float32')
x2 = np.array([[0.5, 0.6]]).astype('float32')
out1 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.4, 0.5], [0.5, 0.6]]).astype('float32')
out2 = np.array(
[[0.1, 0.2], [0.2, 0.3], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
self.inputs = {
'Ids': [('ids1', ids1), ('ids2', ids2)],
"Rows": [('rows1', rows1), ('rows2', rows2), ('rows3', rows3)],
"X": [('x0', x0), ('x1', x1), ('x2', x2)]
}
self.outputs = {'Out': [('out1', out1), ('out2', out2)]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
...@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator ...@@ -25,18 +25,21 @@ from paddle.fluid.op import Operator
class TestSplitIdsOp(OpTest): class TestSplitIdsOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "split_ids" self.op_type = "split_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64') ids1 = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
ids2 = np.array([[6], [2], [3], [3], [5], [2], [6]]).astype('int64')
ids3 = np.array([[2], [2], [2], [3], [5], [5], [6]]).astype('int64')
out0 = np.array([[0], [3], [6]]).astype('int64') out0 = np.array([[0], [3], [6]]).astype('int64')
out1 = np.array([[]]).astype('int64') out1 = np.array([[]]).astype('int64')
out2 = np.array([[2], [2], [5], [5]]).astype('int64') out2 = np.array([[2], [5]]).astype('int64')
self.inputs = {'Ids': ids} self.inputs = {'Ids': [('ids1', ids1), ('ids2', ids2), ('ids3', ids3)]}
self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]} self.outputs = {'Out': [('out0', out0), ('out1', out1), ('out2', out2)]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestSpliteIds(unittest.TestCase): class TestSplitSelectedRows(unittest.TestCase):
def get_places(self): def get_places(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
return places return places
......
...@@ -1033,15 +1033,11 @@ to transpile() call.") ...@@ -1033,15 +1033,11 @@ to transpile() call.")
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
# self.all_prefetch_input_vars = self.all_in_ids_vars = []
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_input_vars = [] self.all_prefetch_input_vars = []
# self.all_prefetch_input_vars =
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
self.all_prefetch_output_vars = [] self.all_prefetch_output_vars = []
self.all_out_emb_vars = []
lookup_table_op_index = -1
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
while continue_search_lookup_table_op: while continue_search_lookup_table_op:
...@@ -1051,42 +1047,50 @@ to transpile() call.") ...@@ -1051,42 +1047,50 @@ to transpile() call.")
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
lookup_table_op_index = list(all_ops).index(op) lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(
all_ops).index(op)
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
ids_var = program.global_block().vars[ids_name[0]] ids_var = program.global_block().vars[ids_name[0]]
prefetch_input_vars = self._create_splited_vars( self.all_in_ids_vars.append(ids_var)
source_var=ids_var,
block=program.global_block(),
tag="_prefetch_in_")
self.all_prefetch_input_vars.append(prefetch_input_vars)
out_var = program.global_block().vars[out_name[0]] out_var = program.global_block().vars[out_name[0]]
prefetch_output_vars = self._create_splited_vars( self.all_out_emb_vars.append(out_var)
source_var=out_var,
block=program.global_block(), # delete lookup_table_op
tag="_prefetch_out_") delete_ops(program.global_block(), [op])
self.all_prefetch_output_vars.append(prefetch_output_vars) # break for loop
break
for index in range(len(self.pserver_endpoints)):
in_var = program.global_block().create_var(
name=str("prefetch_compress_in_tmp_" + str(index)),
type=self.all_in_ids_vars[0].type,
shape=self.all_in_ids_vars[0].shape,
dtype=self.all_in_ids_vars[0].dtype)
self.all_prefetch_input_vars.append(in_var)
out_var = program.global_block().create_var(
name=str("prefetch_compress_out_tmp_" + str(index)),
type=self.all_out_emb_vars[0].type,
shape=self.all_out_emb_vars[0].shape,
dtype=self.all_out_emb_vars[0].dtype)
self.all_prefetch_output_vars.append(out_var)
# insert split_ids_op # insert split_ids_op
program.global_block()._insert_op( program.global_block()._insert_op(
index=lookup_table_op_index, index=lookup_table_op_index,
type="split_ids", type="split_ids",
inputs={ inputs={'Ids': self.all_in_ids_vars},
'Ids': [ outputs={"Out": self.all_prefetch_input_vars})
program.global_block().vars[varname]
for varname in ids_name
]
},
outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block()._insert_op( program.global_block()._insert_op(
index=lookup_table_op_index + 1, index=lookup_table_op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': prefetch_input_vars}, inputs={'X': self.all_prefetch_input_vars},
outputs={"Out": prefetch_output_vars}, outputs={"Out": self.all_prefetch_output_vars},
attrs={ attrs={
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
# FIXME(qiao) temporarily disable this config because prefetch # FIXME(qiao) temporarily disable this config because prefetch
...@@ -1099,23 +1103,11 @@ to transpile() call.") ...@@ -1099,23 +1103,11 @@ to transpile() call.")
index=lookup_table_op_index + 2, index=lookup_table_op_index + 2,
type="merge_ids", type="merge_ids",
inputs={ inputs={
'Ids': [ 'Ids': self.all_in_ids_vars,
program.global_block().vars[varname] 'Rows': self.all_prefetch_input_vars,
for varname in ids_name 'X': self.all_prefetch_output_vars
],
'X': prefetch_output_vars
}, },
outputs={ outputs={"Out": self.all_out_emb_vars})
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_op to send gradient to pservers # 2. add split_ids_op and send_op to send gradient to pservers
...@@ -1159,15 +1151,14 @@ to transpile() call.") ...@@ -1159,15 +1151,14 @@ to transpile() call.")
# STEP: create prefetch block # STEP: create prefetch block
table_var = pserver_program.global_block().vars[self.table_name] table_var = pserver_program.global_block().vars[self.table_name]
prefetch_var_name_to_block_id = [] prefetch_var_name_to_block_id = []
for index in range(len(self.all_prefetch_input_vars)):
prefetch_block = pserver_program._create_block(optimize_block.idx) prefetch_block = pserver_program._create_block(optimize_block.idx)
trainer_ids = self.all_prefetch_input_vars[index][pserver_index] trainer_ids = self.all_prefetch_input_vars[pserver_index]
pserver_ids = pserver_program.global_block().create_var( pserver_ids = pserver_program.global_block().create_var(
name=trainer_ids.name, name=trainer_ids.name,
type=trainer_ids.type, type=trainer_ids.type,
shape=trainer_ids.shape, shape=trainer_ids.shape,
dtype=trainer_ids.dtype) dtype=trainer_ids.dtype)
trainer_out = self.all_prefetch_output_vars[index][pserver_index] trainer_out = self.all_prefetch_output_vars[pserver_index]
pserver_out = pserver_program.global_block().create_var( pserver_out = pserver_program.global_block().create_var(
name=trainer_out.name, name=trainer_out.name,
type=trainer_out.type, type=trainer_out.type,
...@@ -1363,16 +1354,6 @@ to transpile() call.") ...@@ -1363,16 +1354,6 @@ to transpile() call.")
program.global_block()._sync_with_cpp() program.global_block()._sync_with_cpp()
return var_mapping return var_mapping
def _create_splited_vars(self, source_var, block, tag):
return [
block.create_var(
name=str(source_var.name + tag + str(index)),
type=source_var.type,
shape=source_var.shape,
dtype=source_var.dtype)
for index in range(len(self.pserver_endpoints))
]
def _clone_var(self, block, var, persistable=True): def _clone_var(self, block, var, persistable=True):
return block.create_var( return block.create_var(
name=var.name, name=var.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册