提交 3c307cf4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!552 Update clang format rule

Merge pull request !552 from ZhouFeng/clang-rule-update
...@@ -94,7 +94,7 @@ PenaltyBreakString: 1000 ...@@ -94,7 +94,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10 PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000 PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200 PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left PointerAlignment: Right
RawStringFormats: RawStringFormats:
- Language: Cpp - Language: Cpp
Delimiters: Delimiters:
......
...@@ -23,7 +23,7 @@ namespace common { ...@@ -23,7 +23,7 @@ namespace common {
const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_NUM = 1 << 8;
const int CACHED_STR_MASK = CACHED_STR_NUM - 1; const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM); std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
const char* SafeCStr(const std::string&& str) { const char *SafeCStr(const std::string &&str) {
static std::atomic<uint32_t> index{0}; static std::atomic<uint32_t> index{0};
uint32_t cur_index = index++; uint32_t cur_index = index++;
cur_index = cur_index & CACHED_STR_MASK; cur_index = cur_index & CACHED_STR_MASK;
......
...@@ -21,16 +21,16 @@ ...@@ -21,16 +21,16 @@
#include <string> #include <string>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \ #define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType&) = delete; \ ClassType(const ClassType &) = delete; \
ClassType& operator=(const ClassType&) = delete; ClassType &operator=(const ClassType &) = delete;
namespace mindspore { namespace mindspore {
namespace common { namespace common {
inline const char* SafeCStr(const std::string& str) { return str.c_str(); } inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
const char* SafeCStr(const std::string&& str); const char *SafeCStr(const std::string &&str);
static inline std::string GetEnv(const std::string& envvar) { static inline std::string GetEnv(const std::string &envvar) {
const char* value = ::getenv(envvar.c_str()); const char *value = ::getenv(envvar.c_str());
if (value == nullptr) { if (value == nullptr) {
return std::string(); return std::string();
......
...@@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { ...@@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {
~DecodeOp() = default; ~DecodeOp() = default;
Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override; Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
void Print(std::ostream& out) const override { out << "DecodeOp"; } void Print(std::ostream &out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private: private:
bool is_rgb_format_ = true; bool is_rgb_format_ = true;
......
...@@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int ...@@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int
rnd_.seed(seed_); rnd_.seed(seed_);
} }
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input, Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>>* output) { std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output); IO_CHECK_VECTOR(input, output);
if (input.size() != NumInput()) if (input.size() != NumInput())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
...@@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso ...@@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso
return Status::OK(); return Status::OK();
} }
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs, Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs,
std::vector<TensorShape>& outputs) { std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear(); outputs.clear();
TensorShape out = TensorShape{-1, -1}; TensorShape out = TensorShape{-1, -1};
...@@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp ...@@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp
if (!outputs.empty()) return Status::OK(); if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
} }
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) { Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = inputs[0]; outputs[0] = inputs[0];
return Status::OK(); return Status::OK();
......
...@@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp { ...@@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp {
~DistortBoundingBoxCropOp() override = default; ~DistortBoundingBoxCropOp() override = default;
void Print(std::ostream& out) const override { void Print(std::ostream &out) const override {
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
} }
Status Compute(const std::vector<std::shared_ptr<Tensor>>& input, Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>>* output) override; std::vector<std::shared_ptr<Tensor>> *output) override;
uint32_t NumInput() override { return 5; } uint32_t NumInput() override { return 5; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override; Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override; Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private: private:
int32_t max_attempts_; int32_t max_attempts_;
......
...@@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ ...@@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
rnd_.seed(GetSeed()); rnd_.seed(GetSeed());
} }
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) { Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output); IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
...@@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std: ...@@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
} }
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) { Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear(); outputs.clear();
TensorShape out = TensorShape{target_height_, target_width_}; TensorShape out = TensorShape{target_height_, target_width_};
...@@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs ...@@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
if (!outputs.empty()) return Status::OK(); if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
} }
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
double scale, aspect; double scale, aspect;
*crop_width = w_in; *crop_width = w_in;
*crop_height = h_in; *crop_height = h_in;
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace mindspore { namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy"; constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
} // namespace mindspore } // namespace mindspore
......
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
namespace mindspore { namespace mindspore {
struct ParamPtrEqual { struct ParamPtrEqual {
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
const ParameterPtr param1 = dyn_cast<Parameter>(t1); const ParameterPtr param1 = dyn_cast<Parameter>(t1);
const ParameterPtr param2 = dyn_cast<Parameter>(t2); const ParameterPtr param2 = dyn_cast<Parameter>(t2);
...@@ -52,7 +52,7 @@ struct ParamPtrEqual { ...@@ -52,7 +52,7 @@ struct ParamPtrEqual {
}; };
struct ParamPtrHasher { struct ParamPtrHasher {
std::size_t operator()(AnfNodePtr const& param) const { std::size_t operator()(AnfNodePtr const &param) const {
const ParameterPtr parameter = dyn_cast<Parameter>(param); const ParameterPtr parameter = dyn_cast<Parameter>(param);
if (parameter == nullptr) { if (parameter == nullptr) {
return 0; return 0;
...@@ -64,39 +64,39 @@ struct ParamPtrHasher { ...@@ -64,39 +64,39 @@ struct ParamPtrHasher {
class AnfExporter { class AnfExporter {
public: public:
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear(); func_graph_set.clear();
exported.clear(); exported.clear();
} }
virtual ~AnfExporter() {} virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs); void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);
protected: protected:
virtual std::string GetNodeType(const AnfNodePtr& nd); virtual std::string GetNodeType(const AnfNodePtr &nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param); int GetParamIndexFromExported(const AnfNodePtr &param);
std::string DumpObject(const py::object& obj, const std::string& category) const; std::string DumpObject(const py::object &obj, const std::string &category) const;
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetPrimitiveText(const PrimitivePtr& prim); std::string GetPrimitiveText(const PrimitivePtr &prim);
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetNameSpaceText(const parse::NameSpacePtr& ns); std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int>& apply_map); const std::map<AnfNodePtr, int> &apply_map);
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters, void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map); OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph); void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
int param_index; int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{}; OrderedSet<FuncGraphPtr> func_graph_set{};
...@@ -108,16 +108,16 @@ class AnfExporter { ...@@ -108,16 +108,16 @@ class AnfExporter {
abstract::AnfNodeConfigPtr node_cfg_ = nullptr; abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
}; };
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs); void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);
std::vector<FuncGraphPtr> ImportIR(const std::string& filename); std::vector<FuncGraphPtr> ImportIR(const std::string &filename);
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
...@@ -34,7 +34,7 @@ namespace draw { ...@@ -34,7 +34,7 @@ namespace draw {
namespace { namespace {
// Only for ValueNode // Only for ValueNode
std::string ValueType(const ValueNodePtr& node) { std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return ""; return "";
} }
...@@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { ...@@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
return v->type_name(); return v->type_name();
} }
std::string ReplaceSpecialChar(const std::string& str) { std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss; std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) { for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') { if (str[i] == '<') {
...@@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { ...@@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
} // namespace } // namespace
// API of debug utils // API of debug utils
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs, void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) { bool is_user) {
if (sub_graphs == nullptr) { if (sub_graphs == nullptr) {
return; return;
} }
for (auto& nd : nodes) { for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd); MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph(); auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) { if (sub_graph != nullptr) {
...@@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st ...@@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
} }
} }
void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) { OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) { if (sub_graphs == nullptr) {
return; return;
} }
int dup_idx = 0; int dup_idx = 0;
for (auto& nd : nodes) { for (auto &nd : nodes) {
for (auto& t : SuccIncoming(nd)) { for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd); MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
...@@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes, ...@@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
} }
} }
void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) { void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) { if (digraph == nullptr) {
return; return;
} }
...@@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD ...@@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
} }
// Draw edge // Draw edge
for (auto& nd : nodes) { for (auto &nd : nodes) {
auto succs = SuccIncoming(nd); auto succs = SuccIncoming(nd);
auto num = succs.size(); auto num = succs.size();
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
auto& t = succs.at(i); auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) { if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) { if ((!is_user) || (i != 0)) {
...@@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD ...@@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
} }
} }
void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) { void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
...@@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use ...@@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
DrawValueNodes(nodes, &sub_graphs); DrawValueNodes(nodes, &sub_graphs);
// Draw subgraph // Draw subgraph
for (const auto& gsub : sub_graphs) { for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second); digraph->SubGraph(gsub.first, gsub.second);
} }
...@@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use ...@@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
} }
#ifdef ENABLE_DUMP_IR #ifdef ENABLE_DUMP_IR
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot"; const std::string dot_suffix = ".dot";
std::string filename_with_suffix = std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false); DrawByOpt(filename_with_suffix, func_graph, false);
} }
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true); DrawByOpt(filename, func_graph, true);
} }
#else #else
void Draw(const std::string&, const FuncGraphPtr&) { void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
...@@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { ...@@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script."; << "please recompile source to enable it. See help of building script.";
} }
void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false; static bool already_printed = false;
if (already_printed) { if (already_printed) {
return; return;
...@@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { ...@@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
return "plaintext"; return "plaintext";
} }
std::string Graphviz::Color(const AnfNodePtr& node) { std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return ""; return "";
} }
...@@ -259,7 +259,7 @@ void BaseDigraph::Start() { ...@@ -259,7 +259,7 @@ void BaseDigraph::Start() {
buffer_ << "compound=true" << std::endl; buffer_ << "compound=true" << std::endl;
} }
void BaseDigraph::Head(const AnfNodePtr& node, int id) { void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
...@@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { ...@@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
} }
} }
void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) { if (node == nullptr) {
return; return;
} }
...@@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { ...@@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
buffer_ << ":" << idx; buffer_ << ":" << idx;
} }
void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) { if (func_graph == nullptr) {
return; return;
} }
...@@ -304,12 +304,12 @@ void BaseDigraph::End() { ...@@ -304,12 +304,12 @@ void BaseDigraph::End() {
} }
} }
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext "; buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>"; buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>"; buffer_ << "<tr><td>parameters</td></tr>";
int count = 0; int count = 0;
for (auto& parameter : key->parameters()) { for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>"; buffer_ << "<tr><td>";
buffer_ << parameter->ToString(); buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param(); auto py_p = dyn_cast<Parameter>(parameter)->default_param();
...@@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { ...@@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
buffer_ << "</table>>,];"; buffer_ << "</table>>,];";
} }
void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) { void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) { if (key == nullptr || gsub == nullptr) {
return; return;
} }
...@@ -361,12 +361,12 @@ Digraph::~Digraph() { ...@@ -361,12 +361,12 @@ Digraph::~Digraph() {
if (fout_.is_open()) { if (fout_.is_open()) {
fout_.close(); fout_.close();
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_; MS_LOG(ERROR) << "Exception when closing file " << filename_;
} }
} }
static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0; size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) { while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to); (void)str.replace(start_pos, from.length(), to);
...@@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st ...@@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
return str; return str;
} }
static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj); MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node) graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>"; << "'>";
...@@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { ...@@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</td></tr>"; graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>"; graph_obj->buffer() << "<tr><td align='left'>";
int i = 0; int i = 0;
for (const auto& attr : attrs) { for (const auto &attr : attrs) {
if (i != 0) { if (i != 0) {
graph_obj->buffer() << "<br/>"; graph_obj->buffer() << "<br/>";
} }
...@@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { ...@@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</table>>,"; graph_obj->buffer() << "</table>>,";
} }
static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) { if (graph_obj == nullptr || node == nullptr) {
return; return;
} }
...@@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { ...@@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
} }
} }
static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) { if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return; return;
} }
...@@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { ...@@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
} }
graph_obj->buffer() << ">"; graph_obj->buffer() << ">";
int i = 0; int i = 0;
for (auto& attr : attrs) { for (auto &attr : attrs) {
if (i != 0) { if (i != 0) {
graph_obj->buffer() << "<br/>"; graph_obj->buffer() << "<br/>";
} }
...@@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { ...@@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
if (fout_.is_open()) { if (fout_.is_open()) {
fout_.close(); fout_.close();
} }
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_; MS_LOG(ERROR) << "exception when closing file " << filename_;
} }
} }
......
...@@ -31,9 +31,9 @@ namespace parse = mindspore::parse; ...@@ -31,9 +31,9 @@ namespace parse = mindspore::parse;
class Graphviz { class Graphviz {
public: public:
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}
explicit Graphviz(const std::string& name) : name_(name) {} explicit Graphviz(const std::string &name) : name_(name) {}
virtual ~Graphviz() {} virtual ~Graphviz() {}
...@@ -41,8 +41,8 @@ class Graphviz { ...@@ -41,8 +41,8 @@ class Graphviz {
virtual void End() {} virtual void End() {}
virtual std::string Shape(AnfNodePtr node); virtual std::string Shape(AnfNodePtr node);
std::string Color(const AnfNodePtr& node); std::string Color(const AnfNodePtr &node);
std::ostringstream& buffer() { return buffer_; } std::ostringstream &buffer() { return buffer_; }
std::ostringstream buffer_; std::ostringstream buffer_;
protected: protected:
...@@ -53,8 +53,8 @@ class Graphviz { ...@@ -53,8 +53,8 @@ class Graphviz {
class BaseDigraph : public Graphviz { class BaseDigraph : public Graphviz {
public: public:
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string& name) : Graphviz(name) {} explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
~BaseDigraph() override = default; ~BaseDigraph() override = default;
virtual void Node(AnfNodePtr node, int id = 0) = 0; virtual void Node(AnfNodePtr node, int id = 0) = 0;
...@@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { ...@@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
void Start() override; void Start() override;
void End() override; void End() override;
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
void FuncGraphParameters(const FuncGraphPtr& key); void FuncGraphParameters(const FuncGraphPtr &key);
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub); void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);
const std::string& name() const { return name_; } const std::string &name() const { return name_; }
protected: protected:
void Head(const AnfNodePtr& node, int id = 0); void Head(const AnfNodePtr &node, int id = 0);
void Tail(const AnfNodePtr& node, int idx, int id = 0); void Tail(const AnfNodePtr &node, int idx, int id = 0);
void Tail(const FuncGraphPtr& func_graph); void Tail(const FuncGraphPtr &func_graph);
}; };
class Digraph : public BaseDigraph { class Digraph : public BaseDigraph {
public: public:
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string& name) : BaseDigraph(name) {} explicit Digraph(const std::string &name) : BaseDigraph(name) {}
~Digraph() override; ~Digraph() override;
void Node(AnfNodePtr node, int id = 0) override; void Node(AnfNodePtr node, int id = 0) override;
...@@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { ...@@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {
class ModelDigraph : public BaseDigraph { class ModelDigraph : public BaseDigraph {
public: public:
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
~ModelDigraph() override; ~ModelDigraph() override;
std::string Shape(AnfNodePtr node) override; std::string Shape(AnfNodePtr node) override;
...@@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { ...@@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
}; };
// API to draw // API to draw
void Draw(const std::string& filename, const FuncGraphPtr& func_graph); void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
} // namespace draw } // namespace draw
} // namespace mindspore } // namespace mindspore
......
此差异已折叠。
...@@ -36,7 +36,7 @@ Dump::Dump() ...@@ -36,7 +36,7 @@ Dump::Dump()
dump_iter_(0), dump_iter_(0),
cur_iter_(0) {} cur_iter_(0) {}
bool Dump::IsKernelNeedDump(const std::string& kernel_name) { bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
if (dump_mode_ == 0) { if (dump_mode_ == 0) {
// Dump All Kernels mode // Dump All Kernels mode
return true; return true;
...@@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { ...@@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
return false; return false;
} }
bool Dump::ParseDumpConfig(const std::string& dump_config_file) { bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
std::ifstream jsonFile(dump_config_file); std::ifstream jsonFile(dump_config_file);
if (!jsonFile.is_open()) { if (!jsonFile.is_open()) {
MS_LOG(ERROR) << dump_config_file << " open failed."; MS_LOG(ERROR) << dump_config_file << " open failed.";
...@@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { ...@@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
return true; return true;
} }
bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
...@@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { ...@@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
return true; return true;
} }
bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto trans_flag = dumpSettings.at("trans_flag"); auto trans_flag = dumpSettings.at("trans_flag");
auto enable = dumpSettings.at("enable"); auto enable = dumpSettings.at("enable");
auto mode = dumpSettings.at("mode"); auto mode = dumpSettings.at("mode");
...@@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { ...@@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
dump_path_ = path; dump_path_ = path;
dump_net_name_ = net_name; dump_net_name_ = net_name;
dump_iter_ = iteration; dump_iter_ = iteration;
for (const auto& kernel : kernels) { for (const auto &kernel : kernels) {
dump_kernels_.push_back(kernel); dump_kernels_.push_back(kernel);
} }
return true; return true;
} }
bool Dump::SetDumpConfFromJsonFile() { bool Dump::SetDumpConfFromJsonFile() {
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) { if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else { } else {
...@@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { ...@@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
return ParseDumpConfig(dump_config_file); return ParseDumpConfig(dump_config_file);
} }
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
if (filename.empty() || data == nullptr || len == 0) { if (filename.empty() || data == nullptr || len == 0) {
MS_LOG(ERROR) << "Incorrect parameter."; MS_LOG(ERROR) << "Incorrect parameter.";
return false; return false;
...@@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) ...@@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
MS_LOG(ERROR) << "Open file " << realpath << " fail."; MS_LOG(ERROR) << "Open file " << realpath << " fail.";
return false; return false;
} }
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len)); (void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
fd.close(); fd.close();
return true; return true;
} }
bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
MS_EXCEPTION_IF_NULL(outpath); MS_EXCEPTION_IF_NULL(outpath);
auto path_split_pos = inpath.find_last_of('/'); auto path_split_pos = inpath.find_last_of('/');
if (path_split_pos == std::string::npos) { if (path_split_pos == std::string::npos) {
...@@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { ...@@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
return true; return true;
} }
bool Dump::CreateNotExistDirs(const std::string& path) { bool Dump::CreateNotExistDirs(const std::string &path) {
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem(); std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs); MS_EXCEPTION_IF_NULL(fs);
char temp_path[PATH_MAX] = {0}; char temp_path[PATH_MAX] = {0};
......
...@@ -43,11 +43,11 @@ class Dump { ...@@ -43,11 +43,11 @@ class Dump {
uint32_t cur_iter() const { return cur_iter_; } uint32_t cur_iter() const { return cur_iter_; }
bool IsKernelNeedDump(const std::string& kernel_name); bool IsKernelNeedDump(const std::string &kernel_name);
bool SetDumpConfFromJsonFile(); bool SetDumpConfFromJsonFile();
static bool DumpToFile(const std::string& filename, const void* data, size_t len); static bool DumpToFile(const std::string &filename, const void *data, size_t len);
protected: protected:
bool dump_enable_; bool dump_enable_;
...@@ -59,14 +59,14 @@ class Dump { ...@@ -59,14 +59,14 @@ class Dump {
uint32_t cur_iter_; uint32_t cur_iter_;
std::vector<std::string> dump_kernels_; std::vector<std::string> dump_kernels_;
static bool GetRealPath(const std::string& inpath, std::string* outpath); static bool GetRealPath(const std::string &inpath, std::string *outpath);
static bool CreateNotExistDirs(const std::string& path); static bool CreateNotExistDirs(const std::string &path);
private: private:
bool ParseDumpConfig(const std::string& dump_config_file); bool ParseDumpConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json& dumpSettings); bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json& dumpSettings); bool IsConfigValid(const nlohmann::json &dumpSettings);
}; };
using DumpConfPtr = std::shared_ptr<Dump>; using DumpConfPtr = std::shared_ptr<Dump>;
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"
namespace mindspore { namespace mindspore {
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
std::string temp_line = line; std::string temp_line = line;
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
tip != kSourceLineTipDiscard) { tip != kSourceLineTipDiscard) {
...@@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { ...@@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
name_ = ""; name_ = "";
} }
DebugInfo::DebugInfo(const std::string& name) { DebugInfo::DebugInfo(const std::string &name) {
InitValueFromContext(); InitValueFromContext();
unique_id_ = gen_unique_id(); unique_id_ = gen_unique_id();
debug_id_ = -1; debug_id_ = -1;
name_ = name; name_ = name;
} }
DebugInfo::DebugInfo(const LocationPtr& loc) { DebugInfo::DebugInfo(const LocationPtr &loc) {
InitValueFromContext(); InitValueFromContext();
unique_id_ = gen_unique_id(); unique_id_ = gen_unique_id();
debug_id_ = -1; debug_id_ = -1;
...@@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { ...@@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
} }
int64_t DebugInfo::unique_id_through_copy() const { int64_t DebugInfo::unique_id_through_copy() const {
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info(); TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
if (trace_info != nullptr) { if (trace_info != nullptr) {
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) { if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
return trace_info->debug_info()->unique_id_through_copy(); return trace_info->debug_info()->unique_id_through_copy();
...@@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { ...@@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
} }
return DebugInfo::location(); return DebugInfo::location();
} }
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }
TraceContextPtr TraceManager::CurrentContextInfo() { TraceContextPtr TraceManager::CurrentContextInfo() {
if (!TraceManager::trace_context_stack_.empty()) { if (!TraceManager::trace_context_stack_.empty()) {
...@@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { ...@@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
return nullptr; return nullptr;
} }
void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location); TraceContextPtr context = std::make_shared<TraceContext>(location);
context->set_func_name(func_name); context->set_func_name(func_name);
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const LocationPtr& location) { void TraceManager::DebugTrace(const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location); TraceContextPtr context = std::make_shared<TraceContext>(location);
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) { if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
} }
...@@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { ...@@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
TraceManager::trace_context_stack_.push(context); TraceManager::trace_context_stack_.push(context);
} }
void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) { if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
} }
......
...@@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou ...@@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
// Location class record the location in source code. // Location class record the location in source code.
class Location { class Location {
public: public:
Location(const std::string& file_name, int line, int column, int line_end, int column_end) Location(const std::string &file_name, int line, int column, int line_end, int column_end)
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
Location(const Location& loc) Location(const Location &loc)
: file_name_(loc.file_name_), : file_name_(loc.file_name_),
line_(loc.line_), line_(loc.line_),
column_(loc.column_), column_(loc.column_),
...@@ -77,21 +77,21 @@ class TraceManager { ...@@ -77,21 +77,21 @@ class TraceManager {
TraceManager() = default; TraceManager() = default;
~TraceManager() = default; ~TraceManager() = default;
static TraceContextPtr CurrentContextInfo(); static TraceContextPtr CurrentContextInfo();
static void DebugTrace(const std::string& func_name, const LocationPtr& location); static void DebugTrace(const std::string &func_name, const LocationPtr &location);
static void DebugTrace(const LocationPtr& location); static void DebugTrace(const LocationPtr &location);
static void DebugTrace(const TraceInfoPtr& trace_info); static void DebugTrace(const TraceInfoPtr &trace_info);
// debug trace with a cloned trace info with debug_info // debug trace with a cloned trace info with debug_info
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
static void EndTrace(); static void EndTrace();
static std::stack<TraceContextPtr> trace_context_stack_; static std::stack<TraceContextPtr> trace_context_stack_;
}; };
class TraceGuard { class TraceGuard {
public: public:
explicit TraceGuard(const std::string func_name, const LocationPtr& location) { explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location); TraceManager::DebugTrace(func_name, location);
} }
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
~TraceGuard() { TraceManager::EndTrace(); } ~TraceGuard() { TraceManager::EndTrace(); }
}; };
...@@ -106,23 +106,23 @@ class TraceContext { ...@@ -106,23 +106,23 @@ class TraceContext {
public: public:
~TraceContext() = default; ~TraceContext() = default;
explicit TraceContext(const LocationPtr& loc) { explicit TraceContext(const LocationPtr &loc) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
location_ = loc; location_ = loc;
} }
explicit TraceContext(const std::string& func_name) { explicit TraceContext(const std::string &func_name) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
func_name_ = func_name; func_name_ = func_name;
} }
explicit TraceContext(const TraceInfoPtr& trace_info) { explicit TraceContext(const TraceInfoPtr &trace_info) {
ProcessAttributeFromContext(); ProcessAttributeFromContext();
trace_info_ = trace_info; trace_info_ = trace_info;
} }
void set_location(const LocationPtr& loc) { location_ = loc; } void set_location(const LocationPtr &loc) { location_ = loc; }
LocationPtr location() { return location_; } LocationPtr location() { return location_; }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; } TraceInfoPtr trace_info() { return trace_info_; }
void set_func_name(const std::string& func_name) { func_name_ = func_name; } void set_func_name(const std::string &func_name) { func_name_ = func_name; }
std::string func_name() { return func_name_; } std::string func_name() { return func_name_; }
}; };
...@@ -130,9 +130,9 @@ class DebugInfo : public Base { ...@@ -130,9 +130,9 @@ class DebugInfo : public Base {
public: public:
DebugInfo(); DebugInfo();
explicit DebugInfo(const std::string& name); explicit DebugInfo(const std::string &name);
explicit DebugInfo(const LocationPtr& loc); explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default; virtual ~DebugInfo() = default;
MS_DECLARE_PARENT(DebugInfo, Base); MS_DECLARE_PARENT(DebugInfo, Base);
...@@ -141,12 +141,12 @@ class DebugInfo : public Base { ...@@ -141,12 +141,12 @@ class DebugInfo : public Base {
int64_t unique_id_through_copy() const; int64_t unique_id_through_copy() const;
std::string get_id() { return std::to_string(debug_id()); } std::string get_id() { return std::to_string(debug_id()); }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; } TraceInfoPtr trace_info() { return trace_info_; }
void set_location(const LocationPtr& loc) { location_ = loc; } void set_location(const LocationPtr &loc) { location_ = loc; }
virtual LocationPtr location() { return location_; } virtual LocationPtr location() { return location_; }
std::string name() { return name_; } std::string name() { return name_; }
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
virtual std::string debug_name(); virtual std::string debug_name();
virtual std::string get_python_func_belonged() { return ""; } virtual std::string get_python_func_belonged() { return ""; }
...@@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { ...@@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
py_func_belonged_ = context_info->func_name(); py_func_belonged_ = context_info->func_name();
} }
} }
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) { if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo(); auto context_info = TraceManager::CurrentContextInfo();
py_func_belonged_ = context_info->func_name(); py_func_belonged_ = context_info->func_name();
...@@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { ...@@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
~NodeDebugInfo() override = default; ~NodeDebugInfo() override = default;
std::string debug_name() override; std::string debug_name() override;
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); } void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); } std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
std::string get_python_func_belonged() override { return py_func_belonged_; } std::string get_python_func_belonged() override { return py_func_belonged_; }
AnfNodeWeakPtr node_; AnfNodeWeakPtr node_;
std::string py_func_belonged_; std::string py_func_belonged_;
...@@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { ...@@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
} }
} }
explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) { if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo(); auto context_info = TraceManager::CurrentContextInfo();
py_func_name_ = context_info->func_name(); py_func_name_ = context_info->func_name();
...@@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { ...@@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
std::string debug_name() override; std::string debug_name() override;
LocationPtr location() override; LocationPtr location() override;
LocationPtr deco_location() { return deco_loc_; } LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); } FuncGraphPtr get_graph() const { return func_graph_.lock(); }
void set_full_name(const std::string& name) { full_name_ = name; } void set_full_name(const std::string &name) { full_name_ = name; }
std::string get_full_name() { return full_name_; } std::string get_full_name() { return full_name_; }
void set_deco_location(const LocationPtr& deco_list_loc); void set_deco_location(const LocationPtr &deco_list_loc);
std::string get_python_func_belonged() override { return py_func_name_; } std::string get_python_func_belonged() override { return py_func_name_; }
FuncGraphWeakPtr func_graph_; FuncGraphWeakPtr func_graph_;
LocationPtr deco_loc_; LocationPtr deco_loc_;
......
...@@ -31,7 +31,7 @@ struct NameWithTrace { ...@@ -31,7 +31,7 @@ struct NameWithTrace {
std::string name; std::string name;
std::vector<std::string> trace_labels; std::vector<std::string> trace_labels;
}; };
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
switch (trace_label) { switch (trace_label) {
case TraceLabelType::kShortSymbol: case TraceLabelType::kShortSymbol:
return trace_info->symbol(); return trace_info->symbol();
...@@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t ...@@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
} }
} }
NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name; NameWithTrace trace_name;
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
auto temp_info = debug_info; auto temp_info = debug_info;
...@@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe ...@@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
return trace_name; return trace_name;
} }
std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) { std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
std::string tags = ""; std::string tags = "";
for (auto& itr : trace_labels) { for (auto &itr : trace_labels) {
std::string symbol = itr; std::string symbol = itr;
tags = tags + symbol; tags = tags + symbol;
} }
...@@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st ...@@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
} }
// get the label name of the node debug info // get the label name of the node debug info
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name = RootName(debug_info, trace_label); NameWithTrace trace_name = RootName(debug_info, trace_label);
return CombineTraceTypes(trace_name.name, trace_name.trace_labels); return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
} }
std::string CombineUniqueID(const DebugInfoPtr& debug_info) { std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
auto temp_info = debug_info; auto temp_info = debug_info;
std::string label = ""; std::string label = "";
while (temp_info != nullptr) { while (temp_info != nullptr) {
...@@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) { ...@@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
} }
// get trace with unique id chain // get trace with unique id chain
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); } std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) { if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
return LabelStringUnique(debug_info); return LabelStringUnique(debug_info);
} }
......
...@@ -29,7 +29,7 @@ namespace label_manage { ...@@ -29,7 +29,7 @@ namespace label_manage {
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId }; enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
TraceLabelType GetGlobalTraceLabelType(); TraceLabelType GetGlobalTraceLabelType();
void SetGlobalTraceLabelType(TraceLabelType label_type); void SetGlobalTraceLabelType(TraceLabelType label_type);
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol); std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
} // namespace label_manage } // namespace label_manage
} // namespace mindspore } // namespace mindspore
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
namespace mindspore { namespace mindspore {
// namespace to support debug trace infomation // namespace to support debug trace infomation
namespace trace { namespace trace {
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) { std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
if (abs == nullptr) { if (abs == nullptr) {
return "Null Abstract"; return "Null Abstract";
} }
...@@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { ...@@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
return debug_with_loc_vec; return debug_with_loc_vec;
} }
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
if (debug_with_loc_vec.size() > 0) { if (debug_with_loc_vec.size() > 0) {
return debug_with_loc_vec[0]; return debug_with_loc_vec[0];
...@@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { ...@@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
} }
} }
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }
...@@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { ...@@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
// a trace info identifies a node transform, so we can trace the node transform through // a trace info identifies a node transform, so we can trace the node transform through
// a link of trace info and debug info // a link of trace info and debug info
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) { std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
if (info_vec.size() < 1) { if (info_vec.size() < 1) {
return ""; return "";
} }
...@@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL ...@@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
return traced_info; return traced_info;
} }
std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }
...@@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { ...@@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
return ""; return "";
} }
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss; std::ostringstream oss;
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
...@@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So ...@@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
return oss.str(); return oss.str();
} }
std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
std::ostringstream oss; std::ostringstream oss;
oss << "graph:" << graph->ToString() << " with args["; oss << "graph:" << graph->ToString() << " with args[";
auto params = graph->parameters(); auto params = graph->parameters();
...@@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas ...@@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
return oss.str(); return oss.str();
} }
void DumpInferStack(std::ostringstream& oss) { void DumpInferStack(std::ostringstream &oss) {
auto& infer_stack = GetCurrenGraphInferStack(); auto &infer_stack = GetCurrenGraphInferStack();
if (infer_stack.empty()) { if (infer_stack.empty()) {
return; return;
} }
...@@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { ...@@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
} }
std::reverse(infer_vec.begin(), infer_vec.end()); std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0; int index = 0;
for (auto& item : infer_vec) { for (auto &item : infer_vec) {
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first); auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
if (graph_infer == nullptr) { if (graph_infer == nullptr) {
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
...@@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { ...@@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
} }
void TraceGraphInfer() { void TraceGraphInfer() {
auto& infer_stack = GetCurrenGraphInferStack(); auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss; std::ostringstream oss;
if (infer_stack.empty()) { if (infer_stack.empty()) {
return; return;
...@@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { ...@@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default; ~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs); void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);
private: private:
std::string GetNodeType(const AnfNodePtr& nd) override; std::string GetNodeType(const AnfNodePtr &nd) override;
}; };
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() { std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs; std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack(); auto &list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) { for (size_t i = 0; i < list.size(); ++i) {
auto node_cfg = list[i]; auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph(); auto fg = node_cfg->context()->func_graph();
...@@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { ...@@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
} }
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
if (node_cfg_ == nullptr) { if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node); return AnfExporter::GetNodeType(node);
} }
...@@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { ...@@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
return oss.str(); return oss.str();
} }
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) { const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
if (node_cfgs.empty()) { if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty"; MS_LOG(DEBUG) << "Node configs is empty";
return; return;
...@@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, ...@@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
auto tagged_func_graphs = CalcTaggedFuncGraphs(); auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output graph on the analysis stack // first output graph on the analysis stack
for (const auto& node_cfg : node_cfgs) { for (const auto &node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph(); auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it // the graph is already output, skip it
if (exported.find(fg) != exported.end()) { if (exported.find(fg) != exported.end()) {
...@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, ...@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
ofs.close(); ofs.close();
} }
void GetInferStackInfo(std::ostringstream& oss) { void GetInferStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin"; MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack(); auto stack = GetCNodeDebugStack();
if (stack.empty()) { if (stack.empty()) {
...@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { ...@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info // trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
...@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An ...@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
} }
} }
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
...@@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { ...@@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
} }
} }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; } std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() { std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
return graph_infer_stack; return graph_infer_stack;
} }
void ClearTraceStack() { void ClearTraceStack() {
......
...@@ -31,19 +31,19 @@ ...@@ -31,19 +31,19 @@
namespace mindspore { namespace mindspore {
namespace trace { namespace trace {
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine); SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer(); void TraceGraphInfer();
void GetInferStackInfo(std::ostringstream& oss); void GetInferStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave(); void TraceInferCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack(); std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack(); std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack(); void ClearTraceStack();
} // namespace trace } // namespace trace
} // namespace mindspore } // namespace mindspore
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h" #include "pipeline/parse/python_adapter.h"
namespace mindspore { namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
if (info == nullptr) { if (info == nullptr) {
return ""; return "";
} }
......
...@@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>; ...@@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>;
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
class TraceInfo : public Base { class TraceInfo : public Base {
public: public:
TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) {
symbol_ = symbol; symbol_ = symbol;
full_name_ = full_name; full_name_ = full_name;
name_ = full_name_; name_ = full_name_;
debug_info_ = info; debug_info_ = info;
} }
TraceInfo(const TraceInfo& info) TraceInfo(const TraceInfo &info)
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
virtual ~TraceInfo() = default; virtual ~TraceInfo() = default;
MS_DECLARE_PARENT(TraceInfo, Base); MS_DECLARE_PARENT(TraceInfo, Base);
...@@ -55,8 +55,8 @@ class TraceInfo : public Base { ...@@ -55,8 +55,8 @@ class TraceInfo : public Base {
virtual std::string full_name() { return full_name_; } virtual std::string full_name() { return full_name_; }
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); } virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
virtual std::string action_name() { return ""; } virtual std::string action_name() { return ""; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); virtual std::string GetActionBetweenNode(const DebugInfoPtr &info);
void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; }
DebugInfoPtr debug_info() { return debug_info_; } DebugInfoPtr debug_info() { return debug_info_; }
DebugInfoPtr DebugInfoHasLoc(); DebugInfoPtr DebugInfoHasLoc();
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo(); std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
...@@ -70,7 +70,7 @@ class TraceInfo : public Base { ...@@ -70,7 +70,7 @@ class TraceInfo : public Base {
class TracePhi : public TraceInfo { class TracePhi : public TraceInfo {
public: public:
explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {}
MS_DECLARE_PARENT(TracePhi, TraceInfo); MS_DECLARE_PARENT(TracePhi, TraceInfo);
~TracePhi() override = default; ~TracePhi() override = default;
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); } TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
...@@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { ...@@ -78,8 +78,8 @@ class TracePhi : public TraceInfo {
class TraceIfStmtTrueBranch : public TraceInfo { class TraceIfStmtTrueBranch : public TraceInfo {
public: public:
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {}
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
~TraceIfStmtTrueBranch() override = default; ~TraceIfStmtTrueBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { ...@@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo {
class TraceIfStmtFalseBranch : public TraceInfo { class TraceIfStmtFalseBranch : public TraceInfo {
public: public:
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {}
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
~TraceIfStmtFalseBranch() override = default; ~TraceIfStmtFalseBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { ...@@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo {
class TraceIfStmtAfterBranch : public TraceInfo { class TraceIfStmtAfterBranch : public TraceInfo {
public: public:
explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {}
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
~TraceIfStmtAfterBranch() override = default; ~TraceIfStmtAfterBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { ...@@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo {
class TraceIfExpTrueBranch : public TraceInfo { class TraceIfExpTrueBranch : public TraceInfo {
public: public:
explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {}
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
~TraceIfExpTrueBranch() override = default; ~TraceIfExpTrueBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { ...@@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo {
class TraceIfExpFalseBranch : public TraceInfo { class TraceIfExpFalseBranch : public TraceInfo {
public: public:
explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {}
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
~TraceIfExpFalseBranch() override = default; ~TraceIfExpFalseBranch() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { ...@@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo {
class TraceCopy : public TraceInfo { class TraceCopy : public TraceInfo {
public: public:
TraceCopy() : TraceInfo(nullptr, "copy", "") {} TraceCopy() : TraceInfo(nullptr, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {}
MS_DECLARE_PARENT(TraceCopy, TraceInfo); MS_DECLARE_PARENT(TraceCopy, TraceInfo);
~TraceCopy() override = default; ~TraceCopy() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); } TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
...@@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { ...@@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo {
class TraceIterator : public TraceInfo { class TraceIterator : public TraceInfo {
public: public:
explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {}
MS_DECLARE_PARENT(TraceIterator, TraceInfo); MS_DECLARE_PARENT(TraceIterator, TraceInfo);
~TraceIterator() override = default; ~TraceIterator() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); } TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
...@@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { ...@@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo {
class TraceWhileHeader : public TraceInfo { class TraceWhileHeader : public TraceInfo {
public: public:
explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {}
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
~TraceWhileHeader() override = default; ~TraceWhileHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
...@@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { ...@@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo {
class TraceWhileBody : public TraceInfo { class TraceWhileBody : public TraceInfo {
public: public:
explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {}
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
~TraceWhileBody() override = default; ~TraceWhileBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
...@@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { ...@@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo {
class TraceWhileAfter : public TraceInfo { class TraceWhileAfter : public TraceInfo {
public: public:
explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {}
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
~TraceWhileAfter() override = default; ~TraceWhileAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); } TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
...@@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { ...@@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo {
class TraceForHeader : public TraceInfo { class TraceForHeader : public TraceInfo {
public: public:
explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {}
MS_DECLARE_PARENT(TraceForHeader, TraceInfo); MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
~TraceForHeader() override = default; ~TraceForHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
...@@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { ...@@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo {
class TraceForBody : public TraceInfo { class TraceForBody : public TraceInfo {
public: public:
explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {}
MS_DECLARE_PARENT(TraceForBody, TraceInfo); MS_DECLARE_PARENT(TraceForBody, TraceInfo);
~TraceForBody() override = default; ~TraceForBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
...@@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { ...@@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo {
class TraceForAfter : public TraceInfo { class TraceForAfter : public TraceInfo {
public: public:
explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {}
MS_DECLARE_PARENT(TraceForAfter, TraceInfo); MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
~TraceForAfter() override = default; ~TraceForAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
...@@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { ...@@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo {
class TraceEquiv : public TraceInfo { class TraceEquiv : public TraceInfo {
public: public:
explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
MS_DECLARE_PARENT(TraceEquiv, TraceInfo); MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
~TraceEquiv() override = default; ~TraceEquiv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); } TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
...@@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { ...@@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo {
class TraceGradFpropApp : public TraceInfo { class TraceGradFpropApp : public TraceInfo {
public: public:
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {}
explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {}
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
~TraceGradFpropApp() override = default; ~TraceGradFpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
...@@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { ...@@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo {
class TraceGradBpropApp : public TraceInfo { class TraceGradBpropApp : public TraceInfo {
public: public:
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {}
explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {}
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
~TraceGradBpropApp() override = default; ~TraceGradBpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
...@@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { ...@@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo {
class TraceGradFprop : public TraceInfo { class TraceGradFprop : public TraceInfo {
public: public:
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {}
explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {}
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
~TraceGradFprop() override = default; ~TraceGradFprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
...@@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { ...@@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo {
class TraceGradBprop : public TraceInfo { class TraceGradBprop : public TraceInfo {
public: public:
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {}
explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {}
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
~TraceGradBprop() override = default; ~TraceGradBprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
...@@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { ...@@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo {
class TraceGradSens : public TraceInfo { class TraceGradSens : public TraceInfo {
public: public:
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {}
explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {}
MS_DECLARE_PARENT(TraceGradSens, TraceInfo); MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
~TraceGradSens() override = default; ~TraceGradSens() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
...@@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { ...@@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo {
class TraceSpecialize : public TraceInfo { class TraceSpecialize : public TraceInfo {
public: public:
explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
std::string name() override { return full_name_ + counter_; } std::string name() override { return full_name_ + counter_; }
std::string symbol() override { return counter_ + "_"; } std::string symbol() override { return counter_ + "_"; }
...@@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { ...@@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo {
class TraceGradOperation : public TraceInfo { class TraceGradOperation : public TraceInfo {
public: public:
explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {}
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
~TraceGradOperation() override = default; ~TraceGradOperation() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { ...@@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo {
class TraceForceBool : public TraceInfo { class TraceForceBool : public TraceInfo {
public: public:
explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {}
MS_DECLARE_PARENT(TraceForceBool, TraceInfo); MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
~TraceForceBool() override = default; ~TraceForceBool() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); } TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
...@@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { ...@@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo {
class TraceExpandJ : public TraceInfo { class TraceExpandJ : public TraceInfo {
public: public:
explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
~TraceExpandJ() override = default; ~TraceExpandJ() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); } TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
...@@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { ...@@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo {
class TraceGenMetaFuncGraph : public TraceInfo { class TraceGenMetaFuncGraph : public TraceInfo {
public: public:
explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
~TraceGenMetaFuncGraph() override = default; ~TraceGenMetaFuncGraph() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { ...@@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo {
class TraceEvaluatorGenGraph : public TraceInfo { class TraceEvaluatorGenGraph : public TraceInfo {
public: public:
explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
~TraceEvaluatorGenGraph() override = default; ~TraceEvaluatorGenGraph() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { ...@@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo {
class TraceResolve : public TraceInfo { class TraceResolve : public TraceInfo {
public: public:
explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {}
MS_DECLARE_PARENT(TraceResolve, TraceInfo); MS_DECLARE_PARENT(TraceResolve, TraceInfo);
~TraceResolve() override = default; ~TraceResolve() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); } TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
...@@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { ...@@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo {
class TraceTransform : public TraceInfo { class TraceTransform : public TraceInfo {
public: public:
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") {
transform_name_ = transform_name; transform_name_ = transform_name;
} }
...@@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { ...@@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo {
class TraceGenerateVarArg : public TraceInfo { class TraceGenerateVarArg : public TraceInfo {
public: public:
explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {}
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
~TraceGenerateVarArg() override = default; ~TraceGenerateVarArg() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { ...@@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo {
class TraceGenerateKwArg : public TraceInfo { class TraceGenerateKwArg : public TraceInfo {
public: public:
explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {}
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
~TraceGenerateKwArg() override = default; ~TraceGenerateKwArg() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { ...@@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo {
class TraceTrasformK : public TraceInfo { class TraceTrasformK : public TraceInfo {
public: public:
explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {}
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
~TraceTrasformK() override = default; ~TraceTrasformK() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); } TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
...@@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { ...@@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo {
class TracePartialTransform : public TraceInfo { class TracePartialTransform : public TraceInfo {
public: public:
explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {}
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
~TracePartialTransform() override = default; ~TracePartialTransform() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
...@@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { ...@@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo {
class TraceGetEnv : public TraceInfo { class TraceGetEnv : public TraceInfo {
public: public:
explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceGetEnv() override = default; ~TraceGetEnv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); } TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
...@@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { ...@@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo {
class TraceDoSignature : public TraceInfo { class TraceDoSignature : public TraceInfo {
public: public:
explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {}
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
~TraceDoSignature() override = default; ~TraceDoSignature() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); } TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
...@@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { ...@@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo {
class TraceCombileLikeGraphs : public TraceInfo { class TraceCombileLikeGraphs : public TraceInfo {
public: public:
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {}
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
~TraceCombileLikeGraphs() override = default; ~TraceCombileLikeGraphs() override = default;
TraceInfoPtr clone() override { TraceInfoPtr clone() override {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (has_malloc_) { if (has_malloc_) {
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
} }
...@@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { ...@@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return size; return size;
} }
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
MS_EXCEPTION_IF_NULL(addr); MS_EXCEPTION_IF_NULL(addr);
has_malloc_ = false; has_malloc_ = false;
free_mem_size_ = total_mem_size_; free_mem_size_ = total_mem_size_;
...@@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { ...@@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base); MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base; device_mem_pool_base_ = device_mem_pool_base;
} }
......
...@@ -26,12 +26,12 @@ namespace ascend { ...@@ -26,12 +26,12 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit { class AscendMemoryPool : public DynamicMemPoolBestFit {
public: public:
~AscendMemoryPool() override = default; ~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete; AscendMemoryPool(const AscendMemoryPool &) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; AscendMemoryPool &operator=(const AscendMemoryPool &) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void set_device_mem_pool_base(uint8_t* device_mem_pool_base); void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size) { void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
device_mem_pool_size_ = device_mem_pool_size; device_mem_pool_size_ = device_mem_pool_size;
free_mem_size_ = device_mem_pool_size_; free_mem_size_ = device_mem_pool_size_;
...@@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { ...@@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
static AscendMemoryPool& GetInstance() { static AscendMemoryPool &GetInstance() {
static AscendMemoryPool instance; static AscendMemoryPool instance;
return instance; return instance;
} }
...@@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { ...@@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private: private:
AscendMemoryPool() = default; AscendMemoryPool() = default;
bool has_malloc_{false}; bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr}; uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
size_t free_mem_size_{0}; size_t free_mem_size_{0};
size_t total_mem_size_{0}; size_t total_mem_size_{0};
......
...@@ -39,13 +39,13 @@ using std::vector; ...@@ -39,13 +39,13 @@ using std::vector;
class AscendStreamAssign { class AscendStreamAssign {
public: public:
static AscendStreamAssign& GetInstance() { static AscendStreamAssign &GetInstance() {
static AscendStreamAssign instance; // Guaranteed to be destroyed. static AscendStreamAssign instance; // Guaranteed to be destroyed.
return instance; return instance;
} }
AscendStreamAssign(const AscendStreamAssign&) = delete; AscendStreamAssign(const AscendStreamAssign &) = delete;
AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
uint32_t GetTotalStreamNum() const; uint32_t GetTotalStreamNum() const;
// new stream policy // new stream policy
...@@ -53,19 +53,19 @@ class AscendStreamAssign { ...@@ -53,19 +53,19 @@ class AscendStreamAssign {
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
uint32_t total_event_num() const { return total_event_num_; } uint32_t total_event_num() const { return total_event_num_; }
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ResetNew(); void ResetNew();
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsIndependentNode(const CNodePtr& node_ptr); bool IsIndependentNode(const CNodePtr &node_ptr);
const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; } const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; } const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; } const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t>* wait_active_stream_list); void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; } const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
private: private:
...@@ -73,30 +73,30 @@ class AscendStreamAssign { ...@@ -73,30 +73,30 @@ class AscendStreamAssign {
~AscendStreamAssign() = default; ~AscendStreamAssign() = default;
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr& node); const CNodePtr &node);
bool IsHcom(const CNodePtr& apply_kernel); bool IsHcom(const CNodePtr &apply_kernel);
bool IsProcessed(uint32_t logic_id); bool IsProcessed(uint32_t logic_id);
void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids); void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
uint32_t* cur_stream_id); uint32_t *cur_stream_id);
void RecordIdMap(uint32_t logic_id, uint32_t physic_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
void UpdateStreamActive(const CNodePtr& active_ptr); void UpdateStreamActive(const CNodePtr &active_ptr);
void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
bool IsTaskSink(); bool IsTaskSink();
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr); void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id); void SetCommonStreamNum(uint32_t cur_stream_id);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsProcessedParallelStream(uint32_t stream_id); bool IsProcessedParallelStream(uint32_t stream_id);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr); void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
uint32_t total_common_stream_num_{0}; uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0}; uint32_t total_independ_stream_num_{0};
......
...@@ -28,14 +28,14 @@ namespace device { ...@@ -28,14 +28,14 @@ namespace device {
namespace ascend { namespace ascend {
class PluginImpl : public PluginIntf { class PluginImpl : public PluginIntf {
public: public:
explicit PluginImpl(const std::string& module); explicit PluginImpl(const std::string &module);
~PluginImpl() override = default; ~PluginImpl() override = default;
int Init(const Reporter* reporter) override; int Init(const Reporter *reporter) override;
int UnInit() override; int UnInit() override;
static Reporter* GetPluginReporter() { return reporter_; } static Reporter *GetPluginReporter() { return reporter_; }
private: private:
static Reporter* reporter_; static Reporter *reporter_;
std::string module_; std::string module_;
}; };
} // namespace ascend } // namespace ascend
......
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
PluginIntf* ProfilingEngineImpl::CreatePlugin() { PluginIntf *ProfilingEngineImpl::CreatePlugin() {
MS_LOG(INFO) << "Create Plugin."; MS_LOG(INFO) << "Create Plugin.";
return new (std::nothrow) PluginImpl("Framework"); return new (std::nothrow) PluginImpl("Framework");
} }
int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) {
if (plugin != nullptr) { if (plugin != nullptr) {
delete plugin; delete plugin;
} }
......
...@@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { ...@@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf {
ProfilingEngineImpl() = default; ProfilingEngineImpl() = default;
~ProfilingEngineImpl() override = default; ~ProfilingEngineImpl() override = default;
PluginIntf* CreatePlugin() override; PluginIntf *CreatePlugin() override;
int ReleasePlugin(PluginIntf* plugin) override; int ReleasePlugin(PluginIntf *plugin) override;
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -35,7 +35,7 @@ using Json = nlohmann::json; ...@@ -35,7 +35,7 @@ using Json = nlohmann::json;
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
ProfilingManager& ProfilingManager::GetInstance() { ProfilingManager &ProfilingManager::GetInstance() {
static ProfilingManager inst; static ProfilingManager inst;
return inst; return inst;
} }
...@@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { ...@@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) {
} }
uint64_t ProfilingManager::GetJobId() const { uint64_t ProfilingManager::GetJobId() const {
const char* job_id = std::getenv("JOB_ID"); const char *job_id = std::getenv("JOB_ID");
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
} }
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const { bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const {
if (!IsProfiling()) { if (!IsProfiling()) {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return false; return false;
...@@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI ...@@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
Msprof::Engine::ReporterData reporter_data = {}; Msprof::Engine::ReporterData reporter_data = {};
for (const auto& iter : op_taskId_map) { for (const auto &iter : op_taskId_map) {
auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
reporter_data.deviceId = UintToInt(device_id_); reporter_data.deviceId = UintToInt(device_id_);
reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str())); reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str()));
reporter_data.dataLen = data.size(); reporter_data.dataLen = data.size();
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
if (ret != 0) { if (ret != 0) {
...@@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI ...@@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
return true; return true;
} }
static std::vector<std::string> Split(const std::string& str, const char delim) { static std::vector<std::string> Split(const std::string &str, const char delim) {
std::vector<std::string> elems; std::vector<std::string> elems;
if (str.empty()) { if (str.empty()) {
...@@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { ...@@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
device_id_ = device_id; device_id_ = device_id;
// exp: export PROFILING_MODE=true // exp: export PROFILING_MODE=true
// export PROFILING_OPTIONS=training_trace // export PROFILING_OPTIONS=training_trace
const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); const char *prof_options_str = std::getenv("PROFILING_OPTIONS");
// register Framework to profiling // register Framework to profiling
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
if (result != 0) { if (result != 0) {
...@@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { ...@@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return true; return true;
} }
Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter != nullptr) { if (reporter != nullptr) {
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
} }
......
...@@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, ...@@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
class GpuQueue { class GpuQueue {
public: public:
GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity);
virtual ~GpuQueue(); virtual ~GpuQueue();
void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; } void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; }
inline bool IsEmpty() const { return head_ == tail_; } inline bool IsEmpty() const { return head_ == tail_; }
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const;
BlockQueueStatus_T Pop(); BlockQueueStatus_T Pop();
bool Destroy(); bool Destroy();
private: private:
struct NodeInfo { struct NodeInfo {
std::unique_ptr<cudaEvent_t> event_; std::unique_ptr<cudaEvent_t> event_;
void* host_feature_addr_; void *host_feature_addr_;
void* host_label_addr_; void *host_label_addr_;
}; };
void* buffer_; void *buffer_;
size_t head_; size_t head_;
size_t tail_; size_t tail_;
size_t feature_size_; size_t feature_size_;
...@@ -61,10 +61,10 @@ class GpuQueue { ...@@ -61,10 +61,10 @@ class GpuQueue {
size_t capacity_; size_t capacity_;
cudaStream_t stream_; cudaStream_t stream_;
std::unique_ptr<NodeInfo[]> node_info_; std::unique_ptr<NodeInfo[]> node_info_;
std::function<void(void*)> host_release_; std::function<void(void *)> host_release_;
GpuQueue(const GpuQueue&) = delete; GpuQueue(const GpuQueue &) = delete;
GpuQueue& operator=(const GpuQueue&) = delete; GpuQueue &operator=(const GpuQueue &) = delete;
}; };
class BlockingQueue { class BlockingQueue {
...@@ -72,11 +72,11 @@ class BlockingQueue { ...@@ -72,11 +72,11 @@ class BlockingQueue {
BlockingQueue() : queue_(nullptr) {} BlockingQueue() : queue_(nullptr) {}
~BlockingQueue() = default; ~BlockingQueue() = default;
BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void*)>& func); void RegisterRelease(const std::function<void(void *)> &func);
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size,
unsigned int timeout_in_sec); unsigned int timeout_in_sec);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size);
BlockQueueStatus_T Pop(); BlockQueueStatus_T Pop();
bool Destroy(); bool Destroy();
......
...@@ -20,17 +20,17 @@ ...@@ -20,17 +20,17 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
CollectiveInitializer& CollectiveInitializer::instance() { CollectiveInitializer &CollectiveInitializer::instance() {
static CollectiveInitializer instance = {}; static CollectiveInitializer instance = {};
return instance; return instance;
} }
bool CollectiveInitializer::collective_inited() const { return collective_inited_; } bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } const void *CollectiveInitializer::collective_handle() const { return collective_handle_; }
void CollectiveInitializer::InitCollective() { void CollectiveInitializer::InitCollective() {
void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) { if (handle == nullptr) {
MS_LOG(EXCEPTION) MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "
......
...@@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { ...@@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() {
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
} }
bool GPUDeviceManager::CreateStream(DeviceStream* stream) { bool GPUDeviceManager::CreateStream(DeviceStream *stream) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
gpu_streams_.emplace_back(*stream); gpu_streams_.emplace_back(*stream);
return true; return true;
} }
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; }
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
...@@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } ...@@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; }
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {
return CudaDriver::CopyDeviceMemToHost(dst, src, size); return CudaDriver::CopyDeviceMemToHost(dst, src, size);
} }
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const {
return CudaDriver::CopyHostMemToDevice(dst, src, size); return CudaDriver::CopyHostMemToDevice(dst, src, size);
} }
} // namespace gpu } // namespace gpu
......
...@@ -37,17 +37,17 @@ class GPUDeviceManager { ...@@ -37,17 +37,17 @@ class GPUDeviceManager {
uint32_t cur_device_id() const; uint32_t cur_device_id() const;
bool is_device_id_init() const; bool is_device_id_init() const;
bool CreateStream(DeviceStream* stream); bool CreateStream(DeviceStream *stream);
bool SyncStream(const DeviceStream& stream) const; bool SyncStream(const DeviceStream &stream) const;
const DeviceStream& default_stream() const; const DeviceStream &default_stream() const;
const cudnnHandle_t& GetCudnnHandle() const; const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t& GetCublasHandle() const; const cublasHandle_t &GetCublasHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
static GPUDeviceManager& GetInstance() { static GPUDeviceManager &GetInstance() {
static GPUDeviceManager instance; static GPUDeviceManager instance;
return instance; return instance;
} }
...@@ -55,8 +55,8 @@ class GPUDeviceManager { ...@@ -55,8 +55,8 @@ class GPUDeviceManager {
private: private:
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
~GPUDeviceManager() = default; ~GPUDeviceManager() = default;
GPUDeviceManager(const GPUDeviceManager&) = delete; GPUDeviceManager(const GPUDeviceManager &) = delete;
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; GPUDeviceManager &operator=(const GPUDeviceManager &) = delete;
// default CUDA stream used for all the kernels. // default CUDA stream used for all the kernels.
DeviceStream default_stream_{nullptr}; DeviceStream default_stream_{nullptr};
......
...@@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { ...@@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() {
return true; return true;
} }
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
auto alloc_size = AllocDeviceMem(size, addr); auto alloc_size = AllocDeviceMem(size, addr);
buffer_q_addr_ = *addr; buffer_q_addr_ = *addr;
// Buffer queue needs to ensure that the alloc_size and size is equal. // Buffer queue needs to ensure that the alloc_size and size is equal.
return (alloc_size == size) ? true : false; return (alloc_size == size) ? true : false;
} }
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (size == 0) { if (size == 0) {
MS_LOG(EXCEPTION) << "The memory alloc size is 0."; MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
} }
...@@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { ...@@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return alloc_size; return alloc_size;
} }
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); }
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }
......
...@@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { ...@@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit {
~GPUMemoryAllocator() override = default; ~GPUMemoryAllocator() override = default;
bool Init(); bool Init();
bool Finalize(); bool Finalize();
bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr);
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override; bool FreeDeviceMem(const DeviceMemPtr &addr) override;
size_t free_mem_size() override; size_t free_mem_size() override;
size_t total_mem_size() override; size_t total_mem_size() override;
static GPUMemoryAllocator& GetInstance() { static GPUMemoryAllocator &GetInstance() {
static GPUMemoryAllocator instance; static GPUMemoryAllocator instance;
return instance; return instance;
} }
private: private:
GPUMemoryAllocator() = default; GPUMemoryAllocator() = default;
GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; GPUMemoryAllocator(const GPUMemoryAllocator &) = delete;
GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete;
// Used to track address of data buffer queue. // Used to track address of data buffer queue.
DeviceMemPtr buffer_q_addr_{nullptr}; DeviceMemPtr buffer_q_addr_{nullptr};
......
...@@ -33,8 +33,8 @@ namespace gpu { ...@@ -33,8 +33,8 @@ namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo; using mindspore::kernel::KernelBuildInfo;
namespace { namespace {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info, bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(selected_kernel_info);
MS_EXCEPTION_IF_NULL(alternative_kernel_info); MS_EXCEPTION_IF_NULL(alternative_kernel_info);
size_t selected_input_num = selected_kernel_info->GetInputNum(); size_t selected_input_num = selected_kernel_info->GetInputNum();
...@@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_ ...@@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_
return true; return true;
} }
std::string SupportedTypeList(const CNodePtr& kernel_node) { std::string SupportedTypeList(const CNodePtr &kernel_node) {
std::string supported_type_lists = std::string supported_type_lists =
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
if (!supported_type_lists.empty()) { if (!supported_type_lists.empty()) {
...@@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { ...@@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
return supported_type_lists; return supported_type_lists;
} }
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) { bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(selected_kernel_info);
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
...@@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu ...@@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
} }
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(), bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) { [&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
}); });
if (!match) { if (!match) {
...@@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu ...@@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
return true; return true;
} }
void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) { void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1); auto input_kernel_node = kernel_node->input(input_index + 1);
...@@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co ...@@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co
} }
} // namespace } // namespace
void SetKernelInfo(const CNodePtr& kernel_node) { void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type; std::vector<TypeId> inputs_type;
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder = std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
void SetKernelInfo(const CNodePtr& apply_kernel_ptr); void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr { class KernelAttr {
public: public:
...@@ -35,24 +35,24 @@ class KernelAttr { ...@@ -35,24 +35,24 @@ class KernelAttr {
KernelAttr() : all_same_(false) {} KernelAttr() : all_same_(false) {}
~KernelAttr() = default; ~KernelAttr() = default;
KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format); input_type_.emplace_back(ms_type, format);
return *this; return *this;
} }
KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format); output_type_.emplace_back(ms_type, format);
return *this; return *this;
} }
KernelAttr& AddAllSameAttr(const bool& all_same) { KernelAttr &AddAllSameAttr(const bool &all_same) {
all_same_ = all_same; all_same_ = all_same;
return *this; return *this;
} }
const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool& GetAllSame() const { return all_same_; } const bool &GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); } size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace mindspore { namespace mindspore {
struct TypeIdManager* TypeIdManager::Get() { struct TypeIdManager *TypeIdManager::Get() {
static TypeIdManager manager; static TypeIdManager manager;
return &manager; return &manager;
} }
......
...@@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra ...@@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const { std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info()); return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
} }
CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph) CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
// Check if CNode is an apply with the specific Primitive. // Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr& value) const { bool CNode::IsApply(const PrimitivePtr &value) const {
if (value == nullptr) { if (value == nullptr) {
return false; return false;
} }
...@@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { ...@@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const {
return false; return false;
} }
void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
std::string CNode::DebugString(int recursive_level) const { std::string CNode::DebugString(int recursive_level) const {
std::ostringstream buffer; std::ostringstream buffer;
...@@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { ...@@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const {
buffer << ToString() << "{"; buffer << ToString() << "{";
bool is_first_node = true; bool is_first_node = true;
int idx = 0; int idx = 0;
for (auto& node : inputs_) { for (auto &node : inputs_) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (is_first_node) { if (is_first_node) {
is_first_node = false; is_first_node = false;
...@@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { ...@@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str(); return buffer.str();
} }
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) { if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name(); << ", using the new one: " << operator_info->name();
...@@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { ...@@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_; return fullname_with_scope_;
} }
void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); } void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); } void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); } void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) { if (cnode != nullptr) {
...@@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { ...@@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
return false; return false;
} }
PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return nullptr; return nullptr;
} }
...@@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { ...@@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
return ""; return "";
} }
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
if (IsValueNode<Primitive>(node)) { if (IsValueNode<Primitive>(node)) {
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node); PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(value); MS_EXCEPTION_IF_NULL(value);
...@@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { ...@@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
} }
namespace id_generator { namespace id_generator {
static std::unordered_map<std::string, int> node_ids; static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr& node) { std::string get_id(const AnfNodePtr &node) {
auto type_name = node->type_name(); auto type_name = node->type_name();
if (node_ids.find(type_name) == node_ids.end()) { if (node_ids.find(type_name) == node_ids.end()) {
node_ids[type_name] = 0; node_ids[type_name] = 0;
......
...@@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {}; ...@@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
class Base : public std::enable_shared_from_this<Base> { class Base : public std::enable_shared_from_this<Base> {
public: public:
constexpr Base() = default; constexpr Base() = default;
Base(const Base& other) : std::enable_shared_from_this<Base>(other) {} Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base& rhs) { virtual bool operator==(const Base &rhs) {
if (this == &rhs) { if (this == &rhs) {
return true; return true;
} }
return false; return false;
} }
virtual Base& operator=(const Base&) { return *this; } virtual Base &operator=(const Base &) { return *this; }
virtual ~Base() = default; virtual ~Base() = default;
virtual std::size_t hash() const { return tid(); } virtual std::size_t hash() const { return tid(); }
virtual std::string ToString() const { return type_name(); } virtual std::string ToString() const { return type_name(); }
...@@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> { ...@@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> {
virtual const bool IsFromTypeId(uint32_t tid) const; virtual const bool IsFromTypeId(uint32_t tid) const;
virtual std::string type_name() const { return "Base"; } virtual std::string type_name() const { return "Base"; }
static uint32_t GetTypeId(const char* const type_key); static uint32_t GetTypeId(const char *const type_key);
virtual uint32_t tid() const { virtual uint32_t tid() const {
static const uint32_t tid = GetTypeId(typeid(Base).name()); static const uint32_t tid = GetTypeId(typeid(Base).name());
return tid; return tid;
} }
template <typename T, template <typename T,
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr> typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
inline bool isa() const { inline bool isa() const {
static const uint32_t tid = GetTypeId(typeid(T).name()); static const uint32_t tid = GetTypeId(typeid(T).name());
return this->IsFromTypeId(tid); return this->IsFromTypeId(tid);
...@@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>; ...@@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>;
using BaseWeakPtr = std::weak_ptr<Base>; using BaseWeakPtr = std::weak_ptr<Base>;
template <typename T, typename U> template <typename T, typename U>
inline T* cast(U* source) { inline T *cast(U *source) {
if (source != nullptr && source->template isa<T>()) { if (source != nullptr && source->template isa<T>()) {
return static_cast<T*>(source); return static_cast<T *>(source);
} else { } else {
return nullptr; return nullptr;
} }
...@@ -100,7 +100,7 @@ inline T* cast(U* source) { ...@@ -100,7 +100,7 @@ inline T* cast(U* source) {
template < template <
typename T, typename U, typename T, typename U,
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr> typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) { inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
if (r != nullptr && r->template isa<T>()) { if (r != nullptr && r->template isa<T>()) {
return std::static_pointer_cast<T>(r); return std::static_pointer_cast<T>(r);
...@@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { ...@@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager {
std::mutex mutex; std::mutex mutex;
std::atomic<uint32_t> type_counter{0}; std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> map; std::unordered_map<std::string, uint32_t> map;
static TypeIdManager* Get(); static TypeIdManager *Get();
TypeIdManager() : mutex(), type_counter(0), map() {} TypeIdManager() : mutex(), type_counter(0), map() {}
}; };
} // namespace mindspore } // namespace mindspore
......
...@@ -48,11 +48,11 @@ std::string Keyword::ToString() const { ...@@ -48,11 +48,11 @@ std::string Keyword::ToString() const {
return buffer.str(); return buffer.str();
} }
bool Keyword::operator==(const Type& other) const { bool Keyword::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_keyword = static_cast<const Keyword&>(other); const auto &other_keyword = static_cast<const Keyword &>(other);
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
} }
...@@ -87,11 +87,11 @@ std::string Slice::ToString() const { ...@@ -87,11 +87,11 @@ std::string Slice::ToString() const {
return buffer.str(); return buffer.str();
} }
bool Slice::operator==(const Type& other) const { bool Slice::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_slice = static_cast<const Slice&>(other); auto other_slice = static_cast<const Slice &>(other);
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
} }
...@@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { ...@@ -122,11 +122,11 @@ std::string TensorType::DumpText() const {
} }
} }
bool TensorType::operator==(const Type& other) const { bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_elem_type = static_cast<const TensorType&>(other).element_type_; auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array. // When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) { if (element_type_ == nullptr && other_elem_type == nullptr) {
return true; return true;
...@@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { ...@@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) {
retval_ = nullptr; retval_ = nullptr;
} }
Function::Function(const std::vector<TypePtr>& args, const TypePtr retval) Function::Function(const std::vector<TypePtr> &args, const TypePtr retval)
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
TypePtr Function::DeepCopy() const { TypePtr Function::DeepCopy() const {
...@@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { ...@@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const {
TypePtrList args; TypePtrList args;
TypePtr retval = nullptr; TypePtr retval = nullptr;
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
[](const TypePtr& arg) { return arg->DeepCopy(); }); [](const TypePtr &arg) { return arg->DeepCopy(); });
if (retval_ != nullptr) { if (retval_ != nullptr) {
retval = retval_->DeepCopy(); retval = retval_->DeepCopy();
} }
...@@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { ...@@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const {
} }
} }
bool Function::operator==(const Type& other) const { bool Function::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_function = static_cast<const Function&>(other); const auto &other_function = static_cast<const Function &>(other);
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
if (*retval_ != *other_function.retval_) { if (*retval_ != *other_function.retval_) {
return false; return false;
...@@ -188,7 +188,7 @@ std::string Function::ToString() const { ...@@ -188,7 +188,7 @@ std::string Function::ToString() const {
} else { } else {
buffer << "Func[("; buffer << "Func[(";
bool begin = true; bool begin = true;
for (auto& attr : args_) { for (auto &attr : args_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
...@@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { ...@@ -242,34 +242,34 @@ std::string JTagged::DumpText() const {
return buffer.str(); return buffer.str();
} }
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) { std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) {
MS_EXCEPTION_IF_NULL(problem); MS_EXCEPTION_IF_NULL(problem);
os << problem->ToString(); os << problem->ToString();
return os; return os;
} }
std::size_t TypeHasher::operator()(TypePtr const& type) const { std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id()); std::size_t hash = std::hash<size_t>()(type->type_id());
return hash; return hash;
} }
std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0; std::size_t hash_sum = 0;
for (auto& type : type_list) { for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id()); auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id); hash_sum = hash_combine(hash_sum, type_id);
} }
return hash_sum; return hash_sum;
} }
bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2); MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id(); return t1->type_id() == t2->type_id();
} }
bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) { if (lhs.size() != rhs.size()) {
return false; return false;
} }
...@@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { ...@@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) {
namespace { namespace {
template <typename T> template <typename T>
TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == num_type_name) { if (type_name == num_type_name) {
type = std::make_shared<T>(); type = std::make_shared<T>();
...@@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ ...@@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_
} }
auto bits = std::stoi(type_name.substr(num_type_name.size())); auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits); type = std::make_shared<T>(bits);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
} }
} }
return type; return type;
} }
std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types; std::vector<TypePtr> types;
if (type_names.length() == 0) { if (type_names.length() == 0) {
return types; return types;
...@@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) { ...@@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
return types; return types;
} }
TypePtr TensorStrToType(const std::string& type_name) { TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Tensor") { if (type_name == "Tensor") {
type = std::make_shared<TensorType>(); type = std::make_shared<TensorType>();
...@@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { ...@@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return nullptr; return nullptr;
} }
type = std::make_shared<TensorType>(element_type); type = std::make_shared<TensorType>(element_type);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
...@@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { ...@@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return type; return type;
} }
TypePtr ListStrToType(const std::string& type_name) { TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "List") { if (type_name == "List") {
type = std::make_shared<List>(); type = std::make_shared<List>();
...@@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { ...@@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start); std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong = bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) { if (wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<List>(element_types); type = std::make_shared<List>(element_types);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
...@@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { ...@@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) {
return type; return type;
} }
TypePtr TupleStrToType(const std::string& type_name) { TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Tuple") { if (type_name == "Tuple") {
type = std::make_shared<Tuple>(); type = std::make_shared<Tuple>();
...@@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { ...@@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start); std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs); std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong = bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) { if (wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<Tuple>(element_types); type = std::make_shared<Tuple>(element_types);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
return type; return type;
} }
TypePtr FunctionStrToType(const std::string& type_name) { TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name == "Function") { if (type_name == "Function") {
...@@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { ...@@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) {
std::vector<TypePtr> args_type = StringToVectorOfType(str_args); std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval); TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) { if (retval == nullptr || wrong) {
return nullptr; return nullptr;
} }
type = std::make_shared<Function>(args_type, retval); type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception& e) { } catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
} }
} }
...@@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) { ...@@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) {
} }
} // namespace } // namespace
TypePtr StringToType(const std::string& type_name) { TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name.compare("None") == 0) { if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>(); type = std::make_shared<TypeNone>();
...@@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) { ...@@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) {
return type; return type;
} }
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) { if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr."; MS_LOG(ERROR) << "Type is nullptr.";
return false; return false;
...@@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { ...@@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
} }
} }
bool IsSubType(TypePtr const& t1, TypePtr const& t2) { bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) { if (t1->type_id() == kTypeUnknown) {
return false; return false;
...@@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { ...@@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
} }
REGISTER_PYBIND_DEFINE( REGISTER_PYBIND_DEFINE(
typing, ([](py::module* const m) { typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype"); auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId"); py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type"); (void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def( (void)m_sub.def(
"dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__", .def("__eq__",
[](const TypePtr& t1, const TypePtr& t2) { [](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) { if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2; return *t1 == *t2;
} }
...@@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE( ...@@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE(
.def("__hash__", &Type::hash) .def("__hash__", &Type::hash)
.def("__str__", &Type::ToString) .def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString) .def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr& t, py::dict) { .def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) { if (t == nullptr) {
return static_cast<TypePtr>(nullptr); return static_cast<TypePtr>(nullptr);
} }
...@@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE( ...@@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool") (void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init()) .def(py::init())
.def(py::pickle( .def(py::pickle(
[](const Bool&) { // __getstate__ [](const Bool &) { // __getstate__
return py::make_tuple(); return py::make_tuple();
}, },
[](const py::tuple&) { // __setstate__ [](const py::tuple &) { // __setstate__
return std::make_shared<Bool>(); return std::make_shared<Bool>();
})); }));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int") (void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const Int& t) { // __getstate__ [](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
...@@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE( ...@@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const UInt& t) { // __getstate__ [](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
...@@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE( ...@@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init()) .def(py::init())
.def(py::init<int>(), py::arg("nbits")) .def(py::init<int>(), py::arg("nbits"))
.def(py::pickle( .def(py::pickle(
[](const Float& t) { // __getstate__ [](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits())); return py::make_tuple(py::int_(t.nbits()));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
...@@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE( ...@@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init<TypePtr>(), py::arg("element")) .def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element) .def("element_type", &TensorType::element)
.def(py::pickle( .def(py::pickle(
[](const TensorType& t) { // __getstate__ [](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id()))); return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
......
...@@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>; ...@@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>;
class Keyword : public Object { class Keyword : public Object {
public: public:
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
~Keyword() override = default; ~Keyword() override = default;
MS_DECLARE_PARENT(Keyword, Object) MS_DECLARE_PARENT(Keyword, Object)
...@@ -70,7 +70,7 @@ class Keyword : public Object { ...@@ -70,7 +70,7 @@ class Keyword : public Object {
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::string GetKey() const { return key_; } std::string GetKey() const { return key_; }
TypePtr GetValue() const { return value_; } TypePtr GetValue() const { return value_; }
...@@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>; ...@@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>;
class Slice : public Object { class Slice : public Object {
public: public:
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
~Slice() override = default; ~Slice() override = default;
...@@ -95,7 +95,7 @@ class Slice : public Object { ...@@ -95,7 +95,7 @@ class Slice : public Object {
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr get_start() const { return start_; } TypePtr get_start() const { return start_; }
TypePtr get_stop() const { return stop_; } TypePtr get_stop() const { return stop_; }
...@@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>; ...@@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>;
class TensorType : public Object { class TensorType : public Object {
public: public:
TensorType() : Object(kObjectTypeTensorType) {} TensorType() : Object(kObjectTypeTensorType) {}
explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
~TensorType() override = default; ~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object) MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; } TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; } const TypePtr element() const { return element_type_; }
void set_element(const TypePtr& element_type) { element_type_ = element_type; } void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string ToReprString() const override { return "tensor"; } std::string ToReprString() const override { return "tensor"; }
std::string DumpText() const override; std::string DumpText() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
private: private:
TypePtr element_type_; TypePtr element_type_;
...@@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>; ...@@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
class Function : public Object { class Function : public Object {
public: public:
Function(); Function();
Function(const std::vector<TypePtr>& args, const TypePtr retval); Function(const std::vector<TypePtr> &args, const TypePtr retval);
~Function() override = default; ~Function() override = default;
MS_DECLARE_PARENT(Function, Object) MS_DECLARE_PARENT(Function, Object)
...@@ -141,11 +141,11 @@ class Function : public Object { ...@@ -141,11 +141,11 @@ class Function : public Object {
// Add temporarily for return abstraction to avoid type checking. // Add temporarily for return abstraction to avoid type checking.
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
const std::vector<TypePtr>& args() const { return args_; } const std::vector<TypePtr> &args() const { return args_; }
const TypePtr& retval() const { return retval_; } const TypePtr &retval() const { return retval_; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::string ToString() const override; std::string ToString() const override;
std::string ToReprString() const override { return "function"; } std::string ToReprString() const override { return "function"; }
...@@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>; ...@@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>;
class JTagged : public Object { class JTagged : public Object {
public: public:
JTagged() : Object(kObjectTypeJTagged) {} JTagged() : Object(kObjectTypeJTagged) {}
explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
~JTagged() override = default; ~JTagged() override = default;
MS_DECLARE_PARENT(JTagged, Object) MS_DECLARE_PARENT(JTagged, Object)
...@@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>; ...@@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>;
class Problem : public Type { class Problem : public Type {
public: public:
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
~Problem() override = default; ~Problem() override = default;
MS_DECLARE_PARENT(Problem, Type) MS_DECLARE_PARENT(Problem, Type)
...@@ -222,7 +222,7 @@ class Problem : public Type { ...@@ -222,7 +222,7 @@ class Problem : public Type {
std::string ToString() const override { return kind_.name(); } std::string ToString() const override { return kind_.name(); }
std::string DumpText() const override { return "ProblemType"; } std::string DumpText() const override { return "ProblemType"; }
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem); friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
private: private:
Named kind_; Named kind_;
...@@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>; ...@@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>;
// helper template // helper template
template <class T> template <class T>
TypePtr Clone(const T& t) { TypePtr Clone(const T &t) {
return t.Clone(); return t.Clone();
} }
TypePtr StringToType(const std::string& type_name); TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate. // Judge whether x is predicate or is a subclass of predicate.
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
// Whether t1 is identity or a subclass of t2. // Whether t1 is identity or a subclass of t2.
bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
struct TypeHasher { struct TypeHasher {
std::size_t operator()(TypePtr const& type) const; std::size_t operator()(TypePtr const &type) const;
}; };
struct TypeListHasher { struct TypeListHasher {
std::size_t operator()(const TypePtrList& type_list) const; std::size_t operator()(const TypePtrList &type_list) const;
}; };
struct TypeEqual { struct TypeEqual {
bool operator()(TypePtr const& t1, TypePtr const& t2) const; bool operator()(TypePtr const &t1, TypePtr const &t2) const;
}; };
struct TypeListEqual { struct TypeListEqual {
bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
}; };
extern const TypePtr kTypeExternal; extern const TypePtr kTypeExternal;
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
namespace mindspore { namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) { static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
std::ostringstream oss; std::ostringstream oss;
bool begin = true; bool begin = true;
int cnt = 0; int cnt = 0;
...@@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { ...@@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const {
} else { } else {
TypePtrList elements; TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); }); [](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<List>(elements); auto copy = std::make_shared<List>(elements);
return copy; return copy;
} }
...@@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { ...@@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const {
return elements_[dim]; return elements_[dim];
} }
bool List::operator==(const Type& other) const { bool List::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const List& other_list = static_cast<const List&>(other); const List &other_list = static_cast<const List &>(other);
if (elements_.size() != other_list.elements_.size()) { if (elements_.size() != other_list.elements_.size()) {
return false; return false;
} }
...@@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { ...@@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const {
return true; return true;
} }
Class::Class(const Named& tag, const ClassAttrVector& attributes, Class::Class(const Named &tag, const ClassAttrVector &attributes,
const std::unordered_map<std::string, ValuePtr>& methods) const std::unordered_map<std::string, ValuePtr> &methods)
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
std::string List::ToString() const { std::string List::ToString() const {
...@@ -122,7 +122,7 @@ std::string List::DumpText() const { ...@@ -122,7 +122,7 @@ std::string List::DumpText() const {
return buffer.str(); return buffer.str();
} }
bool Class::operator==(const Type& other) const { bool Class::operator==(const Type &other) const {
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
return &other == this; return &other == this;
} }
...@@ -143,7 +143,7 @@ std::string Class::ToString() const { ...@@ -143,7 +143,7 @@ std::string Class::ToString() const {
} else { } else {
bool begin = true; bool begin = true;
buffer << "cls." << tag_ << "["; buffer << "cls." << tag_ << "[";
for (auto& attr : attributes_) { for (auto &attr : attributes_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
...@@ -163,7 +163,7 @@ std::string Class::DumpText() const { ...@@ -163,7 +163,7 @@ std::string Class::DumpText() const {
} else { } else {
bool begin = true; bool begin = true;
buffer << "Cls." << tag_ << "["; buffer << "Cls." << tag_ << "[";
for (auto& attr : attributes_) { for (auto &attr : attributes_) {
if (!begin) { if (!begin) {
buffer << ", "; buffer << ", ";
} else { } else {
...@@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { ...@@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const {
} else { } else {
TypePtrList elements; TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); }); [](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<Tuple>(elements); auto copy = std::make_shared<Tuple>(elements);
return copy; return copy;
} }
} }
bool Tuple::operator==(const Type& other) const { bool Tuple::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_tuple = static_cast<const Tuple&>(other); auto other_tuple = static_cast<const Tuple &>(other);
if (elements_.size() != other_tuple.elements_.size()) { if (elements_.size() != other_tuple.elements_.size()) {
return false; return false;
} }
...@@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { ...@@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const {
std::vector<std::pair<std::string, TypePtr>> kv; std::vector<std::pair<std::string, TypePtr>> kv;
(void)std::transform( (void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv), key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); [](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); });
return std::make_shared<Dictionary>(kv); return std::make_shared<Dictionary>(kv);
} }
} }
...@@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { ...@@ -259,7 +259,7 @@ std::string Dictionary::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
std::vector<std::string> keys; std::vector<std::string> keys;
std::vector<TypePtr> values; std::vector<TypePtr> values;
for (const auto& kv : key_values_) { for (const auto &kv : key_values_) {
keys.push_back(kv.first); keys.push_back(kv.first);
values.push_back(kv.second); values.push_back(kv.second);
} }
...@@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { ...@@ -276,12 +276,12 @@ std::string Dictionary::ToString() const {
std::string Dictionary::DumpText() const { return ToString(); } std::string Dictionary::DumpText() const { return ToString(); }
bool Dictionary::operator==(const mindspore::Type& other) const { bool Dictionary::operator==(const mindspore::Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
const auto& other_dict = static_cast<const Dictionary&>(other); const auto &other_dict = static_cast<const Dictionary &>(other);
if (key_values_.size() != other_dict.key_values_.size()) { if (key_values_.size() != other_dict.key_values_.size()) {
return false; return false;
} }
......
...@@ -40,10 +40,10 @@ namespace mindspore { ...@@ -40,10 +40,10 @@ namespace mindspore {
class List : public Object { class List : public Object {
public: public:
List() : Object(kObjectTypeList) {} List() : Object(kObjectTypeList) {}
List(const std::initializer_list<TypePtr>& objs) List(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy; // Shadow copy;
explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {}
~List() override {} ~List() override {}
MS_DECLARE_PARENT(List, Object) MS_DECLARE_PARENT(List, Object)
...@@ -51,7 +51,7 @@ class List : public Object { ...@@ -51,7 +51,7 @@ class List : public Object {
TypeId generic_type_id() const override { return kObjectTypeList; } TypeId generic_type_id() const override { return kObjectTypeList; }
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
std::size_t size() const { return elements_.size(); } std::size_t size() const { return elements_.size(); }
TypePtrList elements() const { return elements_; } TypePtrList elements() const { return elements_; }
std::string ToString() const override; std::string ToString() const override;
...@@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>; ...@@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>;
class Class : public Object { class Class : public Object {
public: public:
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods); Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods);
~Class() override {} ~Class() override {}
MS_DECLARE_PARENT(Class, Object) MS_DECLARE_PARENT(Class, Object)
TypeId generic_type_id() const override { return kObjectTypeClass; } TypeId generic_type_id() const override { return kObjectTypeClass; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; } void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
Named tag() { return tag_; } Named tag() { return tag_; }
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; } std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
std::unordered_map<std::string, ValuePtr> methods() { return methods_; } std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
ClassAttrVector& GetAttributes() { return attributes_; } ClassAttrVector &GetAttributes() { return attributes_; }
ClassAttrVector attributes_; ClassAttrVector attributes_;
...@@ -99,11 +99,11 @@ class Tuple : public Object { ...@@ -99,11 +99,11 @@ class Tuple : public Object {
public: public:
Tuple() : Object(kObjectTypeTuple) {} Tuple() : Object(kObjectTypeTuple) {}
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)}; // usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
Tuple(const std::initializer_list<TypePtr>& objs) Tuple(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy // Shadow copy
explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
~Tuple() override {} ~Tuple() override {}
MS_DECLARE_PARENT(Tuple, Object) MS_DECLARE_PARENT(Tuple, Object)
...@@ -115,7 +115,7 @@ class Tuple : public Object { ...@@ -115,7 +115,7 @@ class Tuple : public Object {
std::string ToReprString() const override { return "tuple_"; } std::string ToReprString() const override { return "tuple_"; }
std::string DumpText() const override; std::string DumpText() const override;
const TypePtr operator[](size_t dim) const; const TypePtr operator[](size_t dim) const;
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtrList elements() const { return elements_; } TypePtrList elements() const { return elements_; }
std::size_t size() const { return elements_.size(); } std::size_t size() const { return elements_.size(); }
...@@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>; ...@@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>;
class Dictionary : public Object { class Dictionary : public Object {
public: public:
Dictionary() : Object(kObjectTypeDictionary) {} Dictionary() : Object(kObjectTypeDictionary) {}
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values) explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values)
: Object(kObjectTypeDictionary, false), key_values_(key_values) {} : Object(kObjectTypeDictionary, false), key_values_(key_values) {}
~Dictionary() override {} ~Dictionary() override {}
...@@ -136,7 +136,7 @@ class Dictionary : public Object { ...@@ -136,7 +136,7 @@ class Dictionary : public Object {
TypeId generic_type_id() const override { return kObjectTypeDictionary; } TypeId generic_type_id() const override { return kObjectTypeDictionary; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override; TypePtr DeepCopy() const override;
std::string ToString() const override; std::string ToString() const override;
std::string DumpText() const override; std::string DumpText() const override;
......
...@@ -24,11 +24,11 @@ ...@@ -24,11 +24,11 @@
#include "pybind_api/export_flags.h" #include "pybind_api/export_flags.h"
namespace mindspore { namespace mindspore {
bool Number::operator==(const Type& other) const { bool Number::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_number = static_cast<const Number&>(other); auto other_number = static_cast<const Number &>(other);
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
} }
......
...@@ -49,12 +49,12 @@ class Number : public Object { ...@@ -49,12 +49,12 @@ class Number : public Object {
TypeId type_id() const override { return number_type_; } TypeId type_id() const override { return number_type_; }
TypeId generic_type_id() const override { return kObjectTypeNumber; } TypeId generic_type_id() const override { return kObjectTypeNumber; }
bool operator==(const Type& other) const override; bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override { return std::make_shared<Number>(); } TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
std::string ToString() const override { return "Number"; } std::string ToString() const override { return "Number"; }
std::string ToReprString() const override { return "number"; } std::string ToReprString() const override { return "number"; }
std::string DumpText() const override { return "Number"; } std::string DumpText() const override { return "Number"; }
std::string GetTypeName(const std::string& type_name) const { std::string GetTypeName(const std::string &type_name) const {
std::ostringstream oss; std::ostringstream oss;
oss << type_name; oss << type_name;
if (nbits() != 0) { if (nbits() != 0) {
......
...@@ -51,7 +51,7 @@ class RefKeyType : public Object { ...@@ -51,7 +51,7 @@ class RefKeyType : public Object {
class RefType : public Object { class RefType : public Object {
public: public:
RefType() : Object(kObjectTypeRef) {} RefType() : Object(kObjectTypeRef) {}
RefType(const TypePtr& subtype, const TypePtr& subtype_origin) RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
~RefType() override {} ~RefType() override {}
MS_DECLARE_PARENT(RefType, Object) MS_DECLARE_PARENT(RefType, Object)
......
...@@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { ...@@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
} }
} }
const char* MetaIdLabel(const TypeId& v) { const char *MetaIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kTypeUnknown: case kTypeUnknown:
return "kTypeUnknown"; return "kTypeUnknown";
...@@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { ...@@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) {
} }
} }
const char* ObjectIdLabel(const TypeId& v) { const char *ObjectIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kObjectTypeNumber: case kObjectTypeNumber:
return "kObjectTypeNumber"; return "kObjectTypeNumber";
...@@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { ...@@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) {
} }
} }
const char* NumberIdLabel(const TypeId& v) { const char *NumberIdLabel(const TypeId &v) {
switch (v) { switch (v) {
case kNumberTypeBool: case kNumberTypeBool:
return "kNumberTypeBool"; return "kNumberTypeBool";
...@@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { ...@@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) {
} }
} }
const char* TypeIdLabel(const TypeId& v) { const char *TypeIdLabel(const TypeId &v) {
if (v < kMetaTypeEnd) { if (v < kMetaTypeEnd) {
return MetaIdLabel(v); return MetaIdLabel(v);
} else { } else {
...@@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { ...@@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) {
} }
} }
bool IsSameObjectType(const Type& lhs, const Type& rhs) { bool IsSameObjectType(const Type &lhs, const Type &rhs) {
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
return false; return false;
} }
return lhs.object_type() == rhs.object_type(); return lhs.object_type() == rhs.object_type();
} }
size_t GetTypeByte(const TypePtr& type_ptr) { size_t GetTypeByte(const TypePtr &type_ptr) {
if (type_ptr && type_ptr->isa<Number>()) { if (type_ptr && type_ptr->isa<Number>()) {
auto number = dyn_cast<Number>(type_ptr); auto number = dyn_cast<Number>(type_ptr);
if (!number) { if (!number) {
...@@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { ...@@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) {
} }
} }
bool Type::operator==(const Value& other) const { bool Type::operator==(const Value &other) const {
if (other.isa<Type>()) { if (other.isa<Type>()) {
auto other_type = static_cast<const Type*>(&other); auto other_type = static_cast<const Type *>(&other);
return *this == *other_type; return *this == *other_type;
} else { } else {
return false; return false;
...@@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { ...@@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() {
return ptr; return ptr;
} }
std::ostream& operator<<(std::ostream& os, const Type& type) { std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString(); os << type.ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const TypePtr type) { std::ostream &operator<<(std::ostream &os, const TypePtr type) {
os << type->ToString(); os << type->ToString();
return os; return os;
} }
...@@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { ...@@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const {
return false; return false;
} }
std::ostream& operator<<(std::ostream& os, const Object& obj) { std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.ToString(); os << obj.ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) { std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
os << obj->ToString(); os << obj->ToString();
return os; return os;
} }
std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
os << "["; os << "[";
for (size_t i = 0; i < types.size(); ++i) { for (size_t i = 0; i < types.size(); ++i) {
if (i > 0) { if (i > 0) {
......
...@@ -95,10 +95,10 @@ enum TypeId : int { ...@@ -95,10 +95,10 @@ enum TypeId : int {
TypeId IntBitsToTypeId(const int nbits); TypeId IntBitsToTypeId(const int nbits);
TypeId UIntBitsToTypeId(const int nbits); TypeId UIntBitsToTypeId(const int nbits);
TypeId FloatBitsToTypeId(const int nbits); TypeId FloatBitsToTypeId(const int nbits);
const char* TypeIdLabel(const TypeId& v); const char *TypeIdLabel(const TypeId &v);
TypeId NormalizeTypeId(const TypeId type_id); TypeId NormalizeTypeId(const TypeId type_id);
bool IsSameObjectType(const Type& lhs, const Type& rhs); bool IsSameObjectType(const Type &lhs, const Type &rhs);
size_t GetTypeByte(const TypePtr& type_ptr); size_t GetTypeByte(const TypePtr &type_ptr);
// Base class for all types // Base class for all types
// forward declaration. // forward declaration.
...@@ -110,14 +110,14 @@ class Type : public Value { ...@@ -110,14 +110,14 @@ class Type : public Value {
~Type() override = default; ~Type() override = default;
MS_DECLARE_PARENT(Type, Value) MS_DECLARE_PARENT(Type, Value)
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
TypeId meta_type() const { return meta_type_; } TypeId meta_type() const { return meta_type_; }
virtual TypeId type_id() const { return meta_type_; } virtual TypeId type_id() const { return meta_type_; }
virtual TypeId generic_type_id() const { return kMetaTypeType; } virtual TypeId generic_type_id() const { return kMetaTypeType; }
virtual bool operator!=(const Type& other) const { return !(*this == other); } virtual bool operator!=(const Type &other) const { return !(*this == other); }
virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
virtual bool equal(const TypePtr other) const { return *this == *other; } virtual bool equal(const TypePtr other) const { return *this == *other; }
virtual TypeId object_type() const { return kTypeUnknown; } virtual TypeId object_type() const { return kTypeUnknown; }
...@@ -134,8 +134,8 @@ class Type : public Value { ...@@ -134,8 +134,8 @@ class Type : public Value {
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
bool IsGeneric() const { return is_generic_; } bool IsGeneric() const { return is_generic_; }
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
friend std::ostream& operator<<(std::ostream& os, const Type& type); friend std::ostream &operator<<(std::ostream &os, const Type &type);
friend std::ostream& operator<<(std::ostream& os, const TypePtr type); friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
const bool parse_info_ = true; const bool parse_info_ = true;
...@@ -163,14 +163,14 @@ class Object : public Type { ...@@ -163,14 +163,14 @@ class Object : public Type {
bool equal(const TypePtr other) const override; bool equal(const TypePtr other) const override;
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
friend std::ostream& operator<<(std::ostream& os, const Object& obj); friend std::ostream &operator<<(std::ostream &os, const Object &obj);
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj); friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
private: private:
const TypeId object_type_; const TypeId object_type_;
}; };
std::ostream& operator<<(std::ostream& os, const TypePtrList& types); std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_
此差异已折叠。
...@@ -43,26 +43,26 @@ struct CloneInfo { ...@@ -43,26 +43,26 @@ struct CloneInfo {
class Cloner { class Cloner {
public: public:
explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false,
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
const TraceInfoPtr& relation = std::make_shared<TraceCopy>(), const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr& target_relation = nullptr); const TraceInfoPtr &target_relation = nullptr);
~Cloner() = default; ~Cloner() = default;
void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
const AnfNodePtrList& params = {}, CloneType type = kBasic); const AnfNodePtrList &params = {}, CloneType type = kBasic);
void Run(); void Run();
// Interfaces for specializer // Interfaces for specializer
AnfNodePtr CloneDisconnected(const AnfNodePtr& root); AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
AnfNodePtr operator[](const AnfNodePtr& node); AnfNodePtr operator[](const AnfNodePtr &node);
FuncGraphPtr operator[](const FuncGraphPtr& func_graph); FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
// Map of replicate nodes and graphs // Map of replicate nodes and graphs
std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; } std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; } std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
// Scope of cloned graphs // Scope of cloned graphs
void set_scope(const ScopePtr& scope) { scope_ = scope; } void set_scope(const ScopePtr &scope) { scope_ = scope; }
const ScopePtr scope() const { return scope_; } const ScopePtr scope() const { return scope_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_; std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
...@@ -71,31 +71,31 @@ class Cloner { ...@@ -71,31 +71,31 @@ class Cloner {
void CloneNodes(); void CloneNodes();
void LinkEdges(); void LinkEdges();
void SetDefaults(); void SetDefaults();
void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneValueNode(const AnfNodePtr& node); void CloneValueNode(const AnfNodePtr &node);
void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr& func_graph); void CloneValueNodes(const FuncGraphPtr &func_graph);
void AddChildGraphs(const FuncGraphPtr& func_graph); void AddChildGraphs(const FuncGraphPtr &func_graph);
void AddTotalGraphs(const FuncGraphPtr& func_graph); void AddTotalGraphs(const FuncGraphPtr &func_graph);
bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void GenParameters(const FuncGraphPtr& func_graph); void GenParameters(const FuncGraphPtr &func_graph);
void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node);
ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params, AnfNodePtrList *const lift_params,
AnfNodePtrList* const input_params); AnfNodePtrList *const input_params);
void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs);
void SetEdges(const FuncGraphPtr& func_graph); void SetEdges(const FuncGraphPtr &func_graph);
void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList& params); const AnfNodePtrList &params);
void Lift(); void Lift();
void LiftParameters(); void LiftParameters();
...@@ -118,17 +118,17 @@ class Cloner { ...@@ -118,17 +118,17 @@ class Cloner {
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_; std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
}; };
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
const TraceInfoPtr& relation = std::make_shared<TraceTransform>()); const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_
此差异已折叠。
此差异已折叠。
...@@ -42,25 +42,25 @@ namespace mindspore { ...@@ -42,25 +42,25 @@ namespace mindspore {
// generate a graph corresponding to these types. // generate a graph corresponding to these types.
class MetaFuncGraph : public FuncGraphBase { class MetaFuncGraph : public FuncGraphBase {
public: public:
explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
~MetaFuncGraph() override = default; ~MetaFuncGraph() override = default;
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
// Return normalized versions of the arguments. // Return normalized versions of the arguments.
// By default, this returns args unchanged. // By default, this returns args unchanged.
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list; return args_spec_list;
} }
const std::vector<Signature>& signatures() const { return signatures_; } const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; } void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments. // Generate a Graph for the given abstract arguments.
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types; TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr& arg) -> TypePtr { [](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType(); return arg->BuildType();
}); });
...@@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { ...@@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase {
} }
// Generate a Graph for this type signature. // Generate a Graph for this type signature.
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
} }
...@@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { ...@@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase {
std::string ToString() const override { return name_; } std::string ToString() const override { return name_; }
std::size_t hash() const override { return tid(); } std::size_t hash() const override { return tid(); }
virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
bool operator==(const Value& other) const override { bool operator==(const Value &other) const override {
if (other.isa<MetaFuncGraph>()) { if (other.isa<MetaFuncGraph>()) {
return &other == this; return &other == this;
} else { } else {
......
...@@ -31,7 +31,7 @@ namespace mindspore { ...@@ -31,7 +31,7 @@ namespace mindspore {
namespace tensor { namespace tensor {
void DataBuf2Contiguous(const py::array& src, py::array* const dest) { void DataBuf2Contiguous(const py::array &src, py::array *const dest) {
if (dest == nullptr) { if (dest == nullptr) {
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
} }
...@@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { ...@@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown. // MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {} MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
...@@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { ...@@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
} }
} }
MetaTensor::MetaTensor(const MetaTensor& meta_tensor) MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
if (&meta_tensor == this) { if (&meta_tensor == this) {
return *this; return *this;
} }
...@@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { ...@@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
return *this; return *this;
} }
bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
} }
...@@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { ...@@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
return type_ptr; return type_ptr;
} }
void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
DeviceInfo info(format, data_type); DeviceInfo info(format, data_type);
set_device_info(info); set_device_info(info);
} }
...@@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { ...@@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const {
return oss.str(); return oss.str();
} }
Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
...@@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { ...@@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
init(data_type_, shape_, &data_); init(data_type_, shape_, &data_);
} }
Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); } Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); }
Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); }
Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address()) { : MetaTensor(tensor), device_address_(tensor.device_address()) {
init(tensor.data_, data_type); init(tensor.data_, data_type);
} }
Tensor& Tensor::operator=(const Tensor& tensor) { Tensor &Tensor::operator=(const Tensor &tensor) {
if (this != &tensor) { if (this != &tensor) {
MetaTensor::operator=(tensor); MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty(); dirty_ = tensor.is_dirty();
...@@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { ...@@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) {
return *this; return *this;
} }
bool Tensor::operator==(const Tensor& tensor) const { bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_); return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
} }
bool Tensor::ValueEqualPy(const py::object& other) const { bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) { if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor"; MS_LOG(WARNING) << "compare other not a tensor";
return false; return false;
...@@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { ...@@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const {
return ValueEqual(py::cast<Tensor>(other)); return ValueEqual(py::cast<Tensor>(other));
} }
bool Tensor::ValueEqual(const Tensor& other) const { bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool { auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy"); auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_); auto equal = np.attr("equal")(data_, other.data_);
...@@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); } ...@@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
std::vector<int> Tensor::shape_c(void) const { return shape(); } std::vector<int> Tensor::shape_c(void) const { return shape(); }
void* Tensor::data_c(bool writable) { void *Tensor::data_c(bool writable) {
// operand of bit operation should be unsigned int. // operand of bit operation should be unsigned int.
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
bool is_c_contiguous = (flags != 0) ? true : false; bool is_c_contiguous = (flags != 0) ? true : false;
...@@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) { ...@@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) {
return data_.request(writable).ptr; return data_.request(writable).ptr;
} }
TypeId Tensor::GetDataType(const py::buffer_info& buf) const { TypeId Tensor::GetDataType(const py::buffer_info &buf) const {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (buf.format.compare("e") == 0) { if (buf.format.compare("e") == 0) {
data_type = TypeId::kNumberTypeFloat16; data_type = TypeId::kNumberTypeFloat16;
...@@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { ...@@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
return data_type; return data_type;
} }
void Tensor::init(const py::array& input, const TypePtr& type_ptr) { void Tensor::init(const py::array &input, const TypePtr &type_ptr) {
TypeId data_type = TypeId::kTypeUnknown; TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) { if (type_ptr != nullptr) {
data_type = type_ptr->type_id(); data_type = type_ptr->type_id();
...@@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { ...@@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
init(input, data_type); init(input, data_type);
} }
void Tensor::init(const py::array& input, const TypeId& data_type) { void Tensor::init(const py::array &input, const TypeId &data_type) {
py::buffer_info buf = input.request(); py::buffer_info buf = input.request();
data_type_ = GetDataType(buf); data_type_ = GetDataType(buf);
...@@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { ...@@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) {
} }
} }
void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) { void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
data_type_ = data_type; data_type_ = data_type;
shape_ = shape; shape_ = shape;
switch (data_type) { switch (data_type) {
...@@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { ...@@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return data_type_; return data_type_;
} }
bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
const TypeId out_data_type) { const TypeId out_data_type) {
if (out == nullptr) { if (out == nullptr) {
return false; return false;
...@@ -458,7 +458,7 @@ py::array Tensor::data_sync() { ...@@ -458,7 +458,7 @@ py::array Tensor::data_sync() {
return data_; return data_;
} }
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// dtype should define before Tensor, because Tensor init depend dtype // dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor") (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape")) .def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
...@@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { ...@@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
.def("__repr__", &Tensor::ToStringRepr) .def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy) .def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle( .def(py::pickle(
[](const Tensor& t) { // __getstate__ [](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(t.data()); return py::make_tuple(t.data());
}, },
[](const py::tuple& t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
throw std::runtime_error("Invalid state!"); throw std::runtime_error("Invalid state!");
} }
......
此差异已折叠。
...@@ -18,9 +18,9 @@ ...@@ -18,9 +18,9 @@
#include "pipeline/static_analysis/abstract_value.h" #include "pipeline/static_analysis/abstract_value.h"
namespace mindspore { namespace mindspore {
bool Named::operator==(const Value& other) const { bool Named::operator==(const Value &other) const {
if (other.isa<Named>()) { if (other.isa<Named>()) {
auto other_named = static_cast<const Named&>(other); auto other_named = static_cast<const Named &>(other);
return *this == other_named; return *this == other_named;
} else { } else {
return false; return false;
......
...@@ -27,18 +27,18 @@ ...@@ -27,18 +27,18 @@
namespace mindspore { namespace mindspore {
class Named : public Value { class Named : public Value {
public: public:
explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); } explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named& other) : Value(other) { Named(const Named &other) : Value(other) {
this->name_ = other.name_; this->name_ = other.name_;
hash_id_ = std::hash<std::string>{}(other.name_); hash_id_ = std::hash<std::string>{}(other.name_);
} }
~Named() override = default; ~Named() override = default;
MS_DECLARE_PARENT(Named, Value); MS_DECLARE_PARENT(Named, Value);
const std::string& name() const { return name_; } const std::string &name() const { return name_; }
virtual bool operator==(const Named& other) const { return name_ == other.name(); } virtual bool operator==(const Named &other) const { return name_ == other.name(); }
bool operator==(const Value& other) const override; bool operator==(const Value &other) const override;
Named& operator=(const Named& other) { Named &operator=(const Named &other) {
if (&other != this) { if (&other != this) {
this->type_ = other.type_; this->type_ = other.type_;
this->name_ = other.name_; this->name_ = other.name_;
...@@ -50,7 +50,7 @@ class Named : public Value { ...@@ -50,7 +50,7 @@ class Named : public Value {
std::size_t Hash() const { return hash_id_; } std::size_t Hash() const { return hash_id_; }
std::size_t hash() const override { return hash_id_; } std::size_t hash() const override { return hash_id_; }
friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
os << nmd.name(); os << nmd.name();
return os; return os;
} }
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
namespace mindspore { namespace mindspore {
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype) const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) && if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
...@@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, ...@@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag,
} }
} }
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name), : name(arg_name),
rw(rw_tag), rw(rw_tag),
kind(arg_kind), kind(arg_kind),
default_value(nullptr), default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic()) (void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite) .value("RW_WRITE", SignatureEnumRW::kRWWrite)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册