diff --git a/paddle/fluid/operators/sum_mkldnn_op.cc b/paddle/fluid/operators/sum_mkldnn_op.cc index d2035777ee2289291a02594ee289156504df09d9..c2cf13f7bf45066c07367ac48862208856feb518 100644 --- a/paddle/fluid/operators/sum_mkldnn_op.cc +++ b/paddle/fluid/operators/sum_mkldnn_op.cc @@ -175,18 +175,31 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel { auto& sel_row = get_selected_row(i); first_dim += sel_row.rows().size(); } - auto in_dim = - framework::vectorize(get_selected_row(N - 1).value().dims()); + + std::vector in_dim; + for (int i = 0; i < N; i++) { + auto& sel_row = get_selected_row(i); + if (sel_row.rows().size() > 0) { + in_dim = framework::vectorize(sel_row.value().dims()); + break; + } + } + + if (in_dim.empty()) { + in_dim = framework::vectorize(get_selected_row(N - 1).value().dims()); + } in_dim[0] = static_cast(first_dim); out_value->Resize(framework::make_ddim(in_dim)); + out_value->mutable_data(ctx.GetPlace()); + // if all the input sparse vars are empty, no need to // merge these vars. if (first_dim == 0UL) { return; } - out_value->mutable_data(ctx.GetPlace()); + math::SelectedRowsAddTo functor; int64_t offset = 0; for (int i = 0; i < N; i++) { diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index a497c91289a8a5e744e75a54dd477f176c8f2609..da192c6212b5094fbdbdf3546f73dd04a517bbab 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -114,16 +114,20 @@ class SumKernel : public framework::OpKernel { break; } } + if (in_dim.empty()) { + in_dim = framework::vectorize(get_selected_row(N - 1).value().dims()); + } + in_dim[0] = static_cast(first_dim); out_value->Resize(framework::make_ddim(in_dim)); + out_value->mutable_data(context.GetPlace()); // if all the input sparse vars are empty, no need to // merge these vars. if (first_dim == 0UL) { return; } - out_value->mutable_data(context.GetPlace()); math::SelectedRowsAddTo functor; diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 1d90414e137a70e6265042e24e106fe565802778..3c42607918209af7bba9c2ba94aab5d521944588 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -15,6 +15,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core +from paddle.fluid.op import Operator class TestSumOp(OpTest): @@ -40,5 +42,61 @@ class TestSumOp(OpTest): pass +class TestSelectedRowsSumOp(OpTest): + def check_with_place(self, place): + scope = core.Scope() + self.check_input_and_optput(scope, place, True, True, True) + self.check_input_and_optput(scope, place, False, True, True) + self.check_input_and_optput(scope, place, False, False, True) + self.check_input_and_optput(scope, place, False, False, False) + + def check_input_and_optput(self, scope, place, w1=False, w2=False, + w3=False): + W1 = self.create_selected_rows(scope, place, "W1", w1) + W2 = self.create_selected_rows(scope, place, "W2", w2) + W3 = self.create_selected_rows(scope, place, "W3", w3) + + # create Out Variable + out = scope.var('Out').get_selected_rows() + + # create and run sum operator + sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') + sum_op.run(scope, place) + + trues = 0 + for w in [w1, w2, w3]: + if not w: + trues += 1 + + self.assertEqual(7 * trues, len(out.rows())) + + def create_selected_rows(self, scope, place, var_name, isEmpty): + # create and initialize W Variable + if not isEmpty: + rows = [0, 1, 2, 3, 4, 5, 6] + row_numel = 12 + else: + rows = [] + row_numel = 12 + + var = scope.var(var_name) + w_selected_rows = var.get_selected_rows() + w_selected_rows.set_height(len(rows)) + w_selected_rows.set_rows(rows) + w_array = np.ones((len(rows), row_numel)).astype("float32") + for i in range(len(rows)): + w_array[i] *= i + w_tensor = w_selected_rows.get_tensor() + w_tensor.set(w_array, place) + + return var + + def test_w_is_selected_rows(self): + places = [core.CPUPlace()] + # currently only support CPU + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main()