提交 bd2b6d7f 编写于 作者: Q Qiao Longfei

sum_op support inplace

上级 b4a32eaf
......@@ -69,15 +69,32 @@ class SumKernel : public framework::OpKernel<T> {
}
}
} else if (out_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
if (in_place && in_vars.size() < 2) {
return;
}
std::vector<const paddle::framework::SelectedRows *> inputs;
SelectedRows temp_in0;
if (in_place) {
auto &in0 = in_vars[0]->Get<SelectedRows>();
temp_in0.set_height(in0.height());
temp_in0.set_rows(in0.rows());
framework::TensorCopy(in0.value(), in0.place(),
context.device_context(),
temp_in0.mutable_value());
inputs.push_back(&temp_in0);
for (size_t i = 1; i < in_vars.size(); ++i) {
inputs.push_back(&in_vars[i]->Get<SelectedRows>());
}
} else {
for (auto &in_var : in_vars) {
inputs.push_back(&in_var->Get<SelectedRows>());
}
}
auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out);
......
......@@ -45,17 +45,17 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place):
def check_with_place(self, place, inplace):
scope = core.Scope()
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, True, True, True)
self.check_input_and_optput(scope, place, False, True, True)
self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False)
self.check_input_and_optput(scope, place, inplace, True, True, True)
self.check_input_and_optput(scope, place, inplace, False, True, True)
self.check_input_and_optput(scope, place, inplace, False, False, True)
self.check_input_and_optput(scope, place, inplace, False, False, False)
def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32")
......@@ -66,6 +66,7 @@ class TestSelectedRowsSumOp(OpTest):
def check_input_and_optput(self,
scope,
place,
inplace,
w1_has_data=False,
w2_has_data=False,
w3_has_data=False):
......@@ -75,10 +76,14 @@ class TestSelectedRowsSumOp(OpTest):
self.create_selected_rows(scope, place, "W3", w3_has_data)
# create Out Variable
out = scope.var('Out').get_selected_rows()
if inplace:
out_var_name = "W1"
else:
out_var_name = "Out"
out = scope.var(out_var_name).get_selected_rows()
# create and run sum operator
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out')
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out=out_var_name)
sum_op.run(scope, place)
has_data_w_num = 0
......@@ -121,7 +126,8 @@ class TestSelectedRowsSumOp(OpTest):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
for inplace in [True, False]:
self.check_with_place(place, inplace)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册