提交 5db75513 编写于 作者: Q Qiao Longfei

optimize code

上级 eb6d9e3b
......@@ -228,8 +228,25 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// add or mul.
namespace scatter {
static size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
return std::find(rows.begin(), rows.end(), value) - rows.begin();
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
T* out) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.AXPY(data_len, 1., in, out);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
T* out) {
for (int64_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
}
template <typename T>
......@@ -290,9 +307,9 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
for (int64_t j = 0; j < input_width; j++) {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
elementwise_add<platform::CPUDeviceContext, T>(
context, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]);
}
}
}
......@@ -300,6 +317,8 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> {
......
......@@ -87,108 +87,6 @@ struct MergeAdd {
framework::SelectedRows* output);
};
template <>
struct MergeAdd<platform::CPUDeviceContext, float> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows();
std::vector<int64_t> merge_rows;
merge_rows.reserve(input_rows.size());
std::unordered_map<int64_t, size_t> rows_pos_map;
rows_pos_map.reserve(input_rows.size());
size_t idx = 0u;
for (std::vector<int64_t>::iterator iter = input_rows.begin();
iter != input_rows.end(); ++iter) {
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
rows_pos_map[*iter] = idx++;
merge_rows.emplace_back(*iter);
}
}
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, float> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<float>();
auto* input_data = input.value().data<float>();
auto blas = GetBlas<platform::CPUDeviceContext, float>(context);
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_pos_map[input_rows[i]];
float* y = out_data + out_i * input_width;
const float* x = input_data + i * input_width;
blas.AXPY(input_width, 1., x, y);
}
}
};
template <>
struct MergeAdd<platform::CPUDeviceContext, double> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
(*this)(context, input, &out);
return out;
}
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output) {
framework::SelectedRows& out = *output;
auto input_rows = input.rows();
std::vector<int64_t> merge_rows;
merge_rows.reserve(input_rows.size());
std::unordered_map<int64_t, size_t> rows_pos_map;
rows_pos_map.reserve(input_rows.size());
size_t idx = 0u;
for (std::vector<int64_t>::iterator iter = input_rows.begin();
iter != input_rows.end(); ++iter) {
if (rows_pos_map.find(*iter) == rows_pos_map.end()) {
rows_pos_map[*iter] = idx++;
merge_rows.emplace_back(*iter);
}
}
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<double>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, double> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<double>();
auto* input_data = input.value().data<double>();
auto blas = GetBlas<platform::CPUDeviceContext, double>(context);
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_pos_map[input_rows[i]];
double* y = out_data + out_i * input_width;
const double* x = input_data + i * input_width;
blas.AXPY(input_width, 1., x, y);
}
}
};
template <typename DeviceContext, typename T>
struct Add {
framework::SelectedRows operator()(const DeviceContext& context,
......
......@@ -303,6 +303,65 @@ TEST(selected_rows_functor, cpu_merge_add_int) {
EXPECT_EQ(out_data[2 * row_numel], 1);
}
TEST(selected_rows_functor, cpu_merge_add_multi) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
set_const;
int64_t height = 10;
int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place);
set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({3, row_numel}));
std::vector<int64_t> ret_rows{2, 3, 5};
EXPECT_EQ(output->rows(), ret_rows);
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
std::cout << out_data[i * row_numel + j] << " ";
}
std::cout << "\n";
}
}
TEST(selected_rows_functor, cpu_sum_to) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册