提交 bd2b6d7f 编写于 作者: Q Qiao Longfei

sum_op support inplace

上级 b4a32eaf
...@@ -69,16 +69,33 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -69,16 +69,33 @@ class SumKernel : public framework::OpKernel<T> {
} }
} }
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now"); if (in_place && in_vars.size() < 2) {
auto *out = context.Output<SelectedRows>("Out"); return;
out->mutable_rows()->clear(); }
std::vector<const paddle::framework::SelectedRows *> inputs; std::vector<const paddle::framework::SelectedRows *> inputs;
SelectedRows temp_in0;
for (auto &in_var : in_vars) { if (in_place) {
inputs.push_back(&in_var->Get<SelectedRows>()); auto &in0 = in_vars[0]->Get<SelectedRows>();
temp_in0.set_height(in0.height());
temp_in0.set_rows(in0.rows());
framework::TensorCopy(in0.value(), in0.place(),
context.device_context(),
temp_in0.mutable_value());
inputs.push_back(&temp_in0);
for (size_t i = 1; i < in_vars.size(); ++i) {
inputs.push_back(&in_vars[i]->Get<SelectedRows>());
}
} else {
for (auto &in_var : in_vars) {
inputs.push_back(&in_var->Get<SelectedRows>());
}
} }
auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
math::scatter::MergeAdd<DeviceContext, T> merge_add; math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out); merge_add(context.template device_context<DeviceContext>(), inputs, out);
} else if (out_var->IsType<framework::LoDTensorArray>()) { } else if (out_var->IsType<framework::LoDTensorArray>()) {
......
...@@ -45,17 +45,17 @@ class TestSumOp(OpTest): ...@@ -45,17 +45,17 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place): def check_with_place(self, place, inplace):
scope = core.Scope() scope = core.Scope()
self.height = 10 self.height = 10
self.row_numel = 12 self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6] self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, True, True, True) self.check_input_and_optput(scope, place, inplace, True, True, True)
self.check_input_and_optput(scope, place, False, True, True) self.check_input_and_optput(scope, place, inplace, False, True, True)
self.check_input_and_optput(scope, place, False, False, True) self.check_input_and_optput(scope, place, inplace, False, False, True)
self.check_input_and_optput(scope, place, False, False, False) self.check_input_and_optput(scope, place, inplace, False, False, False)
def _get_array(self, row_num, row_numel): def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32") array = np.ones((row_num, row_numel)).astype("float32")
...@@ -66,6 +66,7 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -66,6 +66,7 @@ class TestSelectedRowsSumOp(OpTest):
def check_input_and_optput(self, def check_input_and_optput(self,
scope, scope,
place, place,
inplace,
w1_has_data=False, w1_has_data=False,
w2_has_data=False, w2_has_data=False,
w3_has_data=False): w3_has_data=False):
...@@ -75,10 +76,14 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -75,10 +76,14 @@ class TestSelectedRowsSumOp(OpTest):
self.create_selected_rows(scope, place, "W3", w3_has_data) self.create_selected_rows(scope, place, "W3", w3_has_data)
# create Out Variable # create Out Variable
out = scope.var('Out').get_selected_rows() if inplace:
out_var_name = "W1"
else:
out_var_name = "Out"
out = scope.var(out_var_name).get_selected_rows()
# create and run sum operator # create and run sum operator
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name)
sum_op.run(scope, place) sum_op.run(scope, place)
has_data_w_num = 0 has_data_w_num = 0
...@@ -121,7 +126,8 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -121,7 +126,8 @@ class TestSelectedRowsSumOp(OpTest):
places = [core.CPUPlace()] places = [core.CPUPlace()]
# currently only support CPU # currently only support CPU
for place in places: for place in places:
self.check_with_place(place) for inplace in [True, False]:
self.check_with_place(place, inplace)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册