未验证 提交 7c426be9 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #11342 from jacquesqiao/add-merge-splited-ids

Add merge_ids_op
...@@ -330,8 +330,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -330,8 +330,12 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
// NOTE! Please do not delete this line, it's usefull because the debug
// string before and after op.run are different, after run the output
// will have right shape which is usefull for debug.
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
...@@ -69,6 +69,19 @@ static DDim GetDims(const Scope& scope, const std::string& name, ...@@ -69,6 +69,19 @@ static DDim GetDims(const Scope& scope, const std::string& name,
} }
} }
static int GetRowSize(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return -1;
}
if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().rows().size();
}
return -1;
}
static LoD GetLoD(const Scope& scope, const std::string& name) { static LoD GetLoD(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}}); auto default_lod = LoD({{}});
...@@ -153,6 +166,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -153,6 +166,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < input.second.size(); ++i) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i]; ss << input.second[i];
if (scope) { if (scope) {
int row_size = GetRowSize(*scope, input.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
ss << "[" << GetDims(*scope, input.second[i], true) << "]"; ss << "[" << GetDims(*scope, input.second[i], true) << "]";
ss << "(" << GetLoD(*scope, input.second[i]) << ")"; ss << "(" << GetLoD(*scope, input.second[i]) << ")";
} }
...@@ -173,6 +190,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { ...@@ -173,6 +190,10 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
for (size_t i = 0; i < output.second.size(); ++i) { for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i]; ss << output.second[i];
if (scope) { if (scope) {
int row_size = GetRowSize(*scope, output.second[i]);
if (row_size >= 0) {
ss << "[row_size=" << row_size << "]";
}
ss << "[" << GetDims(*scope, output.second[i], true) << "]"; ss << "[" << GetDims(*scope, output.second[i], true) << "]";
ss << "(" << GetLoD(*scope, output.second[i]) << ")"; ss << "(" << GetLoD(*scope, output.second[i]) << ")";
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/merge_ids_op.h"
namespace paddle {
namespace operators {
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
AddInput(
"X",
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
"size of embedding table")
.AsDuplicable();
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
AddComment(R"DOC(
Merge multi LoDTensor's into one according to Ids's shard num.
split_ids_op -> prefetch_op -> merge_ids_op
merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op
will split input Ids into multiple tensors according to Id's shard number.
prefetch_op will send them to parameter server to prefetch embedding value
back. During split, the order of ids is disordered. In merge_ids_op we use
the original Ids to restore the order of the fetched embedding value and
also pass the lod information to the merged output.
Example:
Ids = [1,2,3,4,5,6] # 3 shared
split_ids_op ->
Id0 = [3, 6] # id % 3 == 0
Id1 = [1, 4] # id % 3 == 1
Id2 = [2, 5] # id % 3 == 2
prefetch_op ->
X0 = [[0.3 0.3] # 3
[0.6 0.6]] # 6
X1 = [[0.1 0.1] # 1
[0.4 0.4]] # 4
X2 = [[0.2 0.2] # 2
[0.5 0.5]] # 5
merge_ids_op ->
Out = [[0.1 0.1] # 1
[0.2 0.2] # 2
[0.3 0.3] # 3
[0.4 0.4] # 4
[0.5 0.5] # 5
[0.6 0.6]] # 6
)DOC");
}
};
class MergeIdsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out.");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
auto ids_dims = ctx->GetInputDim("Ids");
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
auto x_var_type = ctx->GetInputsVarType("X");
for (auto &var_type : x_var_type) {
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
"input X only support lod tensors");
}
ctx->ShareLoD("Ids", "Out");
}
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X").front()->type()),
ctx.GetPlace());
}
};
class MergeIdsOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(input_var->GetType());
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
ops::MergeIdsOpInferVarType);
REGISTER_OP_CPU_KERNEL(
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MergeIdsOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("MergeIds do not support GPU kernel");
}
VLOG(3) << "run in MergeIdsOpKernel";
const auto *ids_var = ctx.InputVar("Ids");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"only support to merge Ids of LoDTensor");
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
const auto &ids_dims = ids_tensor.dims();
const int64_t *ids = ids_tensor.data<int64_t>();
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
auto *out = ctx.Output<framework::LoDTensor>("Out");
int batch_size = 0;
int embedding_size = 0;
for (auto &input : x_tensors) {
if (framework::product(input->dims()) != 0) {
if (embedding_size == 0) {
embedding_size = input->dims()[1];
}
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
"embedding size of all input should be the same");
batch_size += input->dims()[0];
}
}
PADDLE_ENFORCE_EQ(
batch_size, ids_dims[0],
"the batch size of ids and merged embedding value should be the same");
const size_t shard_num = x_tensors.size();
if (shard_num == 1) {
VLOG(3) << "only one shard, we can copy the data directly";
TensorCopy(*x_tensors[0], place, out);
} else {
std::vector<int> in_indexs(shard_num, 0);
auto *out_data = out->mutable_data<T>(
framework::make_ddim({batch_size, embedding_size}), place);
// copy data from ins[shard_num] to out.
for (int i = 0; i < ids_dims[0]; ++i) {
int64_t id = ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
int index = in_indexs[shard_id];
memcpy(out_data + embedding_size * i,
x_tensors[shard_id]->data<T>() + index * embedding_size,
sizeof(T) * embedding_size);
in_indexs[shard_id] += 1;
}
for (size_t i = 0; i < shard_num; ++i) {
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
"after merge, all data in x_tensor should be used");
}
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestMergeIdsOp(OpTest):
def setUp(self):
self.op_type = "merge_ids"
ids = np.array([[0], [2], [2], [3], [5], [5], [6]]).astype('int64')
x0 = np.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4]]).astype('float32')
x1 = np.array([]).astype('float32')
x2 = np.array([[0.4, 0.5], [0.4, 0.5], [0.5, 0.6],
[0.5, 0.6]]).astype('float32')
out = np.array([[0.1, 0.2], [0.4, 0.5], [0.4, 0.5], [0.2, 0.3],
[0.5, 0.6], [0.5, 0.6], [0.3, 0.4]]).astype('float32')
self.inputs = {'Ids': ids, "X": [('x0', x0), ('x1', x1), ('x2', x2)]}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
...@@ -629,7 +629,7 @@ class DistributeTranspiler: ...@@ -629,7 +629,7 @@ class DistributeTranspiler:
if op.type == LOOKUP_TABLE_TYPE: if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True continue_search_lookup_table_op = True
op_index = list(all_ops).index(op) lookup_table_op_index = list(all_ops).index(op)
ids_name = op.input("Ids") ids_name = op.input("Ids")
out_name = op.output("Out") out_name = op.output("Out")
...@@ -649,7 +649,7 @@ class DistributeTranspiler: ...@@ -649,7 +649,7 @@ class DistributeTranspiler:
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index, index=lookup_table_op_index,
type="split_ids", type="split_ids",
inputs={ inputs={
'Ids': [ 'Ids': [
...@@ -661,7 +661,7 @@ class DistributeTranspiler: ...@@ -661,7 +661,7 @@ class DistributeTranspiler:
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 1, index=lookup_table_op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': prefetch_input_vars}, inputs={'X': prefetch_input_vars},
outputs={"Out": prefetch_output_vars}, outputs={"Out": prefetch_output_vars},
...@@ -672,16 +672,21 @@ class DistributeTranspiler: ...@@ -672,16 +672,21 @@ class DistributeTranspiler:
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block().insert_op(
index=op_index + 2, index=lookup_table_op_index + 2,
type="concat", type="merge_ids",
inputs={'X': prefetch_output_vars}, inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': prefetch_output_vars
},
outputs={ outputs={
"Out": [ "Out": [
program.global_block().vars[varname] program.global_block().vars[varname]
for varname in out_name for varname in out_name
] ]
}, })
attrs={"axis": 0})
# delete lookup_table_op # delete lookup_table_op
delete_ops(program.global_block(), [op]) delete_ops(program.global_block(), [op])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册