未验证 提交 653fad08 编写于 作者: Q Qiao Longfei 提交者: GitHub

Optimize selected rows for dist lookup table with pthread rwlock (#12635)

Optimize selected rows for dist lookup table with rwlock 
上级 64d48f4d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <pthread.h>
namespace paddle {
namespace framework {
struct RWLock {
RWLock() { pthread_rwlock_init(&lock_, nullptr); }
~RWLock() { pthread_rwlock_destroy(&lock_); }
void RDLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0,
"acquire read lock failed");
}
void WRLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
"acquire write lock failed");
}
void UNLock() {
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
}
private:
pthread_rwlock_t lock_;
};
} // namespace framework
} // namespace paddle
......@@ -120,66 +120,76 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
const std::vector<int64_t>& keys, framework::Tensor* value) const {
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
rwlock_->RDLock();
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
rwlock_->UNLock();
if (!auto_grown) {
PADDLE_THROW("key %d not found", key);
}
rwlock_->WRLock();
auto map_size = id_to_index_.size();
auto vector_size = rows_.size();
if (map_size != vector_size) {
rwlock_->UNLock();
PADDLE_THROW(
"id_to_index_ size %d should have the same size with rows_ %d",
map_size, vector_size);
}
auto write_iter = id_to_index_.find(key);
if (write_iter == id_to_index_.end()) {
size_t row_num = rows_.size();
if (row_num == value_->dims()[0]) {
rwlock_->UNLock();
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
}
// key logic to put a key into id_to_index_
rows_.push_back(key);
auto index = static_cast<int64_t>(rows_.size() - 1);
id_to_index_[key] = index;
rwlock_->UNLock();
return index;
} else {
auto index = write_iter->second;
rwlock_->UNLock();
return index;
}
} else {
auto index = iter->second;
rwlock_->UNLock();
return index;
}
}
void SelectedRows::SyncIndex() {
rwlock_->WRLock();
id_to_index_.clear();
for (size_t i = 0; i < rows_.size(); ++i) {
id_to_index_[rows_[i]] = i;
}
rwlock_->UNLock();
}
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown) {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
if (keys.empty()) {
if (ids.numel() == 0) {
VLOG(3) << "keys is empty, please check data!";
} else {
int64_t value_width = value_->numel() / value_->dims()[0];
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
"output tensor should have the same shape with table "
"except the dims[0].");
for (size_t i = 0; i < keys.size(); ++i) {
int64_t index = Index(keys[i]);
if (index == -1) {
non_keys_pair.push_back(
std::make_pair(keys[i], static_cast<int64_t>(i)));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
for (size_t i = 0; i < ids.numel(); ++i) {
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown);
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
return non_keys_pair;
}
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized.");
if (value_->IsInitialized()) {
PADDLE_ENFORCE_EQ(
value.type(), value_->type(),
"The type of the value should be same with the original value");
}
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
"The first dim of value should be 1.");
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
auto index = Index(key);
bool is_new_key = false;
if (index == -1) {
rows_.push_back(key);
index = rows_.size() - 1;
is_new_key = true;
// whether need to resize the table
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
auto dims = value_->dims();
dims[0] = (dims[0] + 1) << 1;
framework::VisitDataType(framework::ToDataType(value.type()),
ReAllocateVisitor(dims, value_.get()));
}
}
framework::VisitDataType(
framework::ToDataType(value.type()),
TensorCopyVisitor(value_.get(),
index * value_->numel() / value_->dims()[0], value,
static_cast<int64_t>(0), value.numel()));
return is_new_key;
}
} // namespace framework
......
......@@ -17,10 +17,12 @@ limitations under the License. */
#include <algorithm>
#include <memory>
#include <mutex> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -48,13 +50,13 @@ class SelectedRows {
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
: rows_(rows), height_(height) {
value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
rwlock_.reset(new RWLock);
}
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
auto_grown_mutex_.reset(new std::mutex);
rwlock_.reset(new RWLock);
}
platform::Place place() const { return value_->place(); }
......@@ -74,47 +76,51 @@ class SelectedRows {
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/*
* @brief wheter has the specified key in the table.
* @brief Get the index of key in rows
*
* @return -1 if the key does not exists.
*/
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
PADDLE_THROW("id %s not in table", key);
}
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
/*
* @brief whether has the specified key in the table.
*
* @return true if the key is exists.
*/
bool HasKey(int64_t key) const;
/*
* @brief Get value by the key list, if the
* @brief Get value by the key list.
* Note!!! this interface is only used when selected_rows is used as
* parameters
* for distribute lookup table.
*
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
framework::Tensor* value) const;
void Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown = false);
/*
* @brief Set a key-value pair into the table.
* This function will double the value memory if it's not engouth.
* @brief Get the index of the key from id_to_index_ map. If the key not
* exist,
* add the key into id_to_index_.
*
* @note:
* 1. The first dim of the value should be 1
* 2. The value should be initialized and the data type
* should be the same with the table.
*
* @return true if the key is a new one, otherwise false
* Note!!! this interface is only used when selected_rows is used as
* parameters
* for distribute lookup table.
*
* @return index of the key.
*/
bool Set(int64_t key, const Tensor& value);
int64_t AutoGrownIndex(int64_t key, bool auto_grown);
/*
* @brief Get the index of key in rows
*
* @return -1 if the key does not exists.
*/
int64_t Index(int64_t key) const {
auto it = std::find(rows_.begin(), rows_.end(), key);
if (it == rows_.end()) {
return static_cast<int64_t>(-1);
}
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
void SyncIndex();
DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims());
......@@ -127,9 +133,10 @@ class SelectedRows {
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_;
std::unordered_map<int64_t, int64_t> id_to_index_;
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_;
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
std::unique_ptr<RWLock> rwlock_{nullptr};
};
/*
......
......@@ -9,8 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include <time.h>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace framework {
......@@ -59,39 +62,129 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims());
}
TEST_F(SelectedRowsTester, SparseTable) {
TEST(SelectedRows, SparseTable) {
platform::CPUPlace cpu;
SelectedRows table;
int64_t table_size = 100;
int64_t embedding_width = 8;
// initialize a sparse table
table.mutable_value()->Resize(framework::make_ddim({1, 100}));
table.mutable_value()->mutable_data<float>(cpu);
table.mutable_rows()->push_back(1);
table.mutable_value()->Resize(
framework::make_ddim({table_size, embedding_width}));
auto* data = table.mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data[i * embedding_width + j] = static_cast<float>(i);
}
}
ASSERT_EQ(table.AutoGrownIndex(10, true), 0);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
ASSERT_EQ(table.AutoGrownIndex(6, true), 2);
ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6));
ASSERT_EQ(table.rows().size(), 3);
framework::Tensor ids;
ids.Resize(framework::make_ddim({4}));
auto* ids_data = ids.mutable_data<int64_t>(cpu);
ids_data[0] = static_cast<int64_t>(6);
ids_data[1] = static_cast<int64_t>(6);
ids_data[2] = static_cast<int64_t>(8);
ids_data[3] = static_cast<int64_t>(10);
int64_t key = 10000;
int64_t non_key = 999;
framework::Tensor value;
value.Resize(framework::make_ddim({1, 100}));
auto ptr = value.mutable_data<float>(cpu);
ptr[0] = static_cast<float>(10);
framework::Tensor get_value;
auto* value_data = get_value.mutable_data<float>(
framework::make_ddim({4, embedding_width}), cpu);
table.Get(ids, &get_value);
ASSERT_EQ(table.rows().size(), static_cast<size_t>(1));
ASSERT_EQ(table.HasKey(key), false);
for (int j = 0; j < embedding_width; ++j) {
ASSERT_EQ(value_data[0 * embedding_width + j], 2);
}
for (int j = 0; j < embedding_width; ++j) {
ASSERT_EQ(value_data[1 * embedding_width + j], 2);
}
for (int j = 0; j < embedding_width; ++j) {
ASSERT_EQ(value_data[2 * embedding_width + j], 1);
}
for (int j = 0; j < embedding_width; ++j) {
ASSERT_EQ(value_data[3 * embedding_width + j], 0);
}
}
table.Set(key, value);
void f1(SelectedRows* table, int table_size) {
for (int i = 1000000; i > 0; --i) {
auto id = i % table_size;
int64_t index1 = table->AutoGrownIndex(id, true);
int64_t index2 = table->AutoGrownIndex(id, false);
int64_t index3 = table->AutoGrownIndex(id, true);
ASSERT_EQ(index1, index2);
ASSERT_EQ(index2, index3);
}
}
ASSERT_EQ(table.rows().size(), static_cast<size_t>(2));
ASSERT_EQ(table.HasKey(key), true);
// check re-allocate
ASSERT_EQ(table.value().dims()[0], static_cast<int64_t>(4));
void f2(SelectedRows* table, int table_size) {
for (int i = 0; i < 1000000; ++i) {
auto id = i % table_size;
int64_t index1 = table->AutoGrownIndex(id, true);
int64_t index2 = table->AutoGrownIndex(id, false);
int64_t index3 = table->AutoGrownIndex(id, true);
ASSERT_EQ(index1, index2);
ASSERT_EQ(index2, index3);
}
}
framework::Tensor get_value;
get_value.mutable_data<float>(framework::make_ddim({2, 100}), cpu);
std::vector<int64_t> keys({non_key, key});
auto non_key_pairs = table.Get(keys, &get_value);
void f3(SelectedRows* table, int table_size) {
clock_t t1 = clock();
for (int i = 100000; i > 0; --i) {
auto id1 = table->AutoGrownIndex(i % table_size, true);
auto id2 = table->Index(i % table_size);
ASSERT_EQ(id1, id2);
}
clock_t t2 = clock();
std::cout << "f3 run time:" << t2 - t1 << std::endl;
}
void f4(SelectedRows* table, int table_size) {
clock_t t1 = clock();
for (int i = 0; i < 100000; ++i) {
auto id1 = table->AutoGrownIndex(i % table_size, true);
auto id2 = table->Index(i % table_size);
ASSERT_EQ(id1, id2);
}
clock_t t2 = clock();
std::cout << "f4 run time:" << t2 - t1 << std::endl;
}
TEST(SelectedRows, MultiThreadAutoIndex) {
platform::CPUPlace cpu;
SelectedRows table;
int64_t table_size = 100000;
int64_t embedding_width = 8;
// initialize a sparse table
table.mutable_value()->Resize(
framework::make_ddim({table_size, embedding_width}));
auto* data = table.mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data[i * embedding_width + j] = static_cast<float>(i);
}
}
ASSERT_EQ(get_value.data<float>()[100], static_cast<float>(10));
ASSERT_EQ(non_key_pairs.size(), static_cast<size_t>(1));
ASSERT_EQ(non_key_pairs[0].first, non_key);
std::thread t1(f1, &table, table_size);
std::thread t11(f1, &table, table_size);
std::thread t2(f2, &table, table_size);
std::thread t22(f2, &table, table_size);
t1.join();
t11.join();
t2.join();
t22.join();
std::thread t3(f3, &table, table_size);
std::thread t4(f4, &table, table_size);
t3.join();
t4.join();
}
} // namespace framework
......
......@@ -78,10 +78,9 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto rows = w->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true);
auto ptr = w_value->mutable_data<float>(*place);
......
......@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -46,10 +45,6 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto out_var = scope.FindVar(Output("Out"));
auto w_var = scope.FindVar(Input("W"));
auto ids_var = scope.FindVar(Input("Ids"));
unsigned int seed = static_cast<unsigned int>(Attr<int>("seed"));
float min = Attr<float>("min");
float max = Attr<float>("max");
bool auto_grown_table = Attr<bool>("auto_grown_table");
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
"The type of Out var should be LodTensor.");
......@@ -60,46 +55,17 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto &ids_t = ids_var->Get<framework::LoDTensor>();
auto out_t = out_var->GetMutable<framework::LoDTensor>();
auto w_t = w_var->GetMutable<framework::SelectedRows>();
std::vector<int64_t> keys;
keys.resize(ids_t.numel());
for (int64_t i = 0; i < ids_t.numel(); ++i) {
keys[i] = ids_t.data<int64_t>()[i];
}
// TODO(Yancey1989): support CUDA Place for the sparse table
platform::CPUPlace cpu;
auto out_shape = w_t->value().dims();
out_shape[0] = keys.size();
out_shape[0] = ids_t.numel();
out_t->Resize(out_shape);
out_t->mutable_data(cpu, w_t->value().type());
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
"The sparse table only support FP32");
auto non_keys_pair = w_t->Get(keys, out_t);
if (!auto_grown_table) {
PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast<size_t>(0),
"there is some keys does exists in the sparse table.");
}
auto value_shape = w_t->value().dims();
value_shape[0] = 1;
for (const auto &it : non_keys_pair) {
const auto key = it.first;
const auto index = it.second;
framework::Tensor value;
value.Resize(value_shape);
auto data = value.mutable_data<float>(cpu);
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<float> dist(min, max);
int64_t size = value.numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
w_t->Set(key, value);
memory::Copy(cpu, out_t->mutable_data<float>(cpu) + index * value.numel(),
cpu, value.data<float>(), value.numel() * sizeof(float));
}
w_t->Get(ids_t, out_t, true);
}
};
......@@ -121,21 +87,6 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddAttr<float>("min",
"(float, default -1.0) "
"Minimum value of uniform random")
.SetDefault(-1.0f);
AddAttr<float>("max",
"(float, default 1.0) "
"Maximum value of uniform random")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
"Random seed used for generating samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time.")
.SetDefault(0);
AddAttr<bool>("auto_grown_table",
"(bool default false)"
"Whether create new value if for nonexistent key.")
......
......@@ -111,7 +111,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height");
int64_t id_index = param.Index(grad.rows()[i]);
int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false);
PADDLE_ENFORCE_GE(id_index, static_cast<int64_t>(0),
"id should be in the table");
for (int64_t j = 0; j < grad_row_width; j++) {
......
......@@ -30,8 +30,10 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape");
tensor = out_var->GetMutable<framework::SelectedRows>()->mutable_value();
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
} else {
PADDLE_THROW(
"uniform_random_op's output only"
......
......@@ -249,6 +249,7 @@ PYBIND11_PLUGIN(core) {
self.set_rows(new_rows);
#endif
})
.def("sync_index", [](SelectedRows &instance) { instance.SyncIndex(); })
.def("rows", [](SelectedRows &self) {
auto rows = self.rows();
std::vector<int64_t> new_rows;
......
......@@ -21,36 +21,27 @@ import paddle.fluid.core as core
from paddle.fluid.op import Operator
def output_hist(out):
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestLookupSpraseTable(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Id Variable
ids = scope.var("Ids").get_tensor()
ids_array = np.array([0, 2, 3, 5, 100]).astype("int64")
ids.set(ids_array, place)
# create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 10000
table_size = 10000
row_numel = 8
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_selected_rows.set_height(table_size)
w_array = np.ones((table_size, row_numel)).astype("float32")
for i in range(table_size):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
# create and initialize Id Variable
ids = scope.var("Ids").get_tensor()
ids_array1 = np.array([0, 2, 3, 2, 5, 0, 100]).astype("int64")
ids.set(ids_array1, place)
# create Out Variable
out_tensor = scope.var('Out').get_tensor()
......@@ -66,16 +57,28 @@ class TestLookupSpraseTable(OpTest):
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
result_array1 = np.array(out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array[:-2]):
assert (row == result_array[idx]).all()
assert (result_array1[0] == w_array[0]).all()
assert (result_array1[1] == w_array[1]).all()
assert (result_array1[2] == w_array[2]).all()
assert (result_array1[3] == w_array[1]).all()
assert (result_array1[4] == w_array[3]).all()
assert (result_array1[5] == w_array[0]).all()
assert (result_array1[6] == w_array[4]).all()
# create and initialize Id Variable
ids = scope.var("Ids").get_tensor()
ids_array2 = np.array([4, 2, 3, 7, 100000]).astype("int64")
ids.set(ids_array2, place)
lookup_table.run(scope, place)
# check the random value
hist, prob = output_hist(result_array[-1])
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
result_array2 = np.array(out_tensor)
assert (result_array2[0] == w_array[5]).all()
assert (result_array2[1] == w_array[1]).all()
assert (result_array2[2] == w_array[2]).all()
assert (result_array2[3] == w_array[6]).all()
assert (result_array2[4] == w_array[7]).all()
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
......
......@@ -126,6 +126,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
w_selected_rows = scope.var('Param').get_selected_rows()
w_selected_rows.set_height(len(param_rows))
w_selected_rows.set_rows(param_rows)
w_selected_rows.sync_index()
w_array = np.ones((len(param_rows), row_width)).astype("float32")
for i in range(len(param_rows)):
w_array[i] *= i
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册