提交 52e925db 编写于 作者: M Margaret_wangrui

handle-get-summary-node

上级 44e3d167
...@@ -32,15 +32,16 @@ ...@@ -32,15 +32,16 @@
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "common/utils.h" #include "common/utils.h"
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/anf.h"
namespace mindspore { namespace mindspore {
namespace session { namespace session {
static std::shared_ptr<std::map<tensor::TensorPtr, ParameterPtr>> python_paras_; static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
void ClearPythonParasMap() { python_paras_ = nullptr; } void ClearPythonParasMap() { python_paras_ = nullptr; }
namespace { namespace {
const int kSummaryGetItem = 2; const int kSummaryGetItem = 2;
tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return nullptr; return nullptr;
} }
...@@ -50,14 +51,7 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) { ...@@ -50,14 +51,7 @@ tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) {
} }
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param()); auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
auto py_param = param_value->value(); auto py_param = param_value->value();
if (!py::hasattr(py_param, "default_input")) { return py_param.ptr();
return nullptr;
}
auto py_p_input = py_param.attr("default_input");
if (!py::hasattr(py_p_input, PYTHON_TENSOR_FLAG)) {
return nullptr;
}
return py_p_input.cast<std::shared_ptr<tensor::Tensor>>();
} }
void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) { void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
...@@ -375,15 +369,17 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -375,15 +369,17 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr; ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras_ == nullptr) { if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<tensor::TensorPtr, ParameterPtr>>(); python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
} }
if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) != kInvalidGraphId) { if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) {
new_parameter = (*python_paras_)[m_tensor]; new_parameter = (*python_paras_)[m_tensor];
} else { } else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (m_tensor != nullptr) { if (m_tensor != nullptr) {
(*python_paras_)[m_tensor] = new_parameter; (*python_paras_)[m_tensor] = new_parameter;
} }
TraceManager::EndTrace();
} }
graph_inputs->push_back(new_parameter); graph_inputs->push_back(new_parameter);
valid_inputs->push_back(valid_input); valid_inputs->push_back(valid_input);
......
...@@ -26,8 +26,6 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ ...@@ -26,8 +26,6 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
.attr("data_format", "required", "str", "all") \ .attr("data_format", "required", "str", "all") \
.input(0, "output_backprop", False, "required", "all") \ .input(0, "output_backprop", False, "required", "all") \
.output(0, "output", False, "required", "all") \ .output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.get_op_info() .get_op_info()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册