提交 1a598800 编写于 作者: Q qiaolongfei

update test_sum_op

上级 40d3bd4e
......@@ -14,4 +14,4 @@ if(WITH_INFERENCE)
add_subdirectory(inference)
endif()
add_subdirectory(train)
#add_subdirectory(train)
......@@ -70,7 +70,7 @@ struct MergeAdd {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output);
void operator()(const platform::CPUDeviceContext& context,
void operator()(const DeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output);
};
......
......@@ -69,80 +69,18 @@ class SumKernel : public framework::OpKernel<T> {
}
}
} else if (out_var->IsType<framework::SelectedRows>()) {
std::unique_ptr<framework::SelectedRows> in0;
if (in_place) {
// If is in_place, we store the input[0] to in0
auto &in_sel0 = in_vars[0]->Get<SelectedRows>();
auto &rows = in_sel0.rows();
#ifdef PADDLE_WITH_CUDA
std::vector<int64_t> rows_in_cpu;
rows_in_cpu.reserve(rows.size());
for (auto item : rows) {
rows_in_cpu.push_back(item);
}
in0.reset(new framework::SelectedRows(rows_in_cpu, in_sel0.height()));
#else
in0.reset(new framework::SelectedRows(rows, in_sel0.height()));
#endif
in0->mutable_value()->ShareDataWith(in_sel0.value());
}
auto get_selected_row = [&](size_t i) -> const SelectedRows & {
if (i == 0 && in0) {
return *in0.get();
} else {
return in_vars[i]->Get<SelectedRows>();
}
};
PADDLE_ENFORCE(!in_place, "SelectedRows not support inplace sum now");
auto *out = context.Output<SelectedRows>("Out");
out->mutable_rows()->clear();
auto *out_value = out->mutable_value();
// Runtime InferShape
size_t first_dim = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
first_dim += sel_row.rows().size();
}
std::vector<int64_t> in_dim;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() > 0) {
in_dim = framework::vectorize(sel_row.value().dims());
break;
}
}
if (in_dim.empty()) {
VLOG(3) << "WARNING: all the inputs are empty";
in_dim =
framework::vectorize(get_selected_row(in_num - 1).value().dims());
} else {
in_dim[0] = static_cast<int64_t>(first_dim);
}
std::vector<const paddle::framework::SelectedRows *> inputs;
out_value->Resize(framework::make_ddim(in_dim));
out_value->mutable_data<T>(context.GetPlace());
// if all the input sparse vars are empty, no need to
// merge these vars.
if (first_dim == 0UL) {
return;
for (auto &in_var : in_vars) {
inputs.push_back(&in_var->Get<SelectedRows>());
}
math::SelectedRowsAddTo<DeviceContext, T> functor;
int64_t offset = 0;
for (size_t i = 0; i < in_num; i++) {
auto &sel_row = get_selected_row(i);
if (sel_row.rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(out->height(), sel_row.height());
functor(context.template device_context<DeviceContext>(), sel_row,
offset, out);
offset += sel_row.value().numel();
}
math::scatter::MergeAdd<DeviceContext, T> merge_add;
merge_add(context.template device_context<DeviceContext>(), inputs, out);
} else if (out_var->IsType<framework::LoDTensorArray>()) {
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
......
......@@ -47,11 +47,22 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place):
scope = core.Scope()
self.height = 10
self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, True, True, True)
self.check_input_and_optput(scope, place, False, True, True)
self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False)
def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32")
for i in range(row_num):
array[i] *= i
return array
def check_input_and_optput(self,
scope,
place,
......@@ -71,28 +82,36 @@ class TestSelectedRowsSumOp(OpTest):
sum_op.run(scope, place)
has_data_w_num = 0
for w in [w1_has_data, w2_has_data, w3_has_data]:
if not w:
for has_data in [w1_has_data, w2_has_data, w3_has_data]:
if has_data:
has_data_w_num += 1
self.assertEqual(7 * has_data_w_num, len(out.rows()))
if has_data_w_num > 0:
self.assertEqual(len(out.rows()), 7)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(len(self.rows), self.row_numel) *
has_data_w_num))
else:
self.assertEqual(len(out.rows()), 0)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(0, self.row_numel) * has_data_w_num))
def create_selected_rows(self, scope, place, var_name, isEmpty):
def create_selected_rows(self, scope, place, var_name, has_data):
# create and initialize W Variable
if not isEmpty:
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
if has_data:
rows = self.rows
else:
rows = []
row_numel = 12
var = scope.var(var_name)
w_selected_rows = var.get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_height(self.height)
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_array = self._get_array(len(rows), self.row_numel)
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册