提交 c1f881e6 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2592 Keep parameters of previous step in TensorLoader

Merge pull request !2592 from ShidaHe/debugger_dev
...@@ -313,4 +313,10 @@ message TensorProto { ...@@ -313,4 +313,10 @@ message TensorProto {
// If the tensor content transferring is finished. // If the tensor content transferring is finished.
optional bool finished = 6; optional bool finished = 6;
// The iteration of the tensor. Supported: "prev" or leave empty.
optional string iter = 7;
// If the tensor name should be truncated.
optional bool truncate = 8;
} }
\ No newline at end of file
...@@ -178,7 +178,7 @@ void Debugger::CheckDatasetGraph() { ...@@ -178,7 +178,7 @@ void Debugger::CheckDatasetGraph() {
is_dataset_graph_ = false; is_dataset_graph_ = false;
} }
GraphProto Debugger::GetGraphProto() { GraphProto Debugger::GetGraphProto() const {
// convert kernel graph to debugger modelproto // convert kernel graph to debugger modelproto
ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_); ModelProto model = GetDebuggerFuncGraphProto(graph_ptr_);
return model.graph(); return model.graph();
...@@ -261,12 +261,9 @@ void Debugger::CommandLoop() { ...@@ -261,12 +261,9 @@ void Debugger::CommandLoop() {
MS_LOG(INFO) << "node name: " << node.node_name(); MS_LOG(INFO) << "node name: " << node.node_name();
MS_LOG(INFO) << "node type: " << node.node_type(); MS_LOG(INFO) << "node type: " << node.node_type();
} }
WatchCondition recieved_condition = GetWatchcondition(reply); MS_LOG(INFO) << "condition: " << GetWatchcondition(reply).condition();
MS_LOG(INFO) << "condition: " << recieved_condition.condition(); MS_LOG(INFO) << "id: " << GetWatchpointID(reply);
int32_t id = GetWatchpointID(reply); MS_LOG(INFO) << "delete: " << GetWatchpointDelete(reply);
MS_LOG(INFO) << "id: " << id;
bool delete_ = GetWatchpointDelete(reply);
MS_LOG(INFO) << "delete: " << delete_;
} }
MS_LOG(INFO) << "Setting watchpoint"; MS_LOG(INFO) << "Setting watchpoint";
if (GetWatchpointDelete(reply)) { if (GetWatchpointDelete(reply)) {
...@@ -284,15 +281,20 @@ void Debugger::CommandLoop() { ...@@ -284,15 +281,20 @@ void Debugger::CommandLoop() {
MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
MS_LOG(INFO) << "tensor slot: " << tensor.slot(); MS_LOG(INFO) << "tensor slot: " << tensor.slot();
MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
MS_LOG(INFO) << "tensor iter: " << tensor.iter();
MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
} }
} }
MS_LOG(INFO) << "Sending tensors"; MS_LOG(INFO) << "Sending tensors";
std::list<TensorProto> tensors = LoadTensors(GetTensors(reply)); std::list<TensorProto> tensors = LoadTensors(GetTensors(reply));
{ {
// print view cmd reply
for (auto tensor : tensors) { for (auto tensor : tensors) {
MS_LOG(INFO) << "tensor node name: " << tensor.node_name(); MS_LOG(INFO) << "tensor node name: " << tensor.node_name();
MS_LOG(INFO) << "tensor slot: " << tensor.slot(); MS_LOG(INFO) << "tensor slot: " << tensor.slot();
MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha; MS_LOG(INFO) << "tensor finished: " << std::boolalpha << tensor.finished() << std::noboolalpha;
MS_LOG(INFO) << "tensor iter: " << tensor.iter();
MS_LOG(INFO) << "tensor truncate: " << std::boolalpha << tensor.truncate() << std::noboolalpha;
MS_LOG(INFO) << "tensor dims: "; MS_LOG(INFO) << "tensor dims: ";
for (auto dim : tensor.dims()) { for (auto dim : tensor.dims()) {
MS_LOG(INFO) << dim << ","; MS_LOG(INFO) << dim << ",";
...@@ -309,68 +311,6 @@ void Debugger::CommandLoop() { ...@@ -309,68 +311,6 @@ void Debugger::CommandLoop() {
} }
} }
DebuggerCommand Debugger::GetCommand(const EventReply &reply) {
DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
switch (reply.cmd_case()) {
case debugger::EventReply::CmdCase::kExit:
cmd = DebuggerCommand::kExitCMD;
break;
case debugger::EventReply::CmdCase::kRunCmd:
cmd = DebuggerCommand::kRunCMD;
break;
case debugger::EventReply::CmdCase::kSetCmd:
cmd = DebuggerCommand::kSetCMD;
break;
case debugger::EventReply::CmdCase::kViewCmd:
cmd = DebuggerCommand::kViewCMD;
break;
default:
MS_LOG(ERROR) << "Error: UnknownCMD";
break;
}
return cmd;
}
ProtoVector<WatchNode> Debugger::GetWatchnodes(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
return ProtoVector<WatchNode>();
}
return reply.set_cmd().watch_nodes();
}
WatchCondition Debugger::GetWatchcondition(const EventReply &reply) {
if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
return WatchCondition();
}
return reply.set_cmd().watch_condition();
}
int32_t Debugger::GetWatchpointID(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
return 0;
}
return reply.set_cmd().id();
}
bool Debugger::GetWatchpointDelete(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
return false;
}
return reply.set_cmd().delete_();
}
ProtoVector<TensorProto> Debugger::GetTensors(const EventReply &reply) {
if (!reply.has_view_cmd()) {
MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
return ProtoVector<TensorProto>();
}
return reply.view_cmd().tensors();
}
void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id) { void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id) {
std::vector<std::tuple<std::string, bool>> check_node_list; std::vector<std::tuple<std::string, bool>> check_node_list;
std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list), std::transform(nodes.begin(), nodes.end(), std::back_inserter(check_node_list),
...@@ -383,7 +323,7 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon ...@@ -383,7 +323,7 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon
void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); } void Debugger::RemoveWatchpoint(const int32_t id) { debug_services_->remove_watchpoint(id); }
std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) { std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &tensors) const {
std::vector<std::string> name; std::vector<std::string> name;
std::vector<std::string> ret_name; std::vector<std::string> ret_name;
std::vector<char *> data_ptr; std::vector<char *> data_ptr;
...@@ -391,38 +331,42 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten ...@@ -391,38 +331,42 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten
std::vector<TypePtr> dtype; std::vector<TypePtr> dtype;
std::vector<std::vector<int>> shape; std::vector<std::vector<int>> shape;
std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), std::transform(tensors.begin(), tensors.end(), std::back_inserter(name), GetTensorFullName);
[](TensorProto tensor) -> std::string { return tensor.node_name() + ":" + tensor.slot(); });
// ret_name will contain tensor names that are found in TensorLoader
// items in ret_name will be in the same order with tensors if found
debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape); debug_services_->read_nodes_tensors(name, &ret_name, &data_ptr, &data_size, &dtype, &shape);
std::list<TensorProto> tensor_list; std::list<TensorProto> tensor_list;
unsigned int result_index = 0; unsigned int result_index = 0;
TensorProto tensor_item;
for (auto tensor : tensors) { for (auto tensor : tensors) {
TensorProto tensor_item;
tensor_item.set_node_name(tensor.node_name()); tensor_item.set_node_name(tensor.node_name());
tensor_item.set_slot(tensor.slot()); tensor_item.set_slot(tensor.slot());
tensor_item.set_iter(tensor.iter());
tensor_item.set_truncate(tensor.truncate());
tensor_item.clear_tensor_content();
tensor_item.clear_data_type();
tensor_item.clear_dims();
// always set finished to true before big tensor splitting is supported
tensor_item.set_finished(true); tensor_item.set_finished(true);
// return empty tensor if didn't find the requested tensor // return empty tensor if didn't find the requested tensor
if (result_index >= ret_name.size() || ret_name[result_index] != tensor.node_name() + ":" + tensor.slot()) { if (result_index >= ret_name.size() || ret_name[result_index] != GetTensorFullName(tensor)) {
tensor_list.push_back(tensor_item); tensor_list.push_back(tensor_item);
continue; continue;
} }
tensor_item.set_tensor_content(data_ptr[result_index], data_size[result_index]); tensor_item.set_tensor_content(data_ptr[result_index], data_size[result_index]);
tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index])); tensor_item.set_data_type(GetDebuggerNumberDataType(dtype[result_index]));
tensor_item.clear_dims();
for (auto &elem : shape[result_index]) { for (auto &elem : shape[result_index]) {
tensor_item.add_dims(elem); tensor_item.add_dims(elem);
} }
// add tensor to result list and increment result_index to check next item in ret_name
tensor_list.push_back(tensor_item); tensor_list.push_back(tensor_item);
result_index++; result_index++;
} }
return tensor_list; return tensor_list;
} }
...@@ -432,7 +376,7 @@ void Debugger::Exit() { ...@@ -432,7 +376,7 @@ void Debugger::Exit() {
std::exit(EXIT_FAILURE); std::exit(EXIT_FAILURE);
} }
std::list<WatchpointHit> Debugger::CheckWatchpoints() { std::list<WatchpointHit> Debugger::CheckWatchpoints() const {
std::vector<std::string> name; std::vector<std::string> name;
std::vector<std::string> slot; std::vector<std::string> slot;
std::vector<char *> data_ptr; std::vector<char *> data_ptr;
...@@ -442,31 +386,23 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() { ...@@ -442,31 +386,23 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() {
debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id); debug_services_->check_watchpoints(&name, &slot, &data_ptr, &data_size, &condition, &watchpoint_id);
std::list<WatchpointHit> points; std::list<WatchpointHit> hits;
for (unsigned int i = 0; i < name.size(); i++) { for (unsigned int i = 0; i < name.size(); i++) {
TensorProto *tensor_item; WatchpointHit hit;
tensor_item = new TensorProto(); hit.set_id(watchpoint_id[i]);
// here TensorProto act as a tensor indicator, not sending tensor content
TensorProto *tensor_item = hit.mutable_tensor();
tensor_item->set_node_name(name[i]); tensor_item->set_node_name(name[i]);
tensor_item->set_slot(slot[i]); tensor_item->set_slot(slot[i]);
tensor_item->set_tensor_content(data_ptr[i], data_size[i]);
// finished in TensorProto will always be true before we implement big tensor splitting
tensor_item->set_finished(true); tensor_item->set_finished(true);
WatchCondition *condition_item; WatchCondition *condition_item = hit.mutable_watch_condition();
condition_item = new WatchCondition();
condition_item->set_condition(debugger::WatchCondition_Condition(condition[i])); condition_item->set_condition(debugger::WatchCondition_Condition(condition[i]));
WatchpointHit point; hits.push_back(hit);
point.set_allocated_tensor(tensor_item);
point.set_allocated_watch_condition(condition_item);
point.set_id(watchpoint_id[i]);
points.push_back(point);
} }
return hits;
return points;
} }
void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points) { void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points) {
...@@ -481,8 +417,81 @@ void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points) ...@@ -481,8 +417,81 @@ void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points)
CommandLoop(); CommandLoop();
} }
DebugServices *Debugger::get_debug_services() { return debug_services_.get(); } DebugServices *Debugger::debug_services() const { return debug_services_.get(); }
bool Debugger::debugger_enabled() const { return debugger_enabled_; }
DebuggerCommand GetCommand(const EventReply &reply) {
DebuggerCommand cmd = DebuggerCommand::kUnknownCMD;
switch (reply.cmd_case()) {
case debugger::EventReply::CmdCase::kExit:
cmd = DebuggerCommand::kExitCMD;
break;
case debugger::EventReply::CmdCase::kRunCmd:
cmd = DebuggerCommand::kRunCMD;
break;
case debugger::EventReply::CmdCase::kSetCmd:
cmd = DebuggerCommand::kSetCMD;
break;
case debugger::EventReply::CmdCase::kViewCmd:
cmd = DebuggerCommand::kViewCMD;
break;
default:
MS_LOG(ERROR) << "Error: UnknownCMD";
break;
}
return cmd;
}
ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>().";
return ProtoVector<WatchNode>();
}
return reply.set_cmd().watch_nodes();
}
WatchCondition GetWatchcondition(const EventReply &reply) {
if (!reply.has_set_cmd() || !reply.set_cmd().has_watch_condition()) {
MS_LOG(ERROR) << "Error: Can not get WatchCondition from command. Returning default value: WatchCondition().";
return WatchCondition();
}
return reply.set_cmd().watch_condition();
}
int32_t GetWatchpointID(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0.";
return 0;
}
return reply.set_cmd().id();
}
bool Debugger::debugger_enabled() { return debugger_enabled_; } bool GetWatchpointDelete(const EventReply &reply) {
if (!reply.has_set_cmd()) {
MS_LOG(ERROR) << "Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false.";
return false;
}
return reply.set_cmd().delete_();
}
ProtoVector<TensorProto> GetTensors(const EventReply &reply) {
if (!reply.has_view_cmd()) {
MS_LOG(ERROR) << "Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>().";
return ProtoVector<TensorProto>();
}
return reply.view_cmd().tensors();
}
std::string GetTensorFullName(const TensorProto &tensor) {
string node_name = tensor.node_name();
if (tensor.truncate()) {
// scopes in node name are seperated by '/'
// use the name without scope if truncate is true
std::size_t found = node_name.find_last_of("/");
node_name = node_name.substr(found + 1);
}
return node_name + ":" + tensor.slot() + (tensor.iter() == "" ? "" : ":" + tensor.iter());
}
} // namespace mindspore } // namespace mindspore
...@@ -72,9 +72,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> { ...@@ -72,9 +72,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// suspend the execution after a debug_op // suspend the execution after a debug_op
void PostDebugOp(); void PostDebugOp();
DebugServices *get_debug_services(); DebugServices *debug_services() const;
bool debugger_enabled(); bool debugger_enabled() const;
private: private:
// private constructor for singleton // private constructor for singleton
...@@ -92,7 +92,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> { ...@@ -92,7 +92,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void CheckDatasetGraph(); void CheckDatasetGraph();
// serialize graph and get proto // serialize graph and get proto
GraphProto GetGraphProto(); GraphProto GetGraphProto() const;
// send graph and enter command wait loop // send graph and enter command wait loop
void SendGraphAndSuspend(const GraphProto &graph_proto); void SendGraphAndSuspend(const GraphProto &graph_proto);
...@@ -102,16 +102,6 @@ class Debugger : public std::enable_shared_from_this<Debugger> { ...@@ -102,16 +102,6 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// break if RunCMD // break if RunCMD
void CommandLoop(); void CommandLoop();
// process reply and command type
DebuggerCommand GetCommand(const EventReply &reply);
// parse other data out of EventReply
ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply);
WatchCondition GetWatchcondition(const EventReply &reply);
int32_t GetWatchpointID(const EventReply &reply);
bool GetWatchpointDelete(const EventReply &reply);
ProtoVector<TensorProto> GetTensors(const EventReply &reply);
// set what nodes and conditions to watch // set what nodes and conditions to watch
void SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id); void SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCondition &condition, const int32_t id);
...@@ -119,14 +109,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> { ...@@ -119,14 +109,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void RemoveWatchpoint(const int32_t id); void RemoveWatchpoint(const int32_t id);
// load tensor for view command // load tensor for view command
std::list<TensorProto> LoadTensors(const ProtoVector<TensorProto> &tensors); std::list<TensorProto> LoadTensors(const ProtoVector<TensorProto> &tensors) const;
// terminate training process // terminate training process
void Exit(); void Exit();
// analyze tensors and check watchpoint conditions // analyze tensors and check watchpoint conditions
// return names of tensors and what condition they hit // return names of tensors and what condition they hit
std::list<WatchpointHit> CheckWatchpoints(); std::list<WatchpointHit> CheckWatchpoints() const;
// send watchpoints that hit and enter command wait loop // send watchpoints that hit and enter command wait loop
void SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points); void SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points);
...@@ -155,5 +145,18 @@ ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph); ...@@ -155,5 +145,18 @@ ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph);
// for getting proto DataType from Type of Tensor // for getting proto DataType from Type of Tensor
DataType GetDebuggerNumberDataType(const TypePtr &type); DataType GetDebuggerNumberDataType(const TypePtr &type);
// process reply and command type
DebuggerCommand GetCommand(const EventReply &reply);
// parse other data out of EventReply
ProtoVector<WatchNode> GetWatchnodes(const EventReply &reply);
WatchCondition GetWatchcondition(const EventReply &reply);
int32_t GetWatchpointID(const EventReply &reply);
bool GetWatchpointDelete(const EventReply &reply);
ProtoVector<TensorProto> GetTensors(const EventReply &reply);
// get the full name of a tensor, which is the name used in TensorLoader
std::string GetTensorFullName(const TensorProto &tensor);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_ #endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <map> #include <map>
#include <tuple> #include <tuple>
#include <string> #include <string>
#include <utility>
#include "debug/tensor_data.h" #include "debug/tensor_data.h"
namespace mindspore { namespace mindspore {
class TensorLoader { class TensorLoader {
...@@ -29,7 +30,15 @@ class TensorLoader { ...@@ -29,7 +30,15 @@ class TensorLoader {
~TensorLoader() {} ~TensorLoader() {}
bool LoadNewTensor(std::shared_ptr<TensorData> tensor) { bool LoadNewTensor(std::shared_ptr<TensorData> tensor, bool keep_prev) {
if (keep_prev) {
// add prev step tensor into current step map with ":prev" suffix
auto handle = prev_tensor_list_map.extract(tensor->GetName());
if (!handle.empty()) {
handle.key() = tensor->GetName() + ":prev";
tensor_list_map.insert(std::move(handle));
}
}
tensor_list.push_back(tensor); tensor_list.push_back(tensor);
tensor_list_map.insert({tensor->GetName(), tensor}); tensor_list_map.insert({tensor->GetName(), tensor});
return true; return true;
...@@ -53,16 +62,20 @@ class TensorLoader { ...@@ -53,16 +62,20 @@ class TensorLoader {
} }
bool EmptyTensor() { bool EmptyTensor() {
tensor_list_map.clear(); prev_tensor_list_map.clear();
tensor_list_map.swap(prev_tensor_list_map);
tensor_list.clear(); tensor_list.clear();
return true; return true;
} }
void EmptyPrevTensor() { prev_tensor_list_map.clear(); }
void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; } void set_iter_num(uint32_t iter_num) { this->iter_num = iter_num; }
private: private:
std::vector<std::shared_ptr<TensorData>> tensor_list; std::vector<std::shared_ptr<TensorData>> tensor_list;
std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map; std::map<std::string, std::shared_ptr<TensorData>> tensor_list_map;
std::map<std::string, std::shared_ptr<TensorData>> prev_tensor_list_map;
uint32_t iter_num; uint32_t iter_num;
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -370,10 +370,10 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file ...@@ -370,10 +370,10 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order,
const std::string &host_fmt, const std::vector<int> &host_shape, const std::string &host_fmt, const std::vector<int> &host_shape,
TypeId host_type, size_t slot, Debugger *debugger) const { TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const {
bool ret = false; bool ret = false;
DebugServices *debug_services = debugger->get_debug_services(); DebugServices *debug_services = debugger->debug_services();
TensorLoader *tensor_loader = debug_services->get_tensor_loader(); TensorLoader *tensor_loader = debug_services->get_tensor_loader();
if (trans_flag) { if (trans_flag) {
...@@ -390,7 +390,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens ...@@ -390,7 +390,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data->SetExecutionOrder(execution_order); tensor_data->SetExecutionOrder(execution_order);
tensor_data->SetTensor(out_tensor); tensor_data->SetTensor(out_tensor);
tensor_data->SetSlot(slot); tensor_data->SetSlot(slot);
ret = tensor_loader->LoadNewTensor(tensor_data); ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
} else { } else {
mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape); mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(type_id_, host_shape);
size_t host_size = out_tensor->data().nbytes(); size_t host_size = out_tensor->data().nbytes();
...@@ -401,7 +401,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens ...@@ -401,7 +401,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data->SetExecutionOrder(execution_order); tensor_data->SetExecutionOrder(execution_order);
tensor_data->SetTensor(out_tensor); tensor_data->SetTensor(out_tensor);
tensor_data->SetSlot(slot); tensor_data->SetSlot(slot);
ret = tensor_loader->LoadNewTensor(tensor_data); ret = tensor_loader->LoadNewTensor(tensor_data, keep_prev);
if (ret_rt_memcpy != RT_ERROR_NONE) { if (ret_rt_memcpy != RT_ERROR_NONE) {
MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]"; MS_LOG(ERROR) << "SyncDeviceToHost: rtMemcpy mem size[" << size_ << "] fail, ret[" << ret_rt_memcpy << "]";
} }
......
...@@ -46,7 +46,8 @@ class AscendDeviceAddress : public DeviceAddress { ...@@ -46,7 +46,8 @@ class AscendDeviceAddress : public DeviceAddress {
#endif #endif
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger,
bool keep_prev) const;
#endif #endif
private: private:
......
...@@ -322,7 +322,8 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { ...@@ -322,7 +322,8 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) {
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); }); [](size_t inner_item) { return SizeToInt(inner_item); });
} }
auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger); auto ret =
ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, j, debugger, false);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name MS_LOG(ERROR) << "LoadMemToHost: flag:" << trans_flag << ", tensor_name:" << tensor_name
<< ", host_format:" << format << ".!"; << ", host_format:" << format << ".!";
...@@ -356,7 +357,8 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) ...@@ -356,7 +357,8 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger)
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes),
[](size_t inner_item) { return SizeToInt(inner_item); }); [](size_t inner_item) { return SizeToInt(inner_item); });
} }
auto ret = ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger); auto ret =
ascend_addr->LoadMemToHost(trans_flag, tensor_name, exec_order, format, int_shapes, type, 0, debugger, true);
if (!ret) { if (!ret) {
MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name MS_LOG(ERROR) << "LoadMemToHost Failed: flag:" << trans_flag << ", path:" << tensor_name
<< ", host_format:" << format << ".!"; << ", host_format:" << format << ".!";
......
...@@ -799,12 +799,13 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) ...@@ -799,12 +799,13 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance); MS_EXCEPTION_IF_NULL(runtime_instance);
DebugServices *debug_services = debugger_->get_debug_services(); DebugServices *debug_services = debugger_->debug_services();
TensorLoader *tensor_loader = debug_services->get_tensor_loader(); TensorLoader *tensor_loader = debug_services->get_tensor_loader();
tensor_loader->EmptyTensor(); tensor_loader->EmptyTensor();
uint32_t iter_num = tensor_loader->GetIterNum(); uint32_t iter_num = tensor_loader->GetIterNum();
tensor_loader->set_iter_num(++iter_num); tensor_loader->set_iter_num(++iter_num);
(void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get()); (void)runtime_instance->LoadData(kernel_graph.get(), debugger_.get());
tensor_loader->EmptyPrevTensor();
#endif #endif
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册