未验证 提交 407f883f 编写于 作者: C Chen Weihang 提交者: GitHub

Add SelectedRows support for dygraph DebugString (#21415)

* add selected rows support for debug string, test=develop

* refactor unittest of debug string, test=develop

* polish unittest name, test=develop
上级 9107bf20
......@@ -140,6 +140,22 @@ static std::string DebugString(
ss << "NOT_INITED";
}
ss << ">";
} else if (var.IsType<framework::SelectedRows>()) {
ss << "SelectedRows<";
auto& selected_rows = var.Get<framework::SelectedRows>();
auto& tensor = selected_rows.value();
auto& rows = selected_rows.rows();
if (tensor.IsInitialized()) {
ss << framework::DataTypeToString(tensor.type()) << ", ";
ss << tensor.place() << ", ";
ss << "height(" << selected_rows.height() << "), rows(";
std::for_each(rows.cbegin(), rows.cend(),
[&ss](const int64_t r) { ss << r << " "; });
ss << "), dims(" << tensor.dims() << ")";
} else {
ss << "NOT_INITED";
}
ss << ">";
} else {
ss << "UNRESOLVED_TYPE";
}
......
......@@ -61,34 +61,62 @@ std::string LayerDebugString(const std::string& op_type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs);
TEST(test_layer, test_debug_string_test_debug_Test) {
TEST(test_layer, test_debug_string) {
platform::CPUPlace place;
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vin_error(
new imperative::VarBase(false, "vin_error"));
std::shared_ptr<imperative::VarBase> vout(
new imperative::VarBase(false, "vout"));
std::shared_ptr<imperative::VarBase> vout_error(
new imperative::VarBase(false, "vout_error"));
vin_error->MutableVar()->GetMutable<framework::LoDTensor>();
vout->MutableVar()->GetMutable<framework::LoDTensor>();
vout_error->MutableVar()->GetMutable<framework::SelectedRows>();
var_pair in_pair = var_pair("X", vb_vector(1, vin));
vb_vector vb_in_error = {vin_error, nullptr};
var_pair vin_error_pair = var_pair("X", vb_in_error);
auto test_func = [&](std::shared_ptr<imperative::VarBase>& vout) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
var_pair vout_error_pair = var_pair("Out2", vb_vector(1, vout_error));
imperative::NameVarBaseMap ins = {in_pair};
imperative::NameVarBaseMap ins_error = {vin_error_pair};
imperative::NameVarBaseMap outs = {out_pair};
imperative::NameVarBaseMap outs_error = {vout_error_pair};
ASSERT_NO_FATAL_FAILURE(LayerDebugString("test_op", ins, outs));
std::string res = LayerDebugString("test_op", ins, outs_error);
ASSERT_TRUE(res.find("UNRESOLVED_TYPE") != std::string::npos);
std::string res2 = LayerDebugString("test_op", ins_error, outs_error);
VLOG(3) << res2;
ASSERT_TRUE(res2.find("NOT_INITED") != std::string::npos);
ASSERT_TRUE(res2.find("NULL") != std::string::npos);
return LayerDebugString("test_op", ins, outs);
};
// 1. test null
std::shared_ptr<imperative::VarBase> null_out(nullptr);
std::string res_null = test_func(null_out);
ASSERT_TRUE(res_null.find("NULL") != std::string::npos);
// 2. test uninit var
std::shared_ptr<imperative::VarBase> un_init_out(
new imperative::VarBase(false, "un_init_out"));
std::string res_un_init = test_func(un_init_out);
ASSERT_TRUE(res_un_init.find("NOT_INITED_VAR") != std::string::npos);
// 3. test unresolved type
std::shared_ptr<imperative::VarBase> ut_out(
new imperative::VarBase(false, "ut_out"));
ut_out->MutableVar()->GetMutable<framework::LoDTensorArray>();
std::string res_ut = test_func(ut_out);
ASSERT_TRUE(res_ut.find("UNRESOLVED_TYPE") != std::string::npos);
// 4. test uninit lod tensor
std::shared_ptr<imperative::VarBase> lod_tensor(
new imperative::VarBase(false, "lod_tensor"));
auto tensor_l = lod_tensor->MutableVar()->GetMutable<framework::LoDTensor>();
std::string res_ui_lod_t = test_func(lod_tensor);
ASSERT_TRUE(res_ui_lod_t.find("NOT_INITED") != std::string::npos);
// 5. test init lod tensor
tensor_l->mutable_data<float>(place);
std::string res_lod_t = test_func(lod_tensor);
ASSERT_TRUE(res_lod_t.find("LoDTensor") != std::string::npos);
// 6. test uninit selected rows
std::shared_ptr<imperative::VarBase> selected_rows(
new imperative::VarBase(false, "selected_rows"));
auto tensor_sr = selected_rows->MutableVar()
->GetMutable<framework::SelectedRows>()
->mutable_value();
std::string res_ui_sr = test_func(selected_rows);
ASSERT_TRUE(res_ui_sr.find("NOT_INITED") != std::string::npos);
// 7. test init selected rows
tensor_sr->mutable_data<float>(place);
std::string res_sr = test_func(selected_rows);
ASSERT_TRUE(res_sr.find("SelectedRows") != std::string::npos);
}
TEST(test_layer, test_clear_backward_info) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册