From bd2b6d7f8f62397df9bd39da8a41978d888751ed Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 17 Oct 2018 10:05:33 +0800 Subject: [PATCH] sum_op support inplace --- paddle/fluid/operators/sum_op.h | 27 +++++++++++++++---- .../fluid/tests/unittests/test_sum_op.py | 22 +++++++++------ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index bc571cd619f..c8ff532e1b8 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -69,16 +69,33 @@ class SumKernel : public framework::OpKernel { } } } else if (out_var->IsType()) { - PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now"); - auto *out = context.Output("Out"); - out->mutable_rows()->clear(); + if (in_place && in_vars.size() < 2) { + return; + } std::vector inputs; + SelectedRows temp_in0; - for (auto &in_var : in_vars) { - inputs.push_back(&in_var->Get()); + if (in_place) { + auto &in0 = in_vars[0]->Get(); + temp_in0.set_height(in0.height()); + temp_in0.set_rows(in0.rows()); + framework::TensorCopy(in0.value(), in0.place(), + context.device_context(), + temp_in0.mutable_value()); + inputs.push_back(&temp_in0); + for (size_t i = 1; i < in_vars.size(); ++i) { + inputs.push_back(&in_vars[i]->Get()); + } + } else { + for (auto &in_var : in_vars) { + inputs.push_back(&in_var->Get()); + } } + auto *out = context.Output("Out"); + out->mutable_rows()->clear(); + math::scatter::MergeAdd merge_add; merge_add(context.template device_context(), inputs, out); } else if (out_var->IsType()) { diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index a461c0a239a..1125dbd398e 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -45,17 +45,17 @@ class TestSumOp(OpTest): class TestSelectedRowsSumOp(OpTest): - def check_with_place(self, place): + def check_with_place(self, place, inplace): scope = core.Scope() self.height = 10 self.row_numel = 12 self.rows = [0, 1, 2, 3, 4, 5, 6] - self.check_input_and_optput(scope, place, True, True, True) - self.check_input_and_optput(scope, place, False, True, True) - self.check_input_and_optput(scope, place, False, False, True) - self.check_input_and_optput(scope, place, False, False, False) + self.check_input_and_optput(scope, place, inplace, True, True, True) + self.check_input_and_optput(scope, place, inplace, False, True, True) + self.check_input_and_optput(scope, place, inplace, False, False, True) + self.check_input_and_optput(scope, place, inplace, False, False, False) def _get_array(self, row_num, row_numel): array = np.ones((row_num, row_numel)).astype("float32") @@ -66,6 +66,7 @@ class TestSelectedRowsSumOp(OpTest): def check_input_and_optput(self, scope, place, + inplace, w1_has_data=False, w2_has_data=False, w3_has_data=False): @@ -75,10 +76,14 @@ class TestSelectedRowsSumOp(OpTest): self.create_selected_rows(scope, place, "W3", w3_has_data) # create Out Variable - out = scope.var('Out').get_selected_rows() + if inplace: + out_var_name = "W1" + else: + out_var_name = "Out" + out = scope.var(out_var_name).get_selected_rows() # create and run sum operator - sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') + sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name) sum_op.run(scope, place) has_data_w_num = 0 @@ -121,7 +126,8 @@ class TestSelectedRowsSumOp(OpTest): places = [core.CPUPlace()] # currently only support CPU for place in places: - self.check_with_place(place) + for inplace in [True, False]: + self.check_with_place(place, inplace) if __name__ == "__main__": -- GitLab