提交 1a598800 编写于 作者: Q qiaolongfei

update test_sum_op

上级 40d3bd4e
...@@ -14,4 +14,4 @@ if(WITH_INFERENCE) ...@@ -14,4 +14,4 @@ if(WITH_INFERENCE)
add_subdirectory(inference) add_subdirectory(inference)
endif() endif()
add_subdirectory(train) #add_subdirectory(train)
...@@ -70,7 +70,7 @@ struct MergeAdd { ...@@ -70,7 +70,7 @@ struct MergeAdd {
void operator()(const DeviceContext& context, void operator()(const DeviceContext& context,
const framework::SelectedRows& input, const framework::SelectedRows& input,
framework::SelectedRows* output); framework::SelectedRows* output);
void operator()(const platform::CPUDeviceContext& context, void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs, const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output); framework::SelectedRows* output);
}; };
......
...@@ -69,80 +69,18 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -69,80 +69,18 @@ class SumKernel : public framework::OpKernel<T> {
} }
} }
} else if (out_var->IsType<framework::SelectedRows>()) { } else if (out_var->IsType<framework::SelectedRows>()) {
std::unique_ptr<framework::SelectedRows> in0; PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
if (in_place) {
// If is in_place, we store the input[0] to in0
auto &in_sel0 = in_vars[0]->Get<SelectedRows>();
auto &rows = in_sel0.rows();
#ifdef PADDLE_WITH_CUDA
std::vector<int64_t> rows_in_cpu;
rows_in_cpu.reserve(rows.size());
for (auto item : rows) {
rows_in_cpu.push_back(item);
}
in0.reset(new framework::SelectedRows(rows_in_cpu, in_sel0.height()));
#else
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
#endif
in0->mutable_value()->ShareDataWith(in_sel0.value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows & {
if (i == 0 && in0) {
return *in0.get();
} else {
return in_vars[i]->Get<SelectedRows>();
}
};
auto *out = context.Output<SelectedRows>("Out"); auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear(); out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
std::vector<int64_t> in_dim; std::vector<const paddle::framework::SelectedRows *> inputs;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims());
break;
}
}
if (in_dim.empty()) {
VLOG(3) << "WARNING: all the inputs are empty";
in_dim =
framework::vectorize(get_selected_row(in_num - 1).value().dims());
} else {
in_dim[0] = static_cast<int64_t>(first_dim);
}
out_value->Resize(framework::make_ddim(in_dim)); for (auto &in_var : in_vars) {
out_value->mutable_data<T>(context.GetPlace()); inputs.push_back(&in_var->Get<SelectedRows>());
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
} }
math::SelectedRowsAddTo<DeviceContext, T> functor; math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out);
int64_t offset = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row,
offset, out);
offset += sel_row.value().numel();
}
} else if (out_var->IsType<framework::LoDTensorArray>()) { } else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>(); auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
......
...@@ -47,11 +47,22 @@ class TestSumOp(OpTest): ...@@ -47,11 +47,22 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, True, True, True) self.check_input_and_optput(scope, place, True, True, True)
self.check_input_and_optput(scope, place, False, True, True) self.check_input_and_optput(scope, place, False, True, True)
self.check_input_and_optput(scope, place, False, False, True) self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False) self.check_input_and_optput(scope, place, False, False, False)
def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32")
for i in range(row_num):
array[i] *= i
return array
def check_input_and_optput(self, def check_input_and_optput(self,
scope, scope,
place, place,
...@@ -71,28 +82,36 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -71,28 +82,36 @@ class TestSelectedRowsSumOp(OpTest):
sum_op.run(scope, place) sum_op.run(scope, place)
has_data_w_num = 0 has_data_w_num = 0
for w in [w1_has_data, w2_has_data, w3_has_data]: for has_data in [w1_has_data, w2_has_data, w3_has_data]:
if not w: if has_data:
has_data_w_num += 1 has_data_w_num += 1
self.assertEqual(7 * has_data_w_num, len(out.rows())) if has_data_w_num > 0:
self.assertEqual(len(out.rows()), 7)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(len(self.rows), self.row_numel) *
has_data_w_num))
else:
self.assertEqual(len(out.rows()), 0)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(0, self.row_numel) * has_data_w_num))
def create_selected_rows(self, scope, place, var_name, isEmpty): def create_selected_rows(self, scope, place, var_name, has_data):
# create and initialize W Variable # create and initialize W Variable
if not isEmpty: if has_data:
rows = [0, 1, 2, 3, 4, 5, 6] rows = self.rows
row_numel = 12
else: else:
rows = [] rows = []
row_numel = 12
var = scope.var(var_name) var = scope.var(var_name)
w_selected_rows = var.get_selected_rows() w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(len(rows)) w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows) w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32") w_array = self._get_array(len(rows), self.row_numel)
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor() w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place) w_tensor.set(w_array, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册