未验证 提交 ff99d941 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #10164 from Yancey1989/lookup_sparse_table_op

add lookup_sparse_table_op
...@@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) { ...@@ -255,11 +255,11 @@ TEST(LoDTensor, RecordIO) {
std::unique_ptr<std::istream> stream_ptr(stream); std::unique_ptr<std::istream> stream_ptr(stream);
recordio::Scanner scanner(std::move(stream_ptr)); recordio::Scanner scanner(std::move(stream_ptr));
auto tensors = ReadFromRecordIO(&scanner, ctx); auto tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
tensors = ReadFromRecordIO(&scanner, ctx); tensors = ReadFromRecordIO(&scanner, ctx);
ASSERT_EQ(tensors.size(), 2); ASSERT_EQ(tensors.size(), static_cast<size_t>(2));
assert_tensor_ok(tensors[0]); assert_tensor_ok(tensors[0]);
assert_tensor_ok(tensors[1]); assert_tensor_ok(tensors[1]);
} }
......
...@@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -120,11 +120,11 @@ bool SelectedRows::HasKey(int64_t key) const {
: true; : true;
} }
std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys, std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
framework::Tensor* value) const { std::vector<int64_t> keys, framework::Tensor* value) const {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
std::vector<int64_t> non_keys; std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
int64_t value_width = value_->numel() / value_->dims()[0]; int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table " "output tensor should have the same shape with table "
...@@ -133,7 +133,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys, ...@@ -133,7 +133,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]); int64_t index = Index(keys[i]);
if (index == -1) { if (index == -1) {
non_keys.push_back(keys[i]); non_keys_pair.push_back(std::make_pair(keys[i], static_cast<int64_t>(i)));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), framework::ToDataType(value_->type()),
...@@ -141,7 +141,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys, ...@@ -141,7 +141,7 @@ std::vector<int64_t> SelectedRows::Get(std::vector<int64_t> keys,
index * value_width, value_width)); index * value_width, value_width));
} }
} }
return non_keys; return non_keys_pair;
} }
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -78,10 +79,11 @@ class SelectedRows { ...@@ -78,10 +79,11 @@ class SelectedRows {
/* /*
* @brief Get value by the key list, if the * @brief Get value by the key list, if the
* *
* @return a list of keys which does not exists in table * @return a list of pair which contains the non-exists key and the index in
* the value
*/ */
std::vector<int64_t> Get(std::vector<int64_t> keys, std::vector<std::pair<int64_t, int64_t>> Get(std::vector<int64_t> keys,
framework::Tensor* tensor) const; framework::Tensor* value) const;
/* /*
* @brief Set a key-value pair into the table. * @brief Set a key-value pair into the table.
......
...@@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ...@@ -59,7 +59,7 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
} }
TEST_F(SelectedRowsTester, Table) { TEST_F(SelectedRowsTester, SparseTable) {
platform::CPUPlace cpu; platform::CPUPlace cpu;
SelectedRows table; SelectedRows table;
// initialize a sparse table // initialize a sparse table
...@@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) { ...@@ -87,11 +87,11 @@ TEST_F(SelectedRowsTester, Table) {
framework::Tensor get_value; framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu); get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key}); std::vector<int64_t> keys({non_key, key});
auto non_keys = table.Get(keys, &get_value); auto non_key_pairs = table.Get(keys, &get_value);
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10)); ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_keys.size(), static_cast<size_t>(1)); ASSERT_EQ(non_key_pairs.size(), static_cast<size_t>(1));
ASSERT_EQ(non_keys[0], non_key); ASSERT_EQ(non_key_pairs[0].first, non_key);
} }
} // namespace framework } // namespace framework
......
...@@ -108,7 +108,7 @@ void RunSerdeTestSelectedRows(platform::Place place) { ...@@ -108,7 +108,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7); EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
} }
for (size_t i = 0; i < rows2->size(); ++i) { for (size_t i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i); EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
} }
EXPECT_EQ(slr2->height(), 1000); EXPECT_EQ(slr2->height(), 1000);
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
constexpr int64_t kNoPadding = -1;
class LookupSparseTableInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of LookupSparseTableOp should not be null.");
auto shape_w = ctx->GetInputDim("W");
auto shape_ids = ctx->GetInputDim("Ids");
shape_w[0] = shape_ids.size();
ctx->SetOutputDim("Out", shape_w);
}
};
class LookupSparseTableOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto out_var = scope.FindVar(Output("Out"));
auto w_var = scope.FindVar(Input("W"));
auto ids_var = scope.FindVar(Input("Ids"));
unsigned int seed = static_cast<unsigned int>(Attr<int>("seed"));
float min = Attr<float>("min");
float max = Attr<float>("max");
bool auto_grown_table = Attr<bool>("auto_grown_table");
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
"The type of Out var should be LodTensor.");
PADDLE_ENFORCE(w_var->IsType<framework::SelectedRows>(),
"The type of W var should be SelectedRows.");
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
"The type of Ids var should be LoDTensor.");
auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>();
std::vector<int64_t> keys;
keys.resize(ids_t.numel());
for (size_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i];
}
// TODO(Yancey1989): support CUDA Place for the sparse table
platform::CPUPlace cpu;
auto out_shape = w_t->value().dims();
out_shape[0] = keys.size();
out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
"The sparse table only support FP32");
auto non_keys_pair = w_t->Get(keys, out_t);
if (!auto_grown_table) {
PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast<size_t>(0),
"there is some keys does exists in the sparse table.");
}
auto value_shape = w_t->value().dims();
value_shape[0] = 1;
for (const auto &it : non_keys_pair) {
const auto key = it.first;
const auto index = it.second;
framework::Tensor value;
value.Resize(value_shape);
auto data = value.mutable_data<float>(cpu);
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<float> dist(min, max);
int64_t size = value.numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
w_t->Set(key, value);
memory::Copy(cpu, out_t->mutable_data<float>(cpu) + index * value.numel(),
cpu, value.data<float>(), value.numel() * sizeof(float));
}
}
};
class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LookupSparseTableOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"(SelectedRows) The input represents embedding table, "
"which is a learnable parameter.");
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.");
AddOutput("Out",
"(LoDTensor) The lookup results, which have the "
"same type as W.");
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximun value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<bool>("auto_grown_table",
"(bool default false)"
"Whether create new value if for nonexistent key.")
.SetDefault(true);
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_sparse_table, ops::LookupSparseTableOp,
ops::LookupSparseTableInferShape,
ops::LookupSparseTableOpMaker,
paddle::framework::EmptyGradOpMaker);
...@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
} }
}; };
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
}
}
}
};
class SGDOpMaker : public framework::OpProtoAndCheckerMaker { class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker) SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$ ...@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OPERATOR(sgd, ops::SGDOp, ops::SGDOpMaker,
paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>); REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);
...@@ -116,11 +116,31 @@ uniform distribution. ...@@ -116,11 +116,31 @@ uniform distribution.
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
} }
}; };
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OPERATOR(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker,
paddle::operators::UniformRandomOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>, paddle::operators::CPUUniformRandomKernel<float>,
paddle::operators::CPUUniformRandomKernel<double>); paddle::operators::CPUUniformRandomKernel<double>);
......
...@@ -661,7 +661,7 @@ class DistributeTranspiler: ...@@ -661,7 +661,7 @@ class DistributeTranspiler:
shape=trainer_out.shape, shape=trainer_out.shape,
dtype=trainer_out.dtype) dtype=trainer_out.dtype)
prefetch_block.append_op( prefetch_block.append_op(
type=LOOKUP_TABLE_TYPE, type="lookup_sparse_table",
inputs={'Ids': pserver_ids, inputs={'Ids': pserver_ids,
"W": table_var}, "W": table_var},
outputs={"Out": pserver_out}, outputs={"Out": pserver_out},
...@@ -685,9 +685,14 @@ class DistributeTranspiler: ...@@ -685,9 +685,14 @@ class DistributeTranspiler:
# STEP: create table optimize block # STEP: create table optimize block
# create table param and grad var in pserver program # create table param and grad var in pserver program
param_var = _clone_var( origin_param_var = self.origin_program.global_block().vars[
pserver_program.global_block(), self.table_name]
self.origin_program.global_block().vars[self.table_name]) param_var = pserver_program.global_block().create_var(
name=origin_param_var.name,
shape=origin_param_var.shape,
dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True)
grad_var = _clone_var( grad_var = _clone_var(
pserver_program.global_block(), pserver_program.global_block(),
self.origin_program.global_block().vars[framework.grad_var_name( self.origin_program.global_block().vars[framework.grad_var_name(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
def output_hist(out):
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestLookupSpraseTable(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Id Variable
ids = scope.var("Ids").get_tensor()
ids_array = np.array([0, 2, 3, 5, 100]).astype("int64")
ids.set(ids_array, place)
# create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 10000
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
lookup_table = Operator(
"lookup_sparse_table",
W='W',
Ids='Ids',
Out='Out',
min=-5.0,
max=10.0,
seed=10)
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array[:-2]):
assert (row == result_array[idx]).all()
# check the random value
hist, prob = output_hist(result_array[-1])
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册