未验证 提交 424dd2fc 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #9597 from jacquesqiao/sgd-support-update-selected-rows

Sgd support update selected rows
...@@ -35,6 +35,17 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -35,6 +35,17 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
}; };
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type());
} else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType(
var->Get<framework::SelectedRows>().value().type());
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
}
static DDim GetDims(const Scope& scope, const std::string& name) { static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
......
...@@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,6 +16,7 @@ limitations under the License. */ ...@@ -13,6 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -47,6 +50,15 @@ class SelectedRows { ...@@ -47,6 +50,15 @@ class SelectedRows {
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; } void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/**
* get the index of id in rows
*/
int64_t index(int64_t id) const {
auto it = std::find(rows_.begin(), rows_.end(), id);
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_; dims[0] = height_;
......
...@@ -18,22 +18,6 @@ limitations under the License. */ ...@@ -18,22 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static inline framework::OpKernelType ExpectedKernelType(
const framework::ExecutionContext& ctx) {
auto* table_var = ctx.InputVar("W");
if (table_var->IsType<LoDTensor>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<LoDTensor>().type()),
ctx.device_context());
} else if (table_var->IsType<SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
ctx.device_context());
} else {
PADDLE_THROW("W should be LoDTensor or SelectedRows");
}
}
class LookupTableOp : public framework::OperatorWithKernel { class LookupTableOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return ExpectedKernelType(ctx); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return ExpectedKernelType(ctx); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor; ...@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
...@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW("table only support LoDTensor and SelectedRows"); PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
} }
int64_t *ids; int64_t *ids;
...@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset(output + i * row_width, 0, row_width * sizeof(T)); memset(output + i * row_width, 0, row_width * sizeof(T));
} else { } else {
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = getIndex(table_t.rows(), ids[i]); auto id_index = table_t.index(ids[i]);
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
...@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW("table only support LoDTensor and SelectedRows"); PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
} }
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
......
...@@ -43,9 +43,8 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -43,9 +43,8 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
framework::ToDataType(ctx.Input<framework::LoDTensor>("Param")->type()), return framework::OpKernelType(data_type, ctx.device_context());
ctx.GetPlace());
} }
}; };
...@@ -53,10 +52,12 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,10 +52,12 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker) SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "(Tensor) Input parameter"); AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of SGD"); AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddInput("Grad", "(Tensor) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
AddComment(R"DOC( AddComment(R"DOC(
SGD operator SGD operator
......
...@@ -23,21 +23,25 @@ namespace operators { ...@@ -23,21 +23,25 @@ namespace operators {
template <typename T> template <typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto* param = ctx.Input<framework::Tensor>("Param"); const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate"); const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
if (param_var->IsType<framework::LoDTensor>()) {
const auto *param = ctx.Input<framework::Tensor>("Param");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* grad_var = ctx.InputVar("Grad");
// Actually, all tensors are LoDTensor except SelectedRows. // Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
auto* grad = ctx.Input<framework::Tensor>("Grad"); const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto p = framework::EigenVector<T>::Flatten(*param); auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad); auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out); auto o = framework::EigenVector<T>::Flatten(*param_out);
auto* lr = learning_rate->data<T>(); auto *lr = learning_rate->data<T>();
o = p - lr[0] * g; o = p - lr[0] * g;
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
...@@ -45,7 +49,7 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -45,7 +49,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
// This manual optimization brings difficulty to track data dependency. // This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution. // It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out); PADDLE_ENFORCE_EQ(param, param_out);
auto* grad = ctx.Input<framework::SelectedRows>("Grad"); const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
// for distributed training, a sparse var may be empty, // for distributed training, a sparse var may be empty,
// just skip updating. // just skip updating.
...@@ -53,31 +57,64 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -53,31 +57,64 @@ class SGDOpKernel : public framework::OpKernel<T> {
return; return;
} }
auto in_height = grad->height(); auto grad_height = grad->height();
auto out_dims = param_out->dims(); auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height, out_dims[0]); PADDLE_ENFORCE_EQ(grad_height, out_dims[0]);
auto& in_value = grad->value(); auto &grad_value = grad->value();
auto& in_rows = grad->rows(); auto &grad_rows = grad->rows();
int64_t in_row_numel = in_value.numel() / in_rows.size(); size_t grad_row_numel = grad_value.numel() / grad_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height); PADDLE_ENFORCE_EQ(grad_row_numel, param_out->numel() / grad_height);
auto* in_data = in_value.data<T>(); auto *grad_data = grad_value.data<T>();
auto* out_data = param_out->data<T>(); auto *out_data = param_out->data<T>();
auto* lr = learning_rate->data<T>(); auto *lr = learning_rate->data<T>();
for (size_t i = 0; i < in_rows.size(); i++) { for (size_t i = 0; i < grad_rows.size(); i++) {
PADDLE_ENFORCE(in_rows[i] < in_height, PADDLE_ENFORCE(grad_rows[i] < grad_height,
"Input rows index should less than height"); "Input rows index should less than height");
for (int64_t j = 0; j < in_row_numel; j++) { for (int64_t j = 0; j < grad_row_numel; j++) {
out_data[in_rows[i] * in_row_numel + j] -= out_data[grad_rows[i] * grad_row_numel + j] -=
lr[0] * in_data[i * in_row_numel + j]; lr[0] * grad_data[i * grad_row_numel + j];
} }
} }
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
} }
} else if (param_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(grad_var->IsType<framework::SelectedRows>(),
"when param "
"is SelectedRows, gradient should also be SelectedRows");
const auto &param = param_var->Get<framework::SelectedRows>();
auto *param_out = ctx.Output<framework::SelectedRows>("ParamOut");
const auto &grad = grad_var->Get<framework::SelectedRows>();
// for distributed training, a sparse var may be empty,
// just skip updating.
if (grad.rows().size() == 0) {
return;
}
size_t param_row_width = param.value().numel() / param.rows().size();
size_t grad_row_width = grad.value().numel() / grad.rows().size();
PADDLE_ENFORCE_EQ(param_row_width, grad_row_width,
"param_row should have the same size with grad_row");
const auto *lr = learning_rate->data<T>();
const auto *grad_data = grad.value().data<T>();
auto *out_data = param_out->mutable_value()->data<T>();
for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height");
int64_t id_index = param.index(grad.rows()[i]);
for (int64_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j];
}
}
} else {
PADDLE_THROW("Unsupported Variable Type of Parameter");
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest): ...@@ -115,18 +115,18 @@ class TestLookupTableWIsSelectedRows(OpTest):
w_array = np.ones((len(rows), row_numel)).astype("float32") w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)): for i in range(len(rows)):
w_array[i] *= i w_array[i] *= i
ids_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
ids_tensor.set(w_array, place) w_tensor.set(w_array, place)
# create Out Variable # create Out Variable
Out_tensor = scope.var('Out').get_tensor() out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator # create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place) lookup_table.run(scope, place)
# get result from Out # get result from Out
result_array = np.array(Out_tensor) result_array = np.array(out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty) # all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array): for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all() assert (row[0] == result_array[idx]).all()
......
...@@ -97,5 +97,72 @@ class TestSparseSGDOp(unittest.TestCase): ...@@ -97,5 +97,72 @@ class TestSparseSGDOp(unittest.TestCase):
self.check_with_place(place) self.check_with_place(place)
class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
row_width = 12
# create and initialize Grad Variable
grad_height = 10
grad_rows = [0, 4, 7]
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(grad_height)
grad_selected_rows.set_rows(grad_rows)
grad_array = np.ones((len(grad_rows), row_width)).astype("float32")
grad_array[0, 0] = 2.0
grad_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(grad_array, place)
# create and initialize Param Variable
# create and initialize W Variable
param_rows = [0, 1, 2, 3, 4, 5, 6, 7]
# init Param
w_selected_rows = scope.var('Param').get_selected_rows()
w_selected_rows.set_height(len(param_rows))
w_selected_rows.set_rows(param_rows)
w_array = np.ones((len(param_rows), row_width)).astype("float32")
for i in range(len(param_rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
w_before_optimize = np.array(w_tensor)
# create and initialize LeraningRate Variable
lr_value = 0.1
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), lr_value).astype("float32")
lr.set(lr_array, place)
# optimize with Python
w_after_optimize = np.copy(w_before_optimize)
for index, id in enumerate(grad_rows):
w_after_optimize[id] = w_before_optimize[
id] - lr_value * grad_array[index]
# create and run sgd operator
sgd_op = Operator(
"sgd",
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate')
sgd_op.run(scope, place)
# get and compare result
result_array = np.array(w_tensor)
assert (result_array == w_after_optimize).all()
def test_sparse_parameter_sgd(self):
places = [core.CPUPlace()]
# do not support GPU kernel currently
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册