diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 5e2024e0ea9040b758e1cec4dbaa4b329bbb727e..07b0388b607e862cd7cd64fcb6e3d00ee277b178 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -91,7 +91,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { .stream()>>>(dx_data, dy_data, x_data, label_data, batch_size, class_num); } else { - math::SetConstant(ctx.device_context(), dx, 0); + math::SetConstant functor; + functor(ctx.device_context(), dx, 0); auto* label_data = label->data(); grid = (batch_size + block - 1) / block; CrossEntropyGradientKernel<<< diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index d2d321aa7ed8e32cc19d5a171beea34d36195b10..19c276d23fc0f7f89b45c6d52c1c8e40d7f7d51d 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -70,7 +70,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { const T* x_data = x->data(); const int* label_data = label->data(); - math::SetConstant(ctx.device_context(), dx, 0); + math::SetConstant functor; + functor(ctx.device_context(), dx, 0); for (int i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index ba653afa2cb175ae2e5e21088b6dc7ba76a6018f..75a705b3460dfeaa5f2d4b1ea0f54c491160783f 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" +#include namespace paddle { namespace operators { @@ -130,6 +131,65 @@ void matmul( matrix_b.data(), beta, matrix_out->data()); } +template struct SetConstant; + +namespace detail { +size_t FindPos(const std::vector& rows, int64_t value) { + for (size_t i = 0; i < rows.size(); i++) { + if (rows[i] == value) { + return i; + } + } + return 0; +} +} // namespace detail + +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2.height()); + PADDLE_ENFORCE_EQ(in1_height, output->height()); + + auto& in1_rows = input1.rows(); + auto& in2_rows = input2.rows(); + auto& out_rows = output->rows(); + + auto* out_value = output->mutable_value(); + auto& in1_value = input1.value(); + auto& in2_value = input2.value(); + + auto in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); + PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); + + SetConstant functor; + functor(context, out_value, 0.0); + auto* out_data = out_value->data(); + + auto* in1_data = in1_value.data(); + for (size_t i = 0; i < in1_rows.size(); i++) { + auto row = detail::FindPos(out_rows, in1_rows[i]); + for (size_t j = 0; j < in1_row_numel; j++) { + out_data[row * in1_row_numel + j] += in1_data[i * in1_row_numel + j]; + } + } + + auto* in2_data = in2_value.data(); + for (size_t i = 0; i < in2_rows.size(); i++) { + auto row = detail::FindPos(out_rows, in2_rows[i]); + for (size_t j = 0; j < in1_row_numel; j++) { + out_data[row * in1_row_numel + j] += in2_data[i * in1_row_numel + j]; + } + } + } +}; + +template struct SelectedRowsAdd; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 473eff4d198ca9b17b6af8eebd6dfe39d49d138d..f298f34baba9867351601ec4b87c0a0160d2a38d 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -53,6 +53,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include #include "paddle/framework/eigen.h" +#include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -86,11 +87,22 @@ void matmul(const platform::DeviceContext& context, framework::Tensor* matrix_out, T beta); template -void SetConstant(const platform::DeviceContext& context, - framework::Tensor* tensor, T num) { - auto t = framework::EigenVector::Flatten(*tensor); - t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); -} +struct SetConstant { + void operator()(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = + t.constant(static_cast(num)); + } +}; + +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output); +}; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index c87d200c3aa5a9336c0f73d3a8bb88d2e9eafbab..43760bc6015839a351622597bdc865ca9a3b6c27 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -1,4 +1,5 @@ #include "paddle/operators/math/math_function.h" +#include "glog/logging.h" #include "gtest/gtest.h" #ifdef PADDLE_WITH_CUDA @@ -253,18 +254,69 @@ TEST(math_function, zero) { auto* cpu_place = new paddle::platform::CPUPlace(); float* t = tensor.mutable_data({2, 2}, *cpu_place); paddle::platform::CPUDeviceContext context(*cpu_place); - paddle::operators::math::SetConstant( - context, &tensor, 0); + paddle::operators::math::SetConstant + functor; + functor(context, &tensor, 0); EXPECT_EQ(t[0], 0); EXPECT_EQ(t[1], 0); EXPECT_EQ(t[2], 0); EXPECT_EQ(t[3], 0); - paddle::operators::math::SetConstant( - context, &tensor, 1); + functor(context, &tensor, 1); EXPECT_EQ(t[0], 1); EXPECT_EQ(t[1], 1); EXPECT_EQ(t[2], 1); EXPECT_EQ(t[3], 1); } + +TEST(math_function, selected_rows_add) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators::math; + + CPUPlace cpu_place; + CPUDeviceContext ctx(cpu_place); + SetConstant functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows1{0, 4, 7}; + std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + make_ddim({static_cast(rows1.size()), row_numel}), cpu_place); + functor(ctx, in1_value, 2.0); + + std::vector rows2{0, 5, 7, 9}; + std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + make_ddim({static_cast(rows2.size()), row_numel}), cpu_place); + functor(ctx, in2_value, 1.0); + + std::unique_ptr output{new SelectedRows()}; + output->set_height(height); + std::vector out_rows = {0, 4, 5, 7, 9}; + output->set_rows(out_rows); + + auto* out_value = output->mutable_value(); + out_value->mutable_data(make_ddim({5, 10}), cpu_place); + + SelectedRowsAdd add_functor; + add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); + + auto* data = output->value().data(); + // out_rows[0] = 0 + EXPECT_EQ(data[0 * row_numel + 0], 3.0); + EXPECT_EQ(data[0 * row_numel + 8], 3.0); + // out_rows[1] = 4 + EXPECT_EQ(data[1 * row_numel + 1], 2.0); + // out_rows[2] = 5 + EXPECT_EQ(data[2 * row_numel + 6], 1.0); + // out_rows[3] = 7 + EXPECT_EQ(data[3 * row_numel + 3], 3.0); + EXPECT_EQ(data[3 * row_numel + 8], 3.0); + // out_rows[4] = 9 + EXPECT_EQ(data[4 * row_numel + 4], 1.0); +}