提交 766ac488 编写于 作者: T tangwei12

sum_op selectedRows dim bug fix

上级 c4c8f60b
...@@ -175,18 +175,31 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -175,18 +175,31 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& sel_row = get_selected_row(i); auto& sel_row = get_selected_row(i);
first_dim += sel_row.rows().size(); first_dim += sel_row.rows().size();
} }
auto in_dim =
framework::vectorize(get_selected_row(N - 1).value().dims()); std::vector<int64_t> in_dim;
for (int i = 0; i < N; i++) {
auto& sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims());
break;
}
}
if (in_dim.empty()) {
in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
}
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(ctx.GetPlace());
// if all the input sparse vars are empty, no need to // if all the input sparse vars are empty, no need to
// merge these vars. // merge these vars.
if (first_dim == 0UL) { if (first_dim == 0UL) {
return; return;
} }
out_value->mutable_data<T>(ctx.GetPlace());
math::SelectedRowsAddTo<CPUDeviceContext, T> functor; math::SelectedRowsAddTo<CPUDeviceContext, T> functor;
int64_t offset = 0; int64_t offset = 0;
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
......
...@@ -114,16 +114,20 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -114,16 +114,20 @@ class SumKernel : public framework::OpKernel<T> {
break; break;
} }
} }
if (in_dim.empty()) {
in_dim = framework::vectorize(get_selected_row(N - 1).value().dims());
}
in_dim[0] = static_cast<int64_t>(first_dim); in_dim[0] = static_cast<int64_t>(first_dim);
out_value->Resize(framework::make_ddim(in_dim)); out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace());
// if all the input sparse vars are empty, no need to // if all the input sparse vars are empty, no need to
// merge these vars. // merge these vars.
if (first_dim == 0UL) { if (first_dim == 0UL) {
return; return;
} }
out_value->mutable_data<T>(context.GetPlace());
math::SelectedRowsAddTo<DeviceContext, T> functor; math::SelectedRowsAddTo<DeviceContext, T> functor;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestSumOp(OpTest): class TestSumOp(OpTest):
...@@ -40,5 +42,61 @@ class TestSumOp(OpTest): ...@@ -40,5 +42,61 @@ class TestSumOp(OpTest):
pass pass
class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place):
scope = core.Scope()
self.check_input_and_optput(scope, place, True, True, True)
self.check_input_and_optput(scope, place, False, True, True)
self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False)
def check_input_and_optput(self, scope, place, w1=False, w2=False,
w3=False):
W1 = self.create_selected_rows(scope, place, "W1", w1)
W2 = self.create_selected_rows(scope, place, "W2", w2)
W3 = self.create_selected_rows(scope, place, "W3", w3)
# create Out Variable
out = scope.var('Out').get_selected_rows()
# create and run sum operator
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out')
sum_op.run(scope, place)
trues = 0
for w in [w1, w2, w3]:
if not w:
trues += 1
self.assertEqual(7 * trues, len(out.rows()))
def create_selected_rows(self, scope, place, var_name, isEmpty):
# create and initialize W Variable
if not isEmpty:
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
else:
rows = []
row_numel = 12
var = scope.var(var_name)
w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
return var
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册