提交 931572e2 编写于 作者: Q qijun

SelectedRowsAddTensor method

上级 7b183433
...@@ -45,6 +45,9 @@ class SelectedRows { ...@@ -45,6 +45,9 @@ class SelectedRows {
} }
private: private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simplely concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
std::vector<int64_t> rows_; std::vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
int64_t height_; int64_t height_;
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/framework/eigen.h"
#include "paddle/memory/memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -151,11 +153,17 @@ struct SelectedRowsAdd<platform::CPUPlace, T> { ...@@ -151,11 +153,17 @@ struct SelectedRowsAdd<platform::CPUPlace, T> {
framework::SelectedRows* output) { framework::SelectedRows* output) {
auto in1_height = input1.height(); auto in1_height = input1.height();
PADDLE_ENFORCE_EQ(in1_height, input2.height()); PADDLE_ENFORCE_EQ(in1_height, input2.height());
PADDLE_ENFORCE_EQ(in1_height, output->height()); output->set_height(in1_height);
auto& in1_rows = input1.rows(); auto& in1_rows = input1.rows();
auto& in2_rows = input2.rows(); auto& in2_rows = input2.rows();
auto& out_rows = output->rows(); std::vector<int64_t> out_rows;
out_rows.reserve(in1_rows.size() + in2_rows.size());
// concat rows
out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end());
out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end());
output->set_rows(out_rows);
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
auto& in1_value = input1.value(); auto& in1_value = input1.value();
...@@ -165,29 +173,59 @@ struct SelectedRowsAdd<platform::CPUPlace, T> { ...@@ -165,29 +173,59 @@ struct SelectedRowsAdd<platform::CPUPlace, T> {
PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size());
PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size());
SetConstant<platform::CPUPlace, T> functor;
functor(context, out_value, 0.0);
auto* out_data = out_value->data<T>(); auto* out_data = out_value->data<T>();
auto* in1_data = in1_value.data<T>(); auto* in1_data = in1_value.data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) { memory::Copy(platform::CPUPlace(), out_data, platform::CPUPlace(), in1_data,
auto row = detail::FindPos(out_rows, in1_rows[i]); in1_value.numel() * sizeof(T));
for (size_t j = 0; j < in1_row_numel; j++) {
out_data[row * in1_row_numel + j] += in1_data[i * in1_row_numel + j];
}
}
auto* in2_data = in2_value.data<T>(); auto* in2_data = in2_value.data<T>();
for (size_t i = 0; i < in2_rows.size(); i++) { memory::Copy(platform::CPUPlace(), out_data + in1_value.numel(),
auto row = detail::FindPos(out_rows, in2_rows[i]); platform::CPUPlace(), in2_data, in2_value.numel() * sizeof(T));
for (size_t j = 0; j < in1_row_numel; j++) { }
out_data[row * in1_row_numel + j] += in2_data[i * in1_row_numel + j]; };
template struct SelectedRowsAdd<platform::CPUPlace, float>;
template <typename T>
struct SelectedRowsAddTensor<platform::CPUPlace, T> {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output) {
auto in1_height = input1.height();
auto in2_dims = input2.dims();
auto out_dims = output->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
PADDLE_ENFORCE_EQ(in1_height, out_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height);
PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height);
SetConstant<platform::CPUPlace, T> functor;
functor(context, output, 0.0);
auto* in1_data = in1_value.data<T>();
auto* out_data = output->data<T>();
for (size_t i = 0; i < in1_rows.size(); i++) {
for (int64_t j = 0; j < in1_row_numel; j++) {
out_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
} }
} }
auto out_eigen = framework::EigenVector<T>::Flatten(*output);
auto in2_eigen = framework::EigenVector<T>::Flatten(input2);
out_eigen.device(*context.GetEigenDevice<platform::CPUPlace>()) =
out_eigen + in2_eigen;
} }
}; };
template struct SelectedRowsAdd<platform::CPUPlace, float>; template struct SelectedRowsAddTensor<platform::CPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -96,6 +96,8 @@ struct SetConstant { ...@@ -96,6 +96,8 @@ struct SetConstant {
} }
}; };
// SelectedRows + SelectedRows will simplely concat value and rows.
// The real computation happens in dealing with LoDTensor.
template <typename Place, typename T> template <typename Place, typename T>
struct SelectedRowsAdd { struct SelectedRowsAdd {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
...@@ -104,6 +106,13 @@ struct SelectedRowsAdd { ...@@ -104,6 +106,13 @@ struct SelectedRowsAdd {
framework::SelectedRows* output); framework::SelectedRows* output);
}; };
template <typename Place, typename T>
struct SelectedRowsAddTensor {
void operator()(const platform::DeviceContext& context,
const framework::SelectedRows& input1,
const framework::Tensor& input2, framework::Tensor* output);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -286,37 +286,76 @@ TEST(math_function, selected_rows_add) { ...@@ -286,37 +286,76 @@ TEST(math_function, selected_rows_add) {
auto* in1_value = selected_rows1->mutable_value(); auto* in1_value = selected_rows1->mutable_value();
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place); make_ddim({static_cast<int64_t>(rows1.size()), row_numel}), cpu_place);
functor(ctx, in1_value, 2.0); functor(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{0, 5, 7, 9}; std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)}; std::unique_ptr<SelectedRows> selected_rows2{new SelectedRows(rows2, height)};
auto* in2_value = selected_rows2->mutable_value(); auto* in2_value = selected_rows2->mutable_value();
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place); make_ddim({static_cast<int64_t>(rows2.size()), row_numel}), cpu_place);
functor(ctx, in2_value, 1.0); functor(ctx, in2_value, 2.0);
std::unique_ptr<SelectedRows> output{new SelectedRows()}; std::unique_ptr<SelectedRows> output{new SelectedRows()};
output->set_height(height);
std::vector<int64_t> out_rows = {0, 4, 5, 7, 9};
output->set_rows(out_rows);
auto* out_value = output->mutable_value(); auto* out_value = output->mutable_value();
out_value->mutable_data<float>(make_ddim({5, 10}), cpu_place);
// simplely concat two SelectedRows
out_value->mutable_data<float>(make_ddim({7, 10}), cpu_place);
SelectedRowsAdd<CPUPlace, float> add_functor; SelectedRowsAdd<CPUPlace, float> add_functor;
add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); add_functor(ctx, *selected_rows1, *selected_rows2, output.get());
auto* data = output->value().data<float>(); auto out_height = output->height();
// out_rows[0] = 0 EXPECT_EQ(out_height, height);
EXPECT_EQ(data[0 * row_numel + 0], 3.0);
EXPECT_EQ(data[0 * row_numel + 8], 3.0); auto& out_rows = output->rows();
// out_rows[1] = 4
EXPECT_EQ(data[1 * row_numel + 1], 2.0); // input1 rows
// out_rows[2] = 5 EXPECT_EQ(out_rows[0], 0);
EXPECT_EQ(data[2 * row_numel + 6], 1.0); EXPECT_EQ(out_rows[1], 4);
// out_rows[3] = 7 EXPECT_EQ(out_rows[2], 7);
EXPECT_EQ(data[3 * row_numel + 3], 3.0); // input2 rows
EXPECT_EQ(data[3 * row_numel + 8], 3.0); EXPECT_EQ(out_rows[3], 0);
// out_rows[4] = 9 EXPECT_EQ(out_rows[4], 5);
EXPECT_EQ(data[4 * row_numel + 4], 1.0); EXPECT_EQ(out_rows[5], 7);
EXPECT_EQ(out_rows[6], 9);
auto* out_data = output->value().data<float>();
// input1 value
EXPECT_EQ(out_data[0 * row_numel + 0], 1.0);
EXPECT_EQ(out_data[0 * row_numel + 8], 1.0);
EXPECT_EQ(out_data[1 * row_numel + 1], 1.0);
EXPECT_EQ(out_data[2 * row_numel + 6], 1.0);
// input2 value
EXPECT_EQ(out_data[3 * row_numel + 3], 2.0);
EXPECT_EQ(out_data[3 * row_numel + 8], 2.0);
EXPECT_EQ(out_data[4 * row_numel + 4], 2.0);
EXPECT_EQ(out_data[5 * row_numel + 7], 2.0);
EXPECT_EQ(out_data[6 * row_numel + 9], 2.0);
std::unique_ptr<Tensor> tensor1{new Tensor()};
tensor1->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
SetConstant<CPUPlace, float> constant_functor;
constant_functor(ctx, tensor1.get(), 3.0);
std::unique_ptr<Tensor> tensor2{new Tensor()};
tensor2->mutable_data<float>(make_ddim({height, row_numel}), cpu_place);
SelectedRowsAddTensor<CPUPlace, float> add_tensor_functor;
add_tensor_functor(ctx, *output, *tensor1, tensor2.get());
auto* tensor2_data = tensor2->data<float>();
// row0: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0);
// row1: 3.0
EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0);
// row4 : 1.0 + 3.0
EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0);
// row5: 2.0 + 3.0
EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0);
// row6: 3.0
EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0);
// row7: 1.0 + 2.0 + 3.0
EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0);
// row9: 2.0 + 3.0
EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册