提交 40d3bd4e 编写于 作者: Q qiaolongfei

selected rows merge add support multi input

上级 8cd17c04
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <set> #include <set>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -190,7 +189,7 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>; ...@@ -190,7 +189,7 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// add or mul. // add or mul.
namespace scatter { namespace scatter {
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) { static size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
return std::find(rows.begin(), rows.end(), value) - rows.begin(); return std::find(rows.begin(), rows.end(), value) - rows.begin();
} }
...@@ -206,14 +205,31 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -206,14 +205,31 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output) { framework::SelectedRows* output) {
framework::SelectedRows& out = *output; std::vector<const framework::SelectedRows*> inputs;
auto input_rows = input.rows(); inputs.push_back(&input);
std::set<int64_t> row_set(input_rows.begin(), input_rows.end()); (*this)(context, inputs, output);
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end()); }
auto input_width = input.value().dims()[1]; void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output) {
PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input");
auto input_width = inputs[0]->value().dims()[1];
auto input_height = inputs[0]->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
"all input should have same "
"dimension except for the first one");
PADDLE_ENFORCE_EQ(input_height, input->height(),
"all input should have same height");
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
out.set_rows(merge_rows); out.set_rows(merge_rows);
out.set_height(input.height()); out.set_height(input_height);
out.mutable_value()->mutable_data<T>( out.mutable_value()->mutable_data<T>(
framework::make_ddim( framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}), {static_cast<int64_t>(merge_rows.size()), input_width}),
...@@ -223,12 +239,16 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -223,12 +239,16 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
constant_functor(context, out.mutable_value(), 0.0); constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>(); auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) { for (auto* input : inputs) {
size_t out_i = FindPos(merge_rows, input_rows[i]); auto* input_data = input->value().data<T>();
for (int64_t j = 0; j < input_width; j++) { auto& input_rows = input->rows();
out_data[out_i * input_width + j] += input_data[i * input_width + j];
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = FindPos(merge_rows, input_rows[i]);
for (int64_t j = 0; j < input_width; j++) {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
} }
} }
} }
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -68,6 +70,9 @@ struct MergeAdd { ...@@ -68,6 +70,9 @@ struct MergeAdd {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output); framework::SelectedRows* output);
void operator()(const platform::CPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -219,3 +219,62 @@ TEST(selected_rows_functor, cpu_add_to) { ...@@ -219,3 +219,62 @@ TEST(selected_rows_functor, cpu_add_to) {
// row9: 2.0 + 3.0 // row9: 2.0 + 3.0
EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0); EXPECT_EQ(tensor1_data[9 * row_numel + 6], 5.0);
} }
TEST(selected_rows_functor, cpu_merge_add) {
paddle::platform::CPUPlace cpu_place;
paddle::platform::CPUDeviceContext ctx(cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUDeviceContext,
float>
set_const;
int64_t height = 10;
int64_t row_numel = 8;
std::vector<int64_t> rows1{5, 2, 5, 3, 5};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows1{
new paddle::framework::SelectedRows(rows1, height)};
auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place);
set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
new paddle::framework::SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>(
paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place);
set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{
new paddle::framework::SelectedRows()};
output->set_height(height);
paddle::operators::math::scatter::MergeAdd<paddle::platform::CPUDeviceContext,
float>
merge_add_functor;
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.push_back(selected_rows1.get());
inputs.push_back(selected_rows2.get());
merge_add_functor(ctx, inputs, output.get());
EXPECT_EQ(output->height(), height);
EXPECT_EQ(output->value().dims(),
paddle::framework::make_ddim({3, row_numel}));
std::vector<int64_t> ret_rows{2, 3, 5};
EXPECT_EQ(output->rows(), ret_rows);
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
std::cout << out_data[i * row_numel + j] << " ";
}
std::cout << "\n";
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册