提交 c2b3360d 编写于 作者: Z zhoufeng

update clang format rule

上级 31a12009
......@@ -94,7 +94,7 @@ PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PointerAlignment: Right
RawStringFormats:
- Language: Cpp
Delimiters:
......
......@@ -23,7 +23,7 @@ namespace common {
const int CACHED_STR_NUM = 1 << 8;
const int CACHED_STR_MASK = CACHED_STR_NUM - 1;
std::vector<std::string> STR_HOLDER(CACHED_STR_NUM);
const char* SafeCStr(const std::string&& str) {
const char *SafeCStr(const std::string &&str) {
static std::atomic<uint32_t> index{0};
uint32_t cur_index = index++;
cur_index = cur_index & CACHED_STR_MASK;
......
......@@ -21,16 +21,16 @@
#include <string>
#define DISABLE_COPY_AND_ASSIGN(ClassType) \
ClassType(const ClassType&) = delete; \
ClassType& operator=(const ClassType&) = delete;
ClassType(const ClassType &) = delete; \
ClassType &operator=(const ClassType &) = delete;
namespace mindspore {
namespace common {
inline const char* SafeCStr(const std::string& str) { return str.c_str(); }
const char* SafeCStr(const std::string&& str);
inline const char *SafeCStr(const std::string &str) { return str.c_str(); }
const char *SafeCStr(const std::string &&str);
static inline std::string GetEnv(const std::string& envvar) {
const char* value = ::getenv(envvar.c_str());
static inline std::string GetEnv(const std::string &envvar) {
const char *value = ::getenv(envvar.c_str());
if (value == nullptr) {
return std::string();
......
......@@ -34,11 +34,11 @@ class DecodeOp : public TensorOp {
~DecodeOp() = default;
Status Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) override;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
void Print(std::ostream& out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
void Print(std::ostream &out) const override { out << "DecodeOp"; }
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private:
bool is_rgb_format_ = true;
......
......@@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int
rnd_.seed(seed_);
}
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) {
Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output);
if (input.size() != NumInput())
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5");
......@@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector<std::shared_ptr<Tenso
return Status::OK();
}
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inputs,
std::vector<TensorShape>& outputs) {
Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape> &inputs,
std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{-1, -1};
......@@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector<TensorShape>& inp
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) {
Status DistortBoundingBoxCropOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
outputs[0] = inputs[0];
return Status::OK();
......
......@@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp {
~DistortBoundingBoxCropOp() override = default;
void Print(std::ostream& out) const override {
void Print(std::ostream &out) const override {
out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_;
}
Status Compute(const std::vector<std::shared_ptr<Tensor>>& input,
std::vector<std::shared_ptr<Tensor>>* output) override;
Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) override;
uint32_t NumInput() override { return 5; }
Status OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) override;
Status OutputType(const std::vector<DataType>& inputs, std::vector<DataType>& outputs) override;
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
private:
int32_t max_attempts_;
......
......@@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ
rnd_.seed(GetSeed());
}
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std::shared_ptr<Tensor>* output) {
Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal");
......@@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr<Tensor>& input, std:
(void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width);
return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_);
}
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs, std::vector<TensorShape>& outputs) {
Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
TensorShape out = TensorShape{target_height_, target_width_};
......@@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector<TensorShape>& inputs
if (!outputs.empty()) return Status::OK();
return Status(StatusCode::kUnexpectedError, "Input has a wrong shape");
}
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) {
Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) {
double scale, aspect;
*crop_width = w_in;
*crop_height = h_in;
......
......@@ -22,7 +22,7 @@
namespace mindspore {
constexpr char PARALLEL_STRATEGY[] = "strategy";
void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false);
void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false);
} // namespace mindspore
......
......@@ -39,7 +39,7 @@
namespace mindspore {
struct ParamPtrEqual {
bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const {
bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const {
const ParameterPtr param1 = dyn_cast<Parameter>(t1);
const ParameterPtr param2 = dyn_cast<Parameter>(t2);
......@@ -52,7 +52,7 @@ struct ParamPtrEqual {
};
struct ParamPtrHasher {
std::size_t operator()(AnfNodePtr const& param) const {
std::size_t operator()(AnfNodePtr const &param) const {
const ParameterPtr parameter = dyn_cast<Parameter>(param);
if (parameter == nullptr) {
return 0;
......@@ -64,39 +64,39 @@ struct ParamPtrHasher {
class AnfExporter {
public:
explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false)
explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false)
: param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) {
func_graph_set.clear();
exported.clear();
}
virtual ~AnfExporter() {}
void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void ExportFuncGraph(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
void ExportFuncGraph(const std::string &filename, const std::vector<TaggedGraph> &graphs);
protected:
virtual std::string GetNodeType(const AnfNodePtr& nd);
int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr& param);
std::string DumpObject(const py::object& obj, const std::string& category) const;
std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst);
std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetPrimitiveText(const PrimitivePtr& prim);
std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value);
std::string GetNameSpaceText(const parse::NameSpacePtr& ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node,
const std::map<AnfNodePtr, int>& apply_map);
void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph);
void OutputParameters(std::ofstream& ofs, const std::vector<AnfNodePtr>& parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual>* param_map);
void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node);
void OutputCNodes(std::ofstream& ofs, const std::vector<AnfNodePtr>& nodes, const FuncGraphPtr& func_graph);
virtual std::string GetNodeType(const AnfNodePtr &nd);
int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &param, bool throw_excp = true);
int GetParamIndexFromExported(const AnfNodePtr &param);
std::string DumpObject(const py::object &obj, const std::string &category) const;
std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node);
std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph);
std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst);
std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetPrimitiveText(const PrimitivePtr &prim);
std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value);
std::string GetNameSpaceText(const parse::NameSpacePtr &ns);
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph);
std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::map<AnfNodePtr, int> &apply_map);
void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph);
void OutputParameters(std::ofstream &ofs, const std::vector<AnfNodePtr> &parameters,
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> *param_map);
void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node);
void OutputCNodes(std::ofstream &ofs, const std::vector<AnfNodePtr> &nodes, const FuncGraphPtr &func_graph);
int param_index;
OrderedSet<FuncGraphPtr> func_graph_set{};
......@@ -108,16 +108,16 @@ class AnfExporter {
abstract::AnfNodeConfigPtr node_cfg_ = nullptr;
};
void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph);
void ExportIR(const std::string& filename, const std::vector<TaggedGraph>& graphs);
void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph);
void ExportIR(const std::string &filename, const std::vector<TaggedGraph> &graphs);
std::vector<FuncGraphPtr> ImportIR(const std::string& filename);
std::vector<FuncGraphPtr> ImportIR(const std::string &filename);
std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph);
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
std::string GetOnnxProtoString(const FuncGraphPtr& func_graph);
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_
......@@ -34,7 +34,7 @@ namespace draw {
namespace {
// Only for ValueNode
std::string ValueType(const ValueNodePtr& node) {
std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) {
return "";
}
......@@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) {
return v->type_name();
}
std::string ReplaceSpecialChar(const std::string& str) {
std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') {
......@@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) {
} // namespace
// API of debug utils
void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs,
void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) {
if (sub_graphs == nullptr) {
return;
}
for (auto& nd : nodes) {
for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) {
......@@ -84,16 +84,16 @@ void DrawNodes(const std::vector<AnfNodePtr>& nodes, OrderedMap<FuncGraphPtr, st
}
}
void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>>* sub_graphs) {
void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) {
return;
}
int dup_idx = 0;
for (auto& nd : nodes) {
for (auto& t : SuccIncoming(nd)) {
for (auto &nd : nodes) {
for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
......@@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector<AnfNodePtr>& nodes,
}
}
void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseDigraph>& digraph, bool is_user) {
void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) {
return;
}
......@@ -120,11 +120,11 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}
// Draw edge
for (auto& nd : nodes) {
for (auto &nd : nodes) {
auto succs = SuccIncoming(nd);
auto num = succs.size();
for (size_t i = 0; i < num; i++) {
auto& t = succs.at(i);
auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) {
......@@ -143,7 +143,7 @@ void DrawEdges(const std::vector<AnfNodePtr>& nodes, const std::shared_ptr<BaseD
}
}
void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_user) {
void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) {
return;
}
......@@ -169,7 +169,7 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
DrawValueNodes(nodes, &sub_graphs);
// Draw subgraph
for (const auto& gsub : sub_graphs) {
for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second);
}
......@@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use
}
#ifdef ENABLE_DUMP_IR
void Draw(const std::string& filename, const FuncGraphPtr& func_graph) {
void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot";
std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false);
}
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) {
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true);
}
#else
void Draw(const std::string&, const FuncGraphPtr&) {
void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
......@@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) {
<< "please recompile source to enable it. See help of building script.";
}
void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) {
void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
......@@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) {
return "plaintext";
}
std::string Graphviz::Color(const AnfNodePtr& node) {
std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) {
return "";
}
......@@ -259,7 +259,7 @@ void BaseDigraph::Start() {
buffer_ << "compound=true" << std::endl;
}
void BaseDigraph::Head(const AnfNodePtr& node, int id) {
void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) {
return;
}
......@@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) {
}
}
void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) {
return;
}
......@@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) {
buffer_ << ":" << idx;
}
void BaseDigraph::Tail(const FuncGraphPtr& func_graph) {
void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
......@@ -304,12 +304,12 @@ void BaseDigraph::End() {
}
}
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>";
int count = 0;
for (auto& parameter : key->parameters()) {
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
......@@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) {
buffer_ << "</table>>,];";
}
void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub) {
void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) {
return;
}
......@@ -361,12 +361,12 @@ Digraph::~Digraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_;
}
}
static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) {
static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to);
......@@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st
return str;
}
static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>";
......@@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>";
int i = 0;
for (const auto& attr : attrs) {
for (const auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
......@@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) {
graph_obj->buffer() << "</table>>,";
}
static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
......@@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) {
}
}
static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return;
}
......@@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) {
}
graph_obj->buffer() << ">";
int i = 0;
for (auto& attr : attrs) {
for (auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
......@@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_;
}
}
......
......@@ -31,9 +31,9 @@ namespace parse = mindspore::parse;
class Graphviz {
public:
Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {}
Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {}
explicit Graphviz(const std::string& name) : name_(name) {}
explicit Graphviz(const std::string &name) : name_(name) {}
virtual ~Graphviz() {}
......@@ -41,8 +41,8 @@ class Graphviz {
virtual void End() {}
virtual std::string Shape(AnfNodePtr node);
std::string Color(const AnfNodePtr& node);
std::ostringstream& buffer() { return buffer_; }
std::string Color(const AnfNodePtr &node);
std::ostringstream &buffer() { return buffer_; }
std::ostringstream buffer_;
protected:
......@@ -53,8 +53,8 @@ class Graphviz {
class BaseDigraph : public Graphviz {
public:
BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string& name) : Graphviz(name) {}
BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {}
explicit BaseDigraph(const std::string &name) : Graphviz(name) {}
~BaseDigraph() override = default;
virtual void Node(AnfNodePtr node, int id = 0) = 0;
......@@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz {
void Start() override;
void End() override;
virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start);
void FuncGraphParameters(const FuncGraphPtr& key);
void SubGraph(const FuncGraphPtr& key, const std::shared_ptr<BaseDigraph>& gsub);
void FuncGraphParameters(const FuncGraphPtr &key);
void SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub);
const std::string& name() const { return name_; }
const std::string &name() const { return name_; }
protected:
void Head(const AnfNodePtr& node, int id = 0);
void Tail(const AnfNodePtr& node, int idx, int id = 0);
void Tail(const FuncGraphPtr& func_graph);
void Head(const AnfNodePtr &node, int id = 0);
void Tail(const AnfNodePtr &node, int idx, int id = 0);
void Tail(const FuncGraphPtr &func_graph);
};
class Digraph : public BaseDigraph {
public:
Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string& name) : BaseDigraph(name) {}
Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit Digraph(const std::string &name) : BaseDigraph(name) {}
~Digraph() override;
void Node(AnfNodePtr node, int id = 0) override;
......@@ -86,8 +86,8 @@ class Digraph : public BaseDigraph {
class ModelDigraph : public BaseDigraph {
public:
ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {}
ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {}
explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {}
~ModelDigraph() override;
std::string Shape(AnfNodePtr node) override;
......@@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph {
};
// API to draw
void Draw(const std::string& filename, const FuncGraphPtr& func_graph);
void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph);
void Draw(const std::string &filename, const FuncGraphPtr &func_graph);
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph);
} // namespace draw
} // namespace mindspore
......
此差异已折叠。
......@@ -36,7 +36,7 @@ Dump::Dump()
dump_iter_(0),
cur_iter_(0) {}
bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
bool Dump::IsKernelNeedDump(const std::string &kernel_name) {
if (dump_mode_ == 0) {
// Dump All Kernels mode
return true;
......@@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) {
return false;
}
bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
bool Dump::ParseDumpConfig(const std::string &dump_config_file) {
std::ifstream jsonFile(dump_config_file);
if (!jsonFile.is_open()) {
MS_LOG(ERROR) << dump_config_file << " open failed.";
......@@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) {
return true;
}
bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) {
if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() ||
dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() ||
dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() ||
......@@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) {
return true;
}
bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto trans_flag = dumpSettings.at("trans_flag");
auto enable = dumpSettings.at("enable");
auto mode = dumpSettings.at("mode");
......@@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) {
dump_path_ = path;
dump_net_name_ = net_name;
dump_iter_ = iteration;
for (const auto& kernel : kernels) {
for (const auto &kernel : kernels) {
dump_kernels_.push_back(kernel);
}
return true;
}
bool Dump::SetDumpConfFromJsonFile() {
const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH");
if (config_path_str != nullptr) {
MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str;
} else {
......@@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() {
return ParseDumpConfig(dump_config_file);
}
bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) {
bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) {
if (filename.empty() || data == nullptr || len == 0) {
MS_LOG(ERROR) << "Incorrect parameter.";
return false;
......@@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len)
MS_LOG(ERROR) << "Open file " << realpath << " fail.";
return false;
}
(void)fd.write(reinterpret_cast<const char*>(data), SizeToLong(len));
(void)fd.write(reinterpret_cast<const char *>(data), SizeToLong(len));
fd.close();
return true;
}
bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) {
MS_EXCEPTION_IF_NULL(outpath);
auto path_split_pos = inpath.find_last_of('/');
if (path_split_pos == std::string::npos) {
......@@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) {
return true;
}
bool Dump::CreateNotExistDirs(const std::string& path) {
bool Dump::CreateNotExistDirs(const std::string &path) {
std::shared_ptr<system::FileSystem> fs = system::Env::GetFileSystem();
MS_EXCEPTION_IF_NULL(fs);
char temp_path[PATH_MAX] = {0};
......
......@@ -43,11 +43,11 @@ class Dump {
uint32_t cur_iter() const { return cur_iter_; }
bool IsKernelNeedDump(const std::string& kernel_name);
bool IsKernelNeedDump(const std::string &kernel_name);
bool SetDumpConfFromJsonFile();
static bool DumpToFile(const std::string& filename, const void* data, size_t len);
static bool DumpToFile(const std::string &filename, const void *data, size_t len);
protected:
bool dump_enable_;
......@@ -59,14 +59,14 @@ class Dump {
uint32_t cur_iter_;
std::vector<std::string> dump_kernels_;
static bool GetRealPath(const std::string& inpath, std::string* outpath);
static bool GetRealPath(const std::string &inpath, std::string *outpath);
static bool CreateNotExistDirs(const std::string& path);
static bool CreateNotExistDirs(const std::string &path);
private:
bool ParseDumpConfig(const std::string& dump_config_file);
bool IsConfigExist(const nlohmann::json& dumpSettings);
bool IsConfigValid(const nlohmann::json& dumpSettings);
bool ParseDumpConfig(const std::string &dump_config_file);
bool IsConfigExist(const nlohmann::json &dumpSettings);
bool IsConfigValid(const nlohmann::json &dumpSettings);
};
using DumpConfPtr = std::shared_ptr<Dump>;
......
......@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) {
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {
std::string temp_line = line;
if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) &&
tip != kSourceLineTipDiscard) {
......@@ -101,14 +101,14 @@ DebugInfo::DebugInfo() {
name_ = "";
}
DebugInfo::DebugInfo(const std::string& name) {
DebugInfo::DebugInfo(const std::string &name) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
name_ = name;
}
DebugInfo::DebugInfo(const LocationPtr& loc) {
DebugInfo::DebugInfo(const LocationPtr &loc) {
InitValueFromContext();
unique_id_ = gen_unique_id();
debug_id_ = -1;
......@@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() {
}
int64_t DebugInfo::unique_id_through_copy() const {
TraceInfoPtr trace_info = const_cast<DebugInfo*>(this)->trace_info();
TraceInfoPtr trace_info = const_cast<DebugInfo *>(this)->trace_info();
if (trace_info != nullptr) {
if (trace_info->isa<TraceCopy>() && trace_info->debug_info() != nullptr) {
return trace_info->debug_info()->unique_id_through_copy();
......@@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() {
}
return DebugInfo::location();
}
void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; }
void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; }
TraceContextPtr TraceManager::CurrentContextInfo() {
if (!TraceManager::trace_context_stack_.empty()) {
......@@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() {
return nullptr;
}
void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) {
void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
context->set_func_name(func_name);
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const LocationPtr& location) {
void TraceManager::DebugTrace(const LocationPtr &location) {
TraceContextPtr context = std::make_shared<TraceContext>(location);
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}
......@@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) {
TraceManager::trace_context_stack_.push(context);
}
void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) {
void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) {
if (trace_info == nullptr) {
MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null";
}
......
......@@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou
// Location class record the location in source code.
class Location {
public:
Location(const std::string& file_name, int line, int column, int line_end, int column_end)
Location(const std::string &file_name, int line, int column, int line_end, int column_end)
: file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {}
Location(const Location& loc)
Location(const Location &loc)
: file_name_(loc.file_name_),
line_(loc.line_),
column_(loc.column_),
......@@ -77,21 +77,21 @@ class TraceManager {
TraceManager() = default;
~TraceManager() = default;
static TraceContextPtr CurrentContextInfo();
static void DebugTrace(const std::string& func_name, const LocationPtr& location);
static void DebugTrace(const LocationPtr& location);
static void DebugTrace(const TraceInfoPtr& trace_info);
static void DebugTrace(const std::string &func_name, const LocationPtr &location);
static void DebugTrace(const LocationPtr &location);
static void DebugTrace(const TraceInfoPtr &trace_info);
// debug trace with a cloned trace info with debug_info
static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info);
static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info);
static void EndTrace();
static std::stack<TraceContextPtr> trace_context_stack_;
};
class TraceGuard {
public:
explicit TraceGuard(const std::string func_name, const LocationPtr& location) {
explicit TraceGuard(const std::string func_name, const LocationPtr &location) {
TraceManager::DebugTrace(func_name, location);
}
explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); }
explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); }
~TraceGuard() { TraceManager::EndTrace(); }
};
......@@ -106,23 +106,23 @@ class TraceContext {
public:
~TraceContext() = default;
explicit TraceContext(const LocationPtr& loc) {
explicit TraceContext(const LocationPtr &loc) {
ProcessAttributeFromContext();
location_ = loc;
}
explicit TraceContext(const std::string& func_name) {
explicit TraceContext(const std::string &func_name) {
ProcessAttributeFromContext();
func_name_ = func_name;
}
explicit TraceContext(const TraceInfoPtr& trace_info) {
explicit TraceContext(const TraceInfoPtr &trace_info) {
ProcessAttributeFromContext();
trace_info_ = trace_info;
}
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
LocationPtr location() { return location_; }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_func_name(const std::string& func_name) { func_name_ = func_name; }
void set_func_name(const std::string &func_name) { func_name_ = func_name; }
std::string func_name() { return func_name_; }
};
......@@ -130,9 +130,9 @@ class DebugInfo : public Base {
public:
DebugInfo();
explicit DebugInfo(const std::string& name);
explicit DebugInfo(const std::string &name);
explicit DebugInfo(const LocationPtr& loc);
explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default;
MS_DECLARE_PARENT(DebugInfo, Base);
......@@ -141,12 +141,12 @@ class DebugInfo : public Base {
int64_t unique_id_through_copy() const;
std::string get_id() { return std::to_string(debug_id()); }
void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; }
void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; }
TraceInfoPtr trace_info() { return trace_info_; }
void set_location(const LocationPtr& loc) { location_ = loc; }
void set_location(const LocationPtr &loc) { location_ = loc; }
virtual LocationPtr location() { return location_; }
std::string name() { return name_; }
void set_name(const std::string& name) { name_ = name; }
void set_name(const std::string &name) { name_ = name; }
virtual std::string debug_name();
virtual std::string get_python_func_belonged() { return ""; }
......@@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo {
py_func_belonged_ = context_info->func_name();
}
}
explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) {
explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_belonged_ = context_info->func_name();
......@@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo {
~NodeDebugInfo() override = default;
std::string debug_name() override;
void set_node(const std::shared_ptr<AnfNode>& node) { node_ = AnfNodeWeakPtr(node); }
void set_node(const std::shared_ptr<AnfNode> &node) { node_ = AnfNodeWeakPtr(node); }
std::shared_ptr<AnfNode> get_node() const { return node_.lock(); }
void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; }
void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; }
std::string get_python_func_belonged() override { return py_func_belonged_; }
AnfNodeWeakPtr node_;
std::string py_func_belonged_;
......@@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo {
}
}
explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) {
explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) {
if (TraceManager::CurrentContextInfo() != nullptr) {
auto context_info = TraceManager::CurrentContextInfo();
py_func_name_ = context_info->func_name();
......@@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo {
std::string debug_name() override;
LocationPtr location() override;
LocationPtr deco_location() { return deco_loc_; }
void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
FuncGraphPtr get_graph() const { return func_graph_.lock(); }
void set_full_name(const std::string& name) { full_name_ = name; }
void set_full_name(const std::string &name) { full_name_ = name; }
std::string get_full_name() { return full_name_; }
void set_deco_location(const LocationPtr& deco_list_loc);
void set_deco_location(const LocationPtr &deco_list_loc);
std::string get_python_func_belonged() override { return py_func_name_; }
FuncGraphWeakPtr func_graph_;
LocationPtr deco_loc_;
......
......@@ -31,7 +31,7 @@ struct NameWithTrace {
std::string name;
std::vector<std::string> trace_labels;
};
static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) {
static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) {
switch (trace_label) {
case TraceLabelType::kShortSymbol:
return trace_info->symbol();
......@@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t
}
}
NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name;
// find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node
auto temp_info = debug_info;
......@@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe
return trace_name;
}
std::string CombineTraceTypes(const std::string& root_name, const std::vector<std::string>& trace_labels) {
std::string CombineTraceTypes(const std::string &root_name, const std::vector<std::string> &trace_labels) {
std::string tags = "";
for (auto& itr : trace_labels) {
for (auto &itr : trace_labels) {
std::string symbol = itr;
tags = tags + symbol;
}
......@@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector<st
}
// get the label name of the node debug info
std::string LabelString(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string LabelString(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
NameWithTrace trace_name = RootName(debug_info, trace_label);
return CombineTraceTypes(trace_name.name, trace_name.trace_labels);
}
std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
auto temp_info = debug_info;
std::string label = "";
while (temp_info != nullptr) {
......@@ -103,9 +103,9 @@ std::string CombineUniqueID(const DebugInfoPtr& debug_info) {
}
// get trace with unique id chain
std::string LabelStringUnique(const DebugInfoPtr& debug_info) { return CombineUniqueID(debug_info); }
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_label) {
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
return LabelStringUnique(debug_info);
}
......
......@@ -29,7 +29,7 @@ namespace label_manage {
enum class TraceLabelType { kShortSymbol, kFullName, kWithUniqueId };
TraceLabelType GetGlobalTraceLabelType();
void SetGlobalTraceLabelType(TraceLabelType label_type);
std::string Label(const DebugInfoPtr& debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_type = TraceLabelType::kShortSymbol);
} // namespace label_manage
} // namespace mindspore
......
......@@ -37,7 +37,7 @@
namespace mindspore {
// namespace to support debug trace infomation
namespace trace {
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs) {
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs) {
if (abs == nullptr) {
return "Null Abstract";
}
......@@ -69,7 +69,7 @@ std::vector<DebugInfoPtr> GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) {
return debug_with_loc_vec;
}
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) {
auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info);
if (debug_with_loc_vec.size() > 0) {
return debug_with_loc_vec[0];
......@@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) {
}
}
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
......@@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
// a trace info identifies a node transform, so we can trace the node transform through
// a link of trace info and debug info
std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceLineTip tip) {
std::string GetInfoWithAction(const std::vector<DebugInfoPtr> &info_vec, SourceLineTip tip) {
if (info_vec.size() < 1) {
return "";
}
......@@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector<DebugInfoPtr>& info_vec, SourceL
return traced_info;
}
std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) {
if (info == nullptr) {
return "";
}
......@@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) {
return "";
}
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) {
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) {
std::ostringstream oss;
if (info == nullptr) {
return "";
......@@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So
return oss.str();
}
std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) {
std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) {
std::ostringstream oss;
oss << "graph:" << graph->ToString() << " with args[";
auto params = graph->parameters();
......@@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas
return oss.str();
}
void DumpInferStack(std::ostringstream& oss) {
auto& infer_stack = GetCurrenGraphInferStack();
void DumpInferStack(std::ostringstream &oss) {
auto &infer_stack = GetCurrenGraphInferStack();
if (infer_stack.empty()) {
return;
}
......@@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) {
}
std::reverse(infer_vec.begin(), infer_vec.end());
int index = 0;
for (auto& item : infer_vec) {
for (auto &item : infer_vec) {
auto graph_infer = std::dynamic_pointer_cast<abstract::BaseFuncGraphEvaluator>(item.first);
if (graph_infer == nullptr) {
MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator";
......@@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) {
}
void TraceGraphInfer() {
auto& infer_stack = GetCurrenGraphInferStack();
auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss;
if (infer_stack.empty()) {
return;
......@@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter {
AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {}
~AnalyzedFuncGraphExporter() override = default;
void ExportFuncGraph(const std::string& filename, const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs);
void ExportFuncGraph(const std::string &filename, const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs);
private:
std::string GetNodeType(const AnfNodePtr& nd) override;
std::string GetNodeType(const AnfNodePtr &nd) override;
};
std::unordered_map<FuncGraphPtr, TaggedNodeMap> CalcTaggedFuncGraphs() {
std::unordered_map<FuncGraphPtr, TaggedNodeMap> tagged_func_graphs;
auto& list = GetCNodeDebugStack();
auto &list = GetCNodeDebugStack();
for (size_t i = 0; i < list.size(); ++i) {
auto node_cfg = list[i];
auto fg = node_cfg->context()->func_graph();
......@@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() {
exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack());
}
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
if (node_cfg_ == nullptr) {
return AnfExporter::GetNodeType(node);
}
......@@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) {
return oss.str();
}
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
const std::vector<abstract::AnfNodeConfigPtr>& node_cfgs) {
void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
const std::vector<abstract::AnfNodeConfigPtr> &node_cfgs) {
if (node_cfgs.empty()) {
MS_LOG(DEBUG) << "Node configs is empty";
return;
......@@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
auto tagged_func_graphs = CalcTaggedFuncGraphs();
// first output graph on the analysis stack
for (const auto& node_cfg : node_cfgs) {
for (const auto &node_cfg : node_cfgs) {
auto fg = node_cfg->context()->func_graph();
// the graph is already output, skip it
if (exported.find(fg) != exported.end()) {
......@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename,
ofs.close();
}
void GetInferStackInfo(std::ostringstream& oss) {
void GetInferStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack();
if (stack.empty()) {
......@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) {
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
......@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An
}
}
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
}
......@@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) {
}
}
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack() { return cnode_debug_stack; }
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack() {
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack() {
return graph_infer_stack;
}
void ClearTraceStack() {
......
......@@ -31,19 +31,19 @@
namespace mindspore {
namespace trace {
std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix,
std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer();
void GetInferStackInfo(std::ostringstream& oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg);
void GetInferStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr>& GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>>& GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr& abs);
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
void ClearTraceStack();
} // namespace trace
} // namespace mindspore
......
......@@ -23,7 +23,7 @@
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {
if (info == nullptr) {
return "";
}
......
......@@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr<DebugInfo>;
// namespace to support intermediate representation definition
class TraceInfo : public Base {
public:
TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) {
TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) {
symbol_ = symbol;
full_name_ = full_name;
name_ = full_name_;
debug_info_ = info;
}
TraceInfo(const TraceInfo& info)
TraceInfo(const TraceInfo &info)
: Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {}
virtual ~TraceInfo() = default;
MS_DECLARE_PARENT(TraceInfo, Base);
......@@ -55,8 +55,8 @@ class TraceInfo : public Base {
virtual std::string full_name() { return full_name_; }
virtual TraceInfoPtr clone() { return shared_from_base<TraceInfo>(); }
virtual std::string action_name() { return ""; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr& info);
void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; }
virtual std::string GetActionBetweenNode(const DebugInfoPtr &info);
void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; }
DebugInfoPtr debug_info() { return debug_info_; }
DebugInfoPtr DebugInfoHasLoc();
std::vector<std::pair<DebugInfoPtr, TraceInfoPtr>> GetSourceCodeDebugInfo();
......@@ -70,7 +70,7 @@ class TraceInfo : public Base {
class TracePhi : public TraceInfo {
public:
explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {}
explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {}
MS_DECLARE_PARENT(TracePhi, TraceInfo);
~TracePhi() override = default;
TraceInfoPtr clone() override { return std::make_shared<TracePhi>(*shared_from_base<TracePhi>()); }
......@@ -78,8 +78,8 @@ class TracePhi : public TraceInfo {
class TraceIfStmtTrueBranch : public TraceInfo {
public:
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {}
TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default;
explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {}
MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo);
~TraceIfStmtTrueBranch() override = default;
TraceInfoPtr clone() override {
......@@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo {
class TraceIfStmtFalseBranch : public TraceInfo {
public:
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {}
TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default;
explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {}
MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo);
~TraceIfStmtFalseBranch() override = default;
TraceInfoPtr clone() override {
......@@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo {
class TraceIfStmtAfterBranch : public TraceInfo {
public:
explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {}
explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {}
MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo);
~TraceIfStmtAfterBranch() override = default;
TraceInfoPtr clone() override {
......@@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo {
class TraceIfExpTrueBranch : public TraceInfo {
public:
explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {}
explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {}
MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo);
~TraceIfExpTrueBranch() override = default;
TraceInfoPtr clone() override {
......@@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo {
class TraceIfExpFalseBranch : public TraceInfo {
public:
explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {}
explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {}
MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo);
~TraceIfExpFalseBranch() override = default;
TraceInfoPtr clone() override {
......@@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo {
class TraceCopy : public TraceInfo {
public:
TraceCopy() : TraceInfo(nullptr, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {}
explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {}
MS_DECLARE_PARENT(TraceCopy, TraceInfo);
~TraceCopy() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceCopy>(*shared_from_base<TraceCopy>()); }
......@@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo {
class TraceIterator : public TraceInfo {
public:
explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {}
explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {}
MS_DECLARE_PARENT(TraceIterator, TraceInfo);
~TraceIterator() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceIterator>(*shared_from_base<TraceIterator>()); }
......@@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo {
class TraceWhileHeader : public TraceInfo {
public:
explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {}
explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {}
MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo);
~TraceWhileHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileHeader>(*shared_from_base<TraceWhileHeader>()); }
......@@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo {
class TraceWhileBody : public TraceInfo {
public:
explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {}
explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {}
MS_DECLARE_PARENT(TraceWhileBody, TraceInfo);
~TraceWhileBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileBody>(*shared_from_base<TraceWhileBody>()); }
......@@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo {
class TraceWhileAfter : public TraceInfo {
public:
explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {}
explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {}
MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo);
~TraceWhileAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceWhileAfter>(*shared_from_base<TraceWhileAfter>()); }
......@@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo {
class TraceForHeader : public TraceInfo {
public:
explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {}
explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {}
MS_DECLARE_PARENT(TraceForHeader, TraceInfo);
~TraceForHeader() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForHeader>(*shared_from_base<TraceForHeader>()); }
......@@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo {
class TraceForBody : public TraceInfo {
public:
explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {}
explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {}
MS_DECLARE_PARENT(TraceForBody, TraceInfo);
~TraceForBody() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForBody>(*shared_from_base<TraceForBody>()); }
......@@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo {
class TraceForAfter : public TraceInfo {
public:
explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {}
explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {}
MS_DECLARE_PARENT(TraceForAfter, TraceInfo);
~TraceForAfter() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
......@@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo {
class TraceEquiv : public TraceInfo {
public:
explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {}
explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
MS_DECLARE_PARENT(TraceEquiv, TraceInfo);
~TraceEquiv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceEquiv>(*shared_from_base<TraceEquiv>()); }
......@@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo {
class TraceGradFpropApp : public TraceInfo {
public:
TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {}
explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {}
explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {}
MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo);
~TraceGradFpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFpropApp>(*shared_from_base<TraceGradFpropApp>()); }
......@@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo {
class TraceGradBpropApp : public TraceInfo {
public:
TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {}
explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {}
explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {}
MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo);
~TraceGradBpropApp() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBpropApp>(*shared_from_base<TraceGradBpropApp>()); }
......@@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo {
class TraceGradFprop : public TraceInfo {
public:
TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {}
explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {}
explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {}
MS_DECLARE_PARENT(TraceGradFprop, TraceInfo);
~TraceGradFprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradFprop>(*shared_from_base<TraceGradFprop>()); }
......@@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo {
class TraceGradBprop : public TraceInfo {
public:
TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {}
explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {}
explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {}
MS_DECLARE_PARENT(TraceGradBprop, TraceInfo);
~TraceGradBprop() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradBprop>(*shared_from_base<TraceGradBprop>()); }
......@@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo {
class TraceGradSens : public TraceInfo {
public:
TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {}
explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {}
explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {}
MS_DECLARE_PARENT(TraceGradSens, TraceInfo);
~TraceGradSens() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGradSens>(*shared_from_base<TraceGradSens>()); }
......@@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo {
class TraceSpecialize : public TraceInfo {
public:
explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; }
MS_DECLARE_PARENT(TraceSpecialize, TraceInfo);
std::string name() override { return full_name_ + counter_; }
std::string symbol() override { return counter_ + "_"; }
......@@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo {
class TraceGradOperation : public TraceInfo {
public:
explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {}
explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {}
MS_DECLARE_PARENT(TraceGradOperation, TraceInfo);
~TraceGradOperation() override = default;
TraceInfoPtr clone() override {
......@@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo {
class TraceForceBool : public TraceInfo {
public:
explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {}
explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {}
MS_DECLARE_PARENT(TraceForceBool, TraceInfo);
~TraceForceBool() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceForceBool>(*shared_from_base<TraceForceBool>()); }
......@@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo {
class TraceExpandJ : public TraceInfo {
public:
explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {}
explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {}
MS_DECLARE_PARENT(TraceExpandJ, TraceInfo);
~TraceExpandJ() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceExpandJ>(*shared_from_base<TraceExpandJ>()); }
......@@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo {
class TraceGenMetaFuncGraph : public TraceInfo {
public:
explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {}
MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo);
~TraceGenMetaFuncGraph() override = default;
TraceInfoPtr clone() override {
......@@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo {
class TraceEvaluatorGenGraph : public TraceInfo {
public:
explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {}
MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo);
~TraceEvaluatorGenGraph() override = default;
TraceInfoPtr clone() override {
......@@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo {
class TraceResolve : public TraceInfo {
public:
explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {}
explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {}
MS_DECLARE_PARENT(TraceResolve, TraceInfo);
~TraceResolve() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceResolve>(*shared_from_base<TraceResolve>()); }
......@@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo {
class TraceTransform : public TraceInfo {
public:
TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; }
explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") {
explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") {
transform_name_ = transform_name;
}
......@@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo {
class TraceGenerateVarArg : public TraceInfo {
public:
explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {}
explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {}
MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo);
~TraceGenerateVarArg() override = default;
TraceInfoPtr clone() override {
......@@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo {
class TraceGenerateKwArg : public TraceInfo {
public:
explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {}
explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {}
MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo);
~TraceGenerateKwArg() override = default;
TraceInfoPtr clone() override {
......@@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo {
class TraceTrasformK : public TraceInfo {
public:
explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {}
explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {}
MS_DECLARE_PARENT(TraceTrasformK, TraceInfo);
~TraceTrasformK() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceTrasformK>(*shared_from_base<TraceTrasformK>()); }
......@@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo {
class TracePartialTransform : public TraceInfo {
public:
explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {}
explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {}
MS_DECLARE_PARENT(TracePartialTransform, TraceInfo);
~TracePartialTransform() override = default;
TraceInfoPtr clone() override {
......@@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo {
class TraceGetEnv : public TraceInfo {
public:
explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {}
explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {}
MS_DECLARE_PARENT(TraceGetEnv, TraceInfo);
~TraceGetEnv() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceGetEnv>(*shared_from_base<TraceGetEnv>()); }
......@@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo {
class TraceDoSignature : public TraceInfo {
public:
explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {}
explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {}
MS_DECLARE_PARENT(TraceDoSignature, TraceInfo);
~TraceDoSignature() override = default;
TraceInfoPtr clone() override { return std::make_shared<TraceDoSignature>(*shared_from_base<TraceDoSignature>()); }
......@@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo {
class TraceCombileLikeGraphs : public TraceInfo {
public:
TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {}
explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {}
MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo);
~TraceCombileLikeGraphs() override = default;
TraceInfoPtr clone() override {
......
......@@ -21,7 +21,7 @@
namespace mindspore {
namespace device {
namespace ascend {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (has_malloc_) {
MS_LOG(EXCEPTION) << "Has alloc memory pool memory !";
}
......@@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return size;
}
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) {
bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
MS_EXCEPTION_IF_NULL(addr);
has_malloc_ = false;
free_mem_size_ = total_mem_size_;
......@@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const {
size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; }
void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) {
void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) {
MS_EXCEPTION_IF_NULL(device_mem_pool_base);
device_mem_pool_base_ = device_mem_pool_base;
}
......
......@@ -26,12 +26,12 @@ namespace ascend {
class AscendMemoryPool : public DynamicMemPoolBestFit {
public:
~AscendMemoryPool() override = default;
AscendMemoryPool(const AscendMemoryPool&) = delete;
AscendMemoryPool& operator=(const AscendMemoryPool&) = delete;
AscendMemoryPool(const AscendMemoryPool &) = delete;
AscendMemoryPool &operator=(const AscendMemoryPool &) = delete;
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
void set_device_mem_pool_base(uint8_t* device_mem_pool_base);
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
void set_device_mem_pool_base(uint8_t *device_mem_pool_base);
void set_device_mem_pool_size(uint64_t device_mem_pool_size) {
device_mem_pool_size_ = device_mem_pool_size;
free_mem_size_ = device_mem_pool_size_;
......@@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
size_t free_mem_size() override;
size_t total_mem_size() override;
static AscendMemoryPool& GetInstance() {
static AscendMemoryPool &GetInstance() {
static AscendMemoryPool instance;
return instance;
}
......@@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit {
private:
AscendMemoryPool() = default;
bool has_malloc_{false};
uint8_t* device_mem_pool_base_{nullptr};
uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0};
size_t free_mem_size_{0};
size_t total_mem_size_{0};
......
......@@ -39,13 +39,13 @@ using std::vector;
class AscendStreamAssign {
public:
static AscendStreamAssign& GetInstance() {
static AscendStreamAssign &GetInstance() {
static AscendStreamAssign instance; // Guaranteed to be destroyed.
return instance;
}
AscendStreamAssign(const AscendStreamAssign&) = delete;
AscendStreamAssign& operator=(const AscendStreamAssign&) = delete;
AscendStreamAssign(const AscendStreamAssign &) = delete;
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
uint32_t GetTotalStreamNum() const;
// new stream policy
......@@ -53,19 +53,19 @@ class AscendStreamAssign {
uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
uint32_t total_event_num() const { return total_event_num_; }
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ResetNew();
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
bool IsIndependentNode(const CNodePtr& node_ptr);
const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t>* wait_active_stream_list);
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
void AssignStreamNew(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr);
const std::unordered_map<uint32_t, uint32_t> &logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t> &logic_to_physic_map() { return logic_to_physic_map_; }
const std::vector<std::vector<uint32_t>> &inner_parallel_streams() { return inner_parallel_streams_; }
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
const std::vector<uint32_t> &hcom_streams() { return hcom_stream_list_; }
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
private:
......@@ -73,30 +73,30 @@ class AscendStreamAssign {
~AscendStreamAssign() = default;
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr& node);
const CNodePtr &node);
bool IsHcom(const CNodePtr& apply_kernel);
bool IsHcom(const CNodePtr &apply_kernel);
bool IsProcessed(uint32_t logic_id);
void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids);
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index,
uint32_t* cur_stream_id);
void TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
uint32_t *cur_stream_id);
void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
void UpdateStreamActive(const CNodePtr& active_ptr);
void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr);
void UpdateStreamActive(const CNodePtr &active_ptr);
void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr);
bool IsTaskSink();
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr);
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
bool IsProcessedParallelStream(uint32_t stream_id);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0};
......
......@@ -28,14 +28,14 @@ namespace device {
namespace ascend {
class PluginImpl : public PluginIntf {
public:
explicit PluginImpl(const std::string& module);
explicit PluginImpl(const std::string &module);
~PluginImpl() override = default;
int Init(const Reporter* reporter) override;
int Init(const Reporter *reporter) override;
int UnInit() override;
static Reporter* GetPluginReporter() { return reporter_; }
static Reporter *GetPluginReporter() { return reporter_; }
private:
static Reporter* reporter_;
static Reporter *reporter_;
std::string module_;
};
} // namespace ascend
......
......@@ -20,12 +20,12 @@
namespace mindspore {
namespace device {
namespace ascend {
PluginIntf* ProfilingEngineImpl::CreatePlugin() {
PluginIntf *ProfilingEngineImpl::CreatePlugin() {
MS_LOG(INFO) << "Create Plugin.";
return new (std::nothrow) PluginImpl("Framework");
}
int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) {
int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) {
if (plugin != nullptr) {
delete plugin;
}
......
......@@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf {
ProfilingEngineImpl() = default;
~ProfilingEngineImpl() override = default;
PluginIntf* CreatePlugin() override;
int ReleasePlugin(PluginIntf* plugin) override;
PluginIntf *CreatePlugin() override;
int ReleasePlugin(PluginIntf *plugin) override;
};
} // namespace ascend
} // namespace device
......
......@@ -35,7 +35,7 @@ using Json = nlohmann::json;
namespace mindspore {
namespace device {
namespace ascend {
ProfilingManager& ProfilingManager::GetInstance() {
ProfilingManager &ProfilingManager::GetInstance() {
static ProfilingManager inst;
return inst;
}
......@@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) {
}
uint64_t ProfilingManager::GetJobId() const {
const char* job_id = std::getenv("JOB_ID");
const char *job_id = std::getenv("JOB_ID");
return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0);
}
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskId_map) const {
bool ProfilingManager::ReportProfilingData(const map<uint32_t, string> &op_taskId_map) const {
if (!IsProfiling()) {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return false;
......@@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size();
Msprof::Engine::ReporterData reporter_data = {};
for (const auto& iter : op_taskId_map) {
for (const auto &iter : op_taskId_map) {
auto data = iter.second + ' ' + std::to_string(iter.first) + ';';
reporter_data.deviceId = UintToInt(device_id_);
reporter_data.data = (unsigned char*)(const_cast<char*>(data.c_str()));
reporter_data.data = (unsigned char *)(const_cast<char *>(data.c_str()));
reporter_data.dataLen = data.size();
auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework"));
if (ret != 0) {
......@@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map<uint32_t, string>& op_taskI
return true;
}
static std::vector<std::string> Split(const std::string& str, const char delim) {
static std::vector<std::string> Split(const std::string &str, const char delim) {
std::vector<std::string> elems;
if (str.empty()) {
......@@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
device_id_ = device_id;
// exp: export PROFILING_MODE=true
// export PROFILING_OPTIONS=training_trace
const char* prof_options_str = std::getenv("PROFILING_OPTIONS");
const char *prof_options_str = std::getenv("PROFILING_OPTIONS");
// register Framework to profiling
int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get());
if (result != 0) {
......@@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const {
MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode.";
return true;
}
Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter();
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter();
if (reporter != nullptr) {
MS_LOG(INFO) << "report data end, ret = " << reporter->Flush();
}
......
......@@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST,
class GpuQueue {
public:
GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity);
GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity);
virtual ~GpuQueue();
void RegisterRelease(const std::function<void(void*)>& func) { host_release_ = func; }
void RegisterRelease(const std::function<void(void *)> &func) { host_release_ = func; }
inline bool IsEmpty() const { return head_ == tail_; }
inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); }
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const;
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size);
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const;
BlockQueueStatus_T Pop();
bool Destroy();
private:
struct NodeInfo {
std::unique_ptr<cudaEvent_t> event_;
void* host_feature_addr_;
void* host_label_addr_;
void *host_feature_addr_;
void *host_label_addr_;
};
void* buffer_;
void *buffer_;
size_t head_;
size_t tail_;
size_t feature_size_;
......@@ -61,10 +61,10 @@ class GpuQueue {
size_t capacity_;
cudaStream_t stream_;
std::unique_ptr<NodeInfo[]> node_info_;
std::function<void(void*)> host_release_;
std::function<void(void *)> host_release_;
GpuQueue(const GpuQueue&) = delete;
GpuQueue& operator=(const GpuQueue&) = delete;
GpuQueue(const GpuQueue &) = delete;
GpuQueue &operator=(const GpuQueue &) = delete;
};
class BlockingQueue {
......@@ -72,11 +72,11 @@ class BlockingQueue {
BlockingQueue() : queue_(nullptr) {}
~BlockingQueue() = default;
BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void*)>& func);
BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size,
BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity);
void RegisterRelease(const std::function<void(void *)> &func);
BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size,
unsigned int timeout_in_sec);
BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size);
BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size);
BlockQueueStatus_T Pop();
bool Destroy();
......
......@@ -20,17 +20,17 @@
namespace mindspore {
namespace device {
namespace gpu {
CollectiveInitializer& CollectiveInitializer::instance() {
CollectiveInitializer &CollectiveInitializer::instance() {
static CollectiveInitializer instance = {};
return instance;
}
bool CollectiveInitializer::collective_inited() const { return collective_inited_; }
const void* CollectiveInitializer::collective_handle() const { return collective_handle_; }
const void *CollectiveInitializer::collective_handle() const { return collective_handle_; }
void CollectiveInitializer::InitCollective() {
void* handle = dlopen("libgpu_collective.so", RTLD_LAZY);
void *handle = dlopen("libgpu_collective.so", RTLD_LAZY);
if (handle == nullptr) {
MS_LOG(EXCEPTION)
<< "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not "
......
......@@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() {
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
}
bool GPUDeviceManager::CreateStream(DeviceStream* stream) {
bool GPUDeviceManager::CreateStream(DeviceStream *stream) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
gpu_streams_.emplace_back(*stream);
return true;
}
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; }
const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; }
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
......@@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; }
bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; }
const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; }
const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const {
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {
return CudaDriver::CopyDeviceMemToHost(dst, src, size);
}
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const {
bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const {
return CudaDriver::CopyHostMemToDevice(dst, src, size);
}
} // namespace gpu
......
......@@ -37,17 +37,17 @@ class GPUDeviceManager {
uint32_t cur_device_id() const;
bool is_device_id_init() const;
bool CreateStream(DeviceStream* stream);
bool SyncStream(const DeviceStream& stream) const;
const DeviceStream& default_stream() const;
bool CreateStream(DeviceStream *stream);
bool SyncStream(const DeviceStream &stream) const;
const DeviceStream &default_stream() const;
const cudnnHandle_t& GetCudnnHandle() const;
const cublasHandle_t& GetCublasHandle() const;
const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t &GetCublasHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const;
bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
static GPUDeviceManager& GetInstance() {
static GPUDeviceManager &GetInstance() {
static GPUDeviceManager instance;
return instance;
}
......@@ -55,8 +55,8 @@ class GPUDeviceManager {
private:
GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {}
~GPUDeviceManager() = default;
GPUDeviceManager(const GPUDeviceManager&) = delete;
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete;
GPUDeviceManager(const GPUDeviceManager &) = delete;
GPUDeviceManager &operator=(const GPUDeviceManager &) = delete;
// default CUDA stream used for all the kernels.
DeviceStream default_stream_{nullptr};
......
......@@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() {
return true;
}
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) {
bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) {
auto alloc_size = AllocDeviceMem(size, addr);
buffer_q_addr_ = *addr;
// Buffer queue needs to ensure that the alloc_size and size is equal.
return (alloc_size == size) ? true : false;
}
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
if (size == 0) {
MS_LOG(EXCEPTION) << "The memory alloc size is 0.";
}
......@@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) {
return alloc_size;
}
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); }
bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); }
size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); }
......
......@@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit {
~GPUMemoryAllocator() override = default;
bool Init();
bool Finalize();
bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr);
bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr);
size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override;
bool FreeDeviceMem(const DeviceMemPtr& addr) override;
size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override;
bool FreeDeviceMem(const DeviceMemPtr &addr) override;
size_t free_mem_size() override;
size_t total_mem_size() override;
static GPUMemoryAllocator& GetInstance() {
static GPUMemoryAllocator &GetInstance() {
static GPUMemoryAllocator instance;
return instance;
}
private:
GPUMemoryAllocator() = default;
GPUMemoryAllocator(const GPUMemoryAllocator&) = delete;
GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete;
GPUMemoryAllocator(const GPUMemoryAllocator &) = delete;
GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete;
// Used to track address of data buffer queue.
DeviceMemPtr buffer_q_addr_{nullptr};
......
......@@ -33,8 +33,8 @@ namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo;
namespace {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info,
const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(selected_kernel_info);
MS_EXCEPTION_IF_NULL(alternative_kernel_info);
size_t selected_input_num = selected_kernel_info->GetInputNum();
......@@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr<KernelBuildInfo>& alternative_kernel_
return true;
}
std::string SupportedTypeList(const CNodePtr& kernel_node) {
std::string SupportedTypeList(const CNodePtr &kernel_node) {
std::string supported_type_lists =
kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node));
if (!supported_type_lists.empty()) {
......@@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) {
return supported_type_lists;
}
bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBuildInfo>& selected_kernel_info) {
bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr<KernelBuildInfo> &selected_kernel_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(selected_kernel_info);
std::vector<std::shared_ptr<KernelBuildInfo>> kernel_info_list;
......@@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
}
bool match = std::any_of(kernel_info_list.begin(), kernel_info_list.end(),
[&](const std::shared_ptr<KernelBuildInfo>& alternative_kernel_info) {
[&](const std::shared_ptr<KernelBuildInfo> &alternative_kernel_info) {
return CheckKernelInfo(alternative_kernel_info, selected_kernel_info);
});
if (!match) {
......@@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr<KernelBu
return true;
}
void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, const CNodePtr& kernel_node) {
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1);
......@@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co
}
} // namespace
void SetKernelInfo(const CNodePtr& kernel_node) {
void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> inputs_format;
std::vector<TypeId> inputs_type;
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
......
......@@ -27,7 +27,7 @@
namespace mindspore {
namespace device {
namespace gpu {
void SetKernelInfo(const CNodePtr& apply_kernel_ptr);
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr {
public:
......@@ -35,24 +35,24 @@ class KernelAttr {
KernelAttr() : all_same_(false) {}
~KernelAttr() = default;
KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) {
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr& AddAllSameAttr(const bool& all_same) {
KernelAttr &AddAllSameAttr(const bool &all_same) {
all_same_ = all_same;
return *this;
}
const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool& GetAllSame() const { return all_same_; }
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool &GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }
......
......@@ -24,7 +24,7 @@
namespace mindspore {
struct TypeIdManager* TypeIdManager::Get() {
struct TypeIdManager *TypeIdManager::Get() {
static TypeIdManager manager;
return &manager;
}
......
......@@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode*>(this)->shared_from_base<AnfNode>()->debug_info());
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
CNode::CNode(const std::vector<AnfNodePtr>& inputs, const FuncGraphPtr& func_graph)
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
// Check if CNode is an apply with the specific Primitive.
bool CNode::IsApply(const PrimitivePtr& value) const {
bool CNode::IsApply(const PrimitivePtr &value) const {
if (value == nullptr) {
return false;
}
......@@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const {
return false;
}
void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; }
void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; }
std::string CNode::DebugString(int recursive_level) const {
std::ostringstream buffer;
......@@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const {
buffer << ToString() << "{";
bool is_first_node = true;
int idx = 0;
for (auto& node : inputs_) {
for (auto &node : inputs_) {
MS_EXCEPTION_IF_NULL(node);
if (is_first_node) {
is_first_node = false;
......@@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str();
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) {
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
......@@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base<Parameter>()); }
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr) {
......@@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) {
return false;
}
PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) {
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
......@@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) {
return "";
}
bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
if (IsValueNode<Primitive>(node)) {
PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
MS_EXCEPTION_IF_NULL(value);
......@@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) {
}
namespace id_generator {
static std::unordered_map<std::string, int> node_ids;
std::string get_id(const AnfNodePtr& node) {
std::string get_id(const AnfNodePtr &node) {
auto type_name = node->type_name();
if (node_ids.find(type_name) == node_ids.end()) {
node_ids[type_name] = 0;
......
......@@ -39,15 +39,15 @@ struct is_shared_ptr<std::shared_ptr<T>> : public std::true_type {};
class Base : public std::enable_shared_from_this<Base> {
public:
constexpr Base() = default;
Base(const Base& other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base& rhs) {
Base(const Base &other) : std::enable_shared_from_this<Base>(other) {}
virtual bool operator==(const Base &rhs) {
if (this == &rhs) {
return true;
}
return false;
}
virtual Base& operator=(const Base&) { return *this; }
virtual Base &operator=(const Base &) { return *this; }
virtual ~Base() = default;
virtual std::size_t hash() const { return tid(); }
virtual std::string ToString() const { return type_name(); }
......@@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this<Base> {
virtual const bool IsFromTypeId(uint32_t tid) const;
virtual std::string type_name() const { return "Base"; }
static uint32_t GetTypeId(const char* const type_key);
static uint32_t GetTypeId(const char *const type_key);
virtual uint32_t tid() const {
static const uint32_t tid = GetTypeId(typeid(Base).name());
return tid;
}
template <typename T,
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type* = nullptr>
typename std::enable_if<!is_shared_ptr<T>::value && std::is_base_of<Base, T>::value, T>::type * = nullptr>
inline bool isa() const {
static const uint32_t tid = GetTypeId(typeid(T).name());
return this->IsFromTypeId(tid);
......@@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr<Base>;
using BaseWeakPtr = std::weak_ptr<Base>;
template <typename T, typename U>
inline T* cast(U* source) {
inline T *cast(U *source) {
if (source != nullptr && source->template isa<T>()) {
return static_cast<T*>(source);
return static_cast<T *>(source);
} else {
return nullptr;
}
......@@ -100,7 +100,7 @@ inline T* cast(U* source) {
template <
typename T, typename U,
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type* = nullptr>
typename std::enable_if<std::is_base_of<Base, T>::value && std::is_base_of<Base, U>::value, T>::type * = nullptr>
inline std::shared_ptr<T> dyn_cast(const std::shared_ptr<U> r) {
if (r != nullptr && r->template isa<T>()) {
return std::static_pointer_cast<T>(r);
......@@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager {
std::mutex mutex;
std::atomic<uint32_t> type_counter{0};
std::unordered_map<std::string, uint32_t> map;
static TypeIdManager* Get();
static TypeIdManager *Get();
TypeIdManager() : mutex(), type_counter(0), map() {}
};
} // namespace mindspore
......
......@@ -48,11 +48,11 @@ std::string Keyword::ToString() const {
return buffer.str();
}
bool Keyword::operator==(const Type& other) const {
bool Keyword::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_keyword = static_cast<const Keyword&>(other);
const auto &other_keyword = static_cast<const Keyword &>(other);
return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_);
}
......@@ -87,11 +87,11 @@ std::string Slice::ToString() const {
return buffer.str();
}
bool Slice::operator==(const Type& other) const {
bool Slice::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_slice = static_cast<const Slice&>(other);
auto other_slice = static_cast<const Slice &>(other);
return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_);
}
......@@ -122,11 +122,11 @@ std::string TensorType::DumpText() const {
}
}
bool TensorType::operator==(const Type& other) const {
bool TensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const TensorType&>(other).element_type_;
auto other_elem_type = static_cast<const TensorType &>(other).element_type_;
// When element_type_ = nullptr, which means any type of Array.
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
......@@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) {
retval_ = nullptr;
}
Function::Function(const std::vector<TypePtr>& args, const TypePtr retval)
Function::Function(const std::vector<TypePtr> &args, const TypePtr retval)
: Object(kObjectTypeFunction, false), args_(args), retval_(retval) {}
TypePtr Function::DeepCopy() const {
......@@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const {
TypePtrList args;
TypePtr retval = nullptr;
(void)std::transform(args_.begin(), args_.end(), std::back_inserter(args),
[](const TypePtr& arg) { return arg->DeepCopy(); });
[](const TypePtr &arg) { return arg->DeepCopy(); });
if (retval_ != nullptr) {
retval = retval_->DeepCopy();
}
......@@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const {
}
}
bool Function::operator==(const Type& other) const {
bool Function::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_function = static_cast<const Function&>(other);
const auto &other_function = static_cast<const Function &>(other);
if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) {
if (*retval_ != *other_function.retval_) {
return false;
......@@ -188,7 +188,7 @@ std::string Function::ToString() const {
} else {
buffer << "Func[(";
bool begin = true;
for (auto& attr : args_) {
for (auto &attr : args_) {
if (!begin) {
buffer << ", ";
} else {
......@@ -242,34 +242,34 @@ std::string JTagged::DumpText() const {
return buffer.str();
}
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem) {
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem) {
MS_EXCEPTION_IF_NULL(problem);
os << problem->ToString();
return os;
}
std::size_t TypeHasher::operator()(TypePtr const& type) const {
std::size_t TypeHasher::operator()(TypePtr const &type) const {
MS_EXCEPTION_IF_NULL(type);
std::size_t hash = std::hash<size_t>()(type->type_id());
return hash;
}
std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const {
std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const {
std::size_t hash_sum = 0;
for (auto& type : type_list) {
for (auto &type : type_list) {
auto type_id = static_cast<std::size_t>(type->type_id());
hash_sum = hash_combine(hash_sum, type_id);
}
return hash_sum;
}
bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const {
bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const {
MS_EXCEPTION_IF_NULL(t1);
MS_EXCEPTION_IF_NULL(t2);
return t1->type_id() == t2->type_id();
}
bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const {
bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const {
if (lhs.size() != rhs.size()) {
return false;
}
......@@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) {
namespace {
template <typename T>
TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) {
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
TypePtr type = nullptr;
if (type_name == num_type_name) {
type = std::make_shared<T>();
......@@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_
}
auto bits = std::stoi(type_name.substr(num_type_name.size()));
type = std::make_shared<T>(bits);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what();
}
}
return type;
}
std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
std::vector<TypePtr> StringToVectorOfType(const std::string &type_names) {
std::vector<TypePtr> types;
if (type_names.length() == 0) {
return types;
......@@ -371,7 +371,7 @@ std::vector<TypePtr> StringToVectorOfType(const std::string& type_names) {
return types;
}
TypePtr TensorStrToType(const std::string& type_name) {
TypePtr TensorStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tensor") {
type = std::make_shared<TensorType>();
......@@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return nullptr;
}
type = std::make_shared<TensorType>(element_type);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
......@@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) {
return type;
}
TypePtr ListStrToType(const std::string& type_name) {
TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "List") {
type = std::make_shared<List>();
......@@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<List>(element_types);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
......@@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) {
return type;
}
TypePtr TupleStrToType(const std::string& type_name) {
TypePtr TupleStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Tuple") {
type = std::make_shared<Tuple>();
......@@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) {
std::string element_strs = type_name.substr(start, end - start);
std::vector<TypePtr> element_types = StringToVectorOfType(element_strs);
bool wrong =
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; });
std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; });
if (wrong) {
return nullptr;
}
type = std::make_shared<Tuple>(element_types);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
return type;
}
TypePtr FunctionStrToType(const std::string& type_name) {
TypePtr FunctionStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "Function") {
......@@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) {
std::vector<TypePtr> args_type = StringToVectorOfType(str_args);
TypePtr retval = StringToType(str_retval);
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; });
bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; });
if (retval == nullptr || wrong) {
return nullptr;
}
type = std::make_shared<Function>(args_type, retval);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what();
}
}
......@@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) {
}
} // namespace
TypePtr StringToType(const std::string& type_name) {
TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>();
......@@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) {
return type;
}
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
......@@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) {
}
}
bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
bool IsSubType(TypePtr const &t1, TypePtr const &t2) {
MS_EXCEPTION_IF_NULL(t1);
if (t1->type_id() == kTypeUnknown) {
return false;
......@@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) {
}
REGISTER_PYBIND_DEFINE(
typing, ([](py::module* const m) {
typing, ([](py::module *const m) {
auto m_sub = m->def_submodule("typing", "submodule for dtype");
py::enum_<TypeId>(m_sub, "TypeId");
(void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass");
(void)m_sub.def("load_type", &TypeIdToType, "load type");
(void)m_sub.def(
"dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type");
"dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type");
(void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type")
.def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_)
.def("__eq__",
[](const TypePtr& t1, const TypePtr& t2) {
[](const TypePtr &t1, const TypePtr &t2) {
if (t1 != nullptr && t2 != nullptr) {
return *t1 == *t2;
}
......@@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE(
.def("__hash__", &Type::hash)
.def("__str__", &Type::ToString)
.def("__repr__", &Type::ReprString)
.def("__deepcopy__", [](const TypePtr& t, py::dict) {
.def("__deepcopy__", [](const TypePtr &t, py::dict) {
if (t == nullptr) {
return static_cast<TypePtr>(nullptr);
}
......@@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<Bool, Type, std::shared_ptr<Bool>>(m_sub, "Bool")
.def(py::init())
.def(py::pickle(
[](const Bool&) { // __getstate__
[](const Bool &) { // __getstate__
return py::make_tuple();
},
[](const py::tuple&) { // __setstate__
[](const py::tuple &) { // __setstate__
return std::make_shared<Bool>();
}));
(void)py::class_<Int, Type, std::shared_ptr<Int>>(m_sub, "Int")
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Int& t) { // __getstate__
[](const Int &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
......@@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const UInt& t) { // __getstate__
[](const UInt &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
......@@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init())
.def(py::init<int>(), py::arg("nbits"))
.def(py::pickle(
[](const Float& t) { // __getstate__
[](const Float &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(t.nbits()));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
......@@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE(
.def(py::init<TypePtr>(), py::arg("element"))
.def("element_type", &TensorType::element)
.def(py::pickle(
[](const TensorType& t) { // __getstate__
[](const TensorType &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(py::int_(static_cast<int>(t.element()->type_id())));
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
......
......@@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr<String>;
class Keyword : public Object {
public:
Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {}
Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {}
~Keyword() override = default;
MS_DECLARE_PARENT(Keyword, Object)
......@@ -70,7 +70,7 @@ class Keyword : public Object {
std::string ToString() const override;
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::string GetKey() const { return key_; }
TypePtr GetValue() const { return value_; }
......@@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr<Keyword>;
class Slice : public Object {
public:
Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {}
Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step)
Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step)
: Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {}
~Slice() override = default;
......@@ -95,7 +95,7 @@ class Slice : public Object {
std::string ToString() const override;
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr get_start() const { return start_; }
TypePtr get_stop() const { return stop_; }
......@@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr<Slice>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType) {}
explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr& element_type) { element_type_ = element_type; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override { return "tensor"; }
std::string DumpText() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
......@@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr<TensorType>;
class Function : public Object {
public:
Function();
Function(const std::vector<TypePtr>& args, const TypePtr retval);
Function(const std::vector<TypePtr> &args, const TypePtr retval);
~Function() override = default;
MS_DECLARE_PARENT(Function, Object)
......@@ -141,11 +141,11 @@ class Function : public Object {
// Add temporarily for return abstraction to avoid type checking.
bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); }
const std::vector<TypePtr>& args() const { return args_; }
const TypePtr& retval() const { return retval_; }
const std::vector<TypePtr> &args() const { return args_; }
const TypePtr &retval() const { return retval_; }
TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::string ToString() const override;
std::string ToReprString() const override { return "function"; }
......@@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr<Function>;
class JTagged : public Object {
public:
JTagged() : Object(kObjectTypeJTagged) {}
explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {}
~JTagged() override = default;
MS_DECLARE_PARENT(JTagged, Object)
......@@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr<TypeType>;
class Problem : public Type {
public:
Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {}
explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {}
explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {}
~Problem() override = default;
MS_DECLARE_PARENT(Problem, Type)
......@@ -222,7 +222,7 @@ class Problem : public Type {
std::string ToString() const override { return kind_.name(); }
std::string DumpText() const override { return "ProblemType"; }
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Problem> problem);
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Problem> problem);
private:
Named kind_;
......@@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr<External>;
// helper template
template <class T>
TypePtr Clone(const T& t) {
TypePtr Clone(const T &t) {
return t.Clone();
}
TypePtr StringToType(const std::string& type_name);
TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate.
bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type);
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
// Whether t1 is identity or a subclass of t2.
bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr);
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);
struct TypeHasher {
std::size_t operator()(TypePtr const& type) const;
std::size_t operator()(TypePtr const &type) const;
};
struct TypeListHasher {
std::size_t operator()(const TypePtrList& type_list) const;
std::size_t operator()(const TypePtrList &type_list) const;
};
struct TypeEqual {
bool operator()(TypePtr const& t1, TypePtr const& t2) const;
bool operator()(TypePtr const &t1, TypePtr const &t2) const;
};
struct TypeListEqual {
bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const;
bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const;
};
extern const TypePtr kTypeExternal;
......
......@@ -24,7 +24,7 @@
#include "pybind_api/export_flags.h"
namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr>& elements, bool is_dumptext) {
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {
std::ostringstream oss;
bool begin = true;
int cnt = 0;
......@@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const {
} else {
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); });
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<List>(elements);
return copy;
}
......@@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const {
return elements_[dim];
}
bool List::operator==(const Type& other) const {
bool List::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const List& other_list = static_cast<const List&>(other);
const List &other_list = static_cast<const List &>(other);
if (elements_.size() != other_list.elements_.size()) {
return false;
}
......@@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const {
return true;
}
Class::Class(const Named& tag, const ClassAttrVector& attributes,
const std::unordered_map<std::string, ValuePtr>& methods)
Class::Class(const Named &tag, const ClassAttrVector &attributes,
const std::unordered_map<std::string, ValuePtr> &methods)
: Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {}
std::string List::ToString() const {
......@@ -122,7 +122,7 @@ std::string List::DumpText() const {
return buffer.str();
}
bool Class::operator==(const Type& other) const {
bool Class::operator==(const Type &other) const {
// Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj.
return &other == this;
}
......@@ -143,7 +143,7 @@ std::string Class::ToString() const {
} else {
bool begin = true;
buffer << "cls." << tag_ << "[";
for (auto& attr : attributes_) {
for (auto &attr : attributes_) {
if (!begin) {
buffer << ", ";
} else {
......@@ -163,7 +163,7 @@ std::string Class::DumpText() const {
} else {
bool begin = true;
buffer << "Cls." << tag_ << "[";
for (auto& attr : attributes_) {
for (auto &attr : attributes_) {
if (!begin) {
buffer << ", ";
} else {
......@@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const {
} else {
TypePtrList elements;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements),
[](const TypePtr& ele) { return ele->DeepCopy(); });
[](const TypePtr &ele) { return ele->DeepCopy(); });
auto copy = std::make_shared<Tuple>(elements);
return copy;
}
}
bool Tuple::operator==(const Type& other) const {
bool Tuple::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_tuple = static_cast<const Tuple&>(other);
auto other_tuple = static_cast<const Tuple &>(other);
if (elements_.size() != other_tuple.elements_.size()) {
return false;
}
......@@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const {
std::vector<std::pair<std::string, TypePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, TypePtr>& item) { return std::make_pair(item.first, item.second->DeepCopy()); });
[](const std::pair<std::string, TypePtr> &item) { return std::make_pair(item.first, item.second->DeepCopy()); });
return std::make_shared<Dictionary>(kv);
}
}
......@@ -259,7 +259,7 @@ std::string Dictionary::ToString() const {
std::ostringstream buffer;
std::vector<std::string> keys;
std::vector<TypePtr> values;
for (const auto& kv : key_values_) {
for (const auto &kv : key_values_) {
keys.push_back(kv.first);
values.push_back(kv.second);
}
......@@ -276,12 +276,12 @@ std::string Dictionary::ToString() const {
std::string Dictionary::DumpText() const { return ToString(); }
bool Dictionary::operator==(const mindspore::Type& other) const {
bool Dictionary::operator==(const mindspore::Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
const auto& other_dict = static_cast<const Dictionary&>(other);
const auto &other_dict = static_cast<const Dictionary &>(other);
if (key_values_.size() != other_dict.key_values_.size()) {
return false;
}
......
......@@ -40,10 +40,10 @@ namespace mindspore {
class List : public Object {
public:
List() : Object(kObjectTypeList) {}
List(const std::initializer_list<TypePtr>& objs)
List(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy;
explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {}
explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {}
~List() override {}
MS_DECLARE_PARENT(List, Object)
......@@ -51,7 +51,7 @@ class List : public Object {
TypeId generic_type_id() const override { return kObjectTypeList; }
TypePtr DeepCopy() const override;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
std::size_t size() const { return elements_.size(); }
TypePtrList elements() const { return elements_; }
std::string ToString() const override;
......@@ -68,22 +68,22 @@ using ClassAttrVector = std::vector<std::pair<std::string, TypePtr>>;
class Class : public Object {
public:
Class() : Object(kObjectTypeClass), tag_(Named("Class")) {}
Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map<std::string, ValuePtr>& methods);
Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map<std::string, ValuePtr> &methods);
~Class() override {}
MS_DECLARE_PARENT(Class, Object)
TypeId generic_type_id() const override { return kObjectTypeClass; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
void set_value(const std::unordered_map<std::string, ValuePtr>& v) { attributes_value_ = v; }
void set_value(const std::unordered_map<std::string, ValuePtr> &v) { attributes_value_ = v; }
Named tag() { return tag_; }
std::unordered_map<std::string, ValuePtr> GetValue() { return attributes_value_; }
std::unordered_map<std::string, ValuePtr> methods() { return methods_; }
ClassAttrVector& GetAttributes() { return attributes_; }
ClassAttrVector &GetAttributes() { return attributes_; }
ClassAttrVector attributes_;
......@@ -99,11 +99,11 @@ class Tuple : public Object {
public:
Tuple() : Object(kObjectTypeTuple) {}
// usage : Tuple t = {std::make_shared<Bool>(), std::make_shared<Int>(32)};
Tuple(const std::initializer_list<TypePtr>& objs)
Tuple(const std::initializer_list<TypePtr> &objs)
: Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
// Shadow copy
explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {}
~Tuple() override {}
MS_DECLARE_PARENT(Tuple, Object)
......@@ -115,7 +115,7 @@ class Tuple : public Object {
std::string ToReprString() const override { return "tuple_"; }
std::string DumpText() const override;
const TypePtr operator[](size_t dim) const;
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtrList elements() const { return elements_; }
std::size_t size() const { return elements_.size(); }
......@@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr<Tuple>;
class Dictionary : public Object {
public:
Dictionary() : Object(kObjectTypeDictionary) {}
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>>& key_values)
explicit Dictionary(const std::vector<std::pair<std::string, TypePtr>> &key_values)
: Object(kObjectTypeDictionary, false), key_values_(key_values) {}
~Dictionary() override {}
......@@ -136,7 +136,7 @@ class Dictionary : public Object {
TypeId generic_type_id() const override { return kObjectTypeDictionary; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string DumpText() const override;
......
......@@ -24,11 +24,11 @@
#include "pybind_api/export_flags.h"
namespace mindspore {
bool Number::operator==(const Type& other) const {
bool Number::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_number = static_cast<const Number&>(other);
auto other_number = static_cast<const Number &>(other);
return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_));
}
......
......@@ -49,12 +49,12 @@ class Number : public Object {
TypeId type_id() const override { return number_type_; }
TypeId generic_type_id() const override { return kObjectTypeNumber; }
bool operator==(const Type& other) const override;
bool operator==(const Type &other) const override;
TypePtr DeepCopy() const override { return std::make_shared<Number>(); }
std::string ToString() const override { return "Number"; }
std::string ToReprString() const override { return "number"; }
std::string DumpText() const override { return "Number"; }
std::string GetTypeName(const std::string& type_name) const {
std::string GetTypeName(const std::string &type_name) const {
std::ostringstream oss;
oss << type_name;
if (nbits() != 0) {
......
......@@ -51,7 +51,7 @@ class RefKeyType : public Object {
class RefType : public Object {
public:
RefType() : Object(kObjectTypeRef) {}
RefType(const TypePtr& subtype, const TypePtr& subtype_origin)
RefType(const TypePtr &subtype, const TypePtr &subtype_origin)
: Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {}
~RefType() override {}
MS_DECLARE_PARENT(RefType, Object)
......
......@@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) {
}
}
const char* MetaIdLabel(const TypeId& v) {
const char *MetaIdLabel(const TypeId &v) {
switch (v) {
case kTypeUnknown:
return "kTypeUnknown";
......@@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) {
}
}
const char* ObjectIdLabel(const TypeId& v) {
const char *ObjectIdLabel(const TypeId &v) {
switch (v) {
case kObjectTypeNumber:
return "kObjectTypeNumber";
......@@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) {
}
}
const char* NumberIdLabel(const TypeId& v) {
const char *NumberIdLabel(const TypeId &v) {
switch (v) {
case kNumberTypeBool:
return "kNumberTypeBool";
......@@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) {
}
}
const char* TypeIdLabel(const TypeId& v) {
const char *TypeIdLabel(const TypeId &v) {
if (v < kMetaTypeEnd) {
return MetaIdLabel(v);
} else {
......@@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) {
}
}
bool IsSameObjectType(const Type& lhs, const Type& rhs) {
bool IsSameObjectType(const Type &lhs, const Type &rhs) {
if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) {
return false;
}
return lhs.object_type() == rhs.object_type();
}
size_t GetTypeByte(const TypePtr& type_ptr) {
size_t GetTypeByte(const TypePtr &type_ptr) {
if (type_ptr && type_ptr->isa<Number>()) {
auto number = dyn_cast<Number>(type_ptr);
if (!number) {
......@@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) {
}
}
bool Type::operator==(const Value& other) const {
bool Type::operator==(const Value &other) const {
if (other.isa<Type>()) {
auto other_type = static_cast<const Type*>(&other);
auto other_type = static_cast<const Type *>(&other);
return *this == *other_type;
} else {
return false;
......@@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() {
return ptr;
}
std::ostream& operator<<(std::ostream& os, const Type& type) {
std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const TypePtr type) {
std::ostream &operator<<(std::ostream &os, const TypePtr type) {
os << type->ToString();
return os;
}
......@@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const {
return false;
}
std::ostream& operator<<(std::ostream& os, const Object& obj) {
std::ostream &operator<<(std::ostream &os, const Object &obj) {
os << obj.ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj) {
std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj) {
os << obj->ToString();
return os;
}
std::ostream& operator<<(std::ostream& os, const TypePtrList& types) {
std::ostream &operator<<(std::ostream &os, const TypePtrList &types) {
os << "[";
for (size_t i = 0; i < types.size(); ++i) {
if (i > 0) {
......
......@@ -95,10 +95,10 @@ enum TypeId : int {
TypeId IntBitsToTypeId(const int nbits);
TypeId UIntBitsToTypeId(const int nbits);
TypeId FloatBitsToTypeId(const int nbits);
const char* TypeIdLabel(const TypeId& v);
const char *TypeIdLabel(const TypeId &v);
TypeId NormalizeTypeId(const TypeId type_id);
bool IsSameObjectType(const Type& lhs, const Type& rhs);
size_t GetTypeByte(const TypePtr& type_ptr);
bool IsSameObjectType(const Type &lhs, const Type &rhs);
size_t GetTypeByte(const TypePtr &type_ptr);
// Base class for all types
// forward declaration.
......@@ -110,14 +110,14 @@ class Type : public Value {
~Type() override = default;
MS_DECLARE_PARENT(Type, Value)
bool operator==(const Value& other) const override;
bool operator==(const Value &other) const override;
TypeId meta_type() const { return meta_type_; }
virtual TypeId type_id() const { return meta_type_; }
virtual TypeId generic_type_id() const { return kMetaTypeType; }
virtual bool operator!=(const Type& other) const { return !(*this == other); }
virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); }
virtual bool operator!=(const Type &other) const { return !(*this == other); }
virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); }
virtual bool equal(const TypePtr other) const { return *this == *other; }
virtual TypeId object_type() const { return kTypeUnknown; }
......@@ -134,8 +134,8 @@ class Type : public Value {
bool IsUnknown() const { return (meta_type_ == kMetaTypeType); }
bool IsGeneric() const { return is_generic_; }
abstract::AbstractBasePtr ToAbstract() override;
friend std::ostream& operator<<(std::ostream& os, const Type& type);
friend std::ostream& operator<<(std::ostream& os, const TypePtr type);
friend std::ostream &operator<<(std::ostream &os, const Type &type);
friend std::ostream &operator<<(std::ostream &os, const TypePtr type);
const bool parse_info_ = true;
......@@ -163,14 +163,14 @@ class Object : public Type {
bool equal(const TypePtr other) const override;
std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); }
friend std::ostream& operator<<(std::ostream& os, const Object& obj);
friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr<Object> obj);
friend std::ostream &operator<<(std::ostream &os, const Object &obj);
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<Object> obj);
private:
const TypeId object_type_;
};
std::ostream& operator<<(std::ostream& os, const TypePtrList& types);
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_
此差异已折叠。
......@@ -43,26 +43,26 @@ struct CloneInfo {
class Cloner {
public:
explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false,
explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false,
bool clone_all_child_graphs = true, bool clone_all_used_graphs = false,
const TraceInfoPtr& relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr& target_relation = nullptr);
const TraceInfoPtr &relation = std::make_shared<TraceCopy>(),
const TraceInfoPtr &target_relation = nullptr);
~Cloner() = default;
void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr,
const AnfNodePtrList& params = {}, CloneType type = kBasic);
void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr,
const AnfNodePtrList &params = {}, CloneType type = kBasic);
void Run();
// Interfaces for specializer
AnfNodePtr CloneDisconnected(const AnfNodePtr& root);
AnfNodePtr operator[](const AnfNodePtr& node);
FuncGraphPtr operator[](const FuncGraphPtr& func_graph);
AnfNodePtr CloneDisconnected(const AnfNodePtr &root);
AnfNodePtr operator[](const AnfNodePtr &node);
FuncGraphPtr operator[](const FuncGraphPtr &func_graph);
// Map of replicate nodes and graphs
std::unordered_map<AnfNodePtr, AnfNodePtr>* cloned_node() { return &repl_node_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> *cloned_node() { return &repl_node_; }
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph() { return repl_func_graph_; }
// Scope of cloned graphs
void set_scope(const ScopePtr& scope) { scope_ = scope; }
void set_scope(const ScopePtr &scope) { scope_ = scope; }
const ScopePtr scope() const { return scope_; }
std::unordered_map<AnfNodePtr, AnfNodePtr> repl_node_;
......@@ -71,31 +71,31 @@ class Cloner {
void CloneNodes();
void LinkEdges();
void SetDefaults();
void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneValueNode(const AnfNodePtr& node);
void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target);
void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr& func_graph);
void AddChildGraphs(const FuncGraphPtr& func_graph);
void AddTotalGraphs(const FuncGraphPtr& func_graph);
bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph);
void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph);
void GenParameters(const FuncGraphPtr& func_graph);
void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node);
ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true);
void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params,
AnfNodePtrList* const input_params);
void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params);
void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs);
void SetEdges(const FuncGraphPtr& func_graph);
void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph,
const AnfNodePtrList& params);
void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneValueNode(const AnfNodePtr &node);
void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target);
void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false);
void CloneValueNodes(const FuncGraphPtr &func_graph);
void AddChildGraphs(const FuncGraphPtr &func_graph);
void AddTotalGraphs(const FuncGraphPtr &func_graph);
bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline);
void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph);
void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph);
void GenParameters(const FuncGraphPtr &func_graph);
void CloneParameter(const ParameterPtr &param, const AnfNodePtr &node);
ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true);
void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params, AnfNodePtrList *const lift_params,
AnfNodePtrList *const input_params);
void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList &params);
void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs);
void SetEdges(const FuncGraphPtr &func_graph);
void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
const AnfNodePtrList &params);
void Lift();
void LiftParameters();
......@@ -118,17 +118,17 @@ class Cloner {
std::unordered_map<FuncGraphPtr, AnfNodePtrList> repl_func_graph_params_;
};
FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph);
FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph);
AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph,
const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr);
AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr);
FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph);
FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph);
ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation);
ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation);
FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph,
const TraceInfoPtr& relation = std::make_shared<TraceTransform>());
FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph,
const TraceInfoPtr &relation = std::make_shared<TraceTransform>());
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_
此差异已折叠。
此差异已折叠。
......@@ -42,25 +42,25 @@ namespace mindspore {
// generate a graph corresponding to these types.
class MetaFuncGraph : public FuncGraphBase {
public:
explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); }
explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); }
~MetaFuncGraph() override = default;
MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node);
abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node);
// Return normalized versions of the arguments.
// By default, this returns args unchanged.
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const {
virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const {
return args_spec_list;
}
const std::vector<Signature>& signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature>& signatures) { signatures_ = signatures; }
const std::vector<Signature> &signatures() const { return signatures_; }
void set_signatures(const std::vector<Signature> &signatures) { signatures_ = signatures; }
// Generate a Graph for the given abstract arguments.
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) {
virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr& arg) -> TypePtr {
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
......@@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase {
}
// Generate a Graph for this type signature.
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) {
virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) {
MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types.";
}
......@@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase {
std::string ToString() const override { return name_; }
std::size_t hash() const override { return tid(); }
virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; }
bool operator==(const Value& other) const override {
virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; }
bool operator==(const Value &other) const override {
if (other.isa<MetaFuncGraph>()) {
return &other == this;
} else {
......
......@@ -31,7 +31,7 @@ namespace mindspore {
namespace tensor {
void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
void DataBuf2Contiguous(const py::array &src, py::array *const dest) {
if (dest == nullptr) {
MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!";
}
......@@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) {
// MetaTensor has default type_id_ which is TypeId::kTypeUnknown.
MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int>& shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypeId data_type, const std::vector<int> &shape) : data_type_(data_type), shape_(shape) {}
MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
......@@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) {
}
}
MetaTensor::MetaTensor(const MetaTensor& meta_tensor)
MetaTensor::MetaTensor(const MetaTensor &meta_tensor)
: Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {}
MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) {
if (&meta_tensor == this) {
return *this;
}
......@@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) {
return *this;
}
bool MetaTensor::operator==(const MetaTensor& meta_tensor) const {
bool MetaTensor::operator==(const MetaTensor &meta_tensor) const {
return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape();
}
......@@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) {
return type_ptr;
}
void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) {
void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) {
DeviceInfo info(format, data_type);
set_device_info(info);
}
......@@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const {
return oss.str();
}
Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
......@@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) {
init(data_type_, shape_, &data_);
}
Tensor::Tensor(TypeId data_type, const std::vector<int>& shape) { init(data_type, shape, &data_); }
Tensor::Tensor(TypeId data_type, const std::vector<int> &shape) { init(data_type, shape, &data_); }
Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); }
Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); }
Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); }
Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type)
Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type)
: MetaTensor(tensor), device_address_(tensor.device_address()) {
init(tensor.data_, data_type);
}
Tensor& Tensor::operator=(const Tensor& tensor) {
Tensor &Tensor::operator=(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);
dirty_ = tensor.is_dirty();
......@@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) {
return *this;
}
bool Tensor::operator==(const Tensor& tensor) const {
bool Tensor::operator==(const Tensor &tensor) const {
return (MetaTensor::operator==(tensor) && data_ == tensor.data_);
}
bool Tensor::ValueEqualPy(const py::object& other) const {
bool Tensor::ValueEqualPy(const py::object &other) const {
if (!py::isinstance<Tensor>(other)) {
MS_LOG(WARNING) << "compare other not a tensor";
return false;
......@@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const {
return ValueEqual(py::cast<Tensor>(other));
}
bool Tensor::ValueEqual(const Tensor& other) const {
bool Tensor::ValueEqual(const Tensor &other) const {
auto equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
......@@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast<int>(data_type_); }
std::vector<int> Tensor::shape_c(void) const { return shape(); }
void* Tensor::data_c(bool writable) {
void *Tensor::data_c(bool writable) {
// operand of bit operation should be unsigned int.
unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
bool is_c_contiguous = (flags != 0) ? true : false;
......@@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) {
return data_.request(writable).ptr;
}
TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
TypeId Tensor::GetDataType(const py::buffer_info &buf) const {
TypeId data_type = TypeId::kTypeUnknown;
if (buf.format.compare("e") == 0) {
data_type = TypeId::kNumberTypeFloat16;
......@@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const {
return data_type;
}
void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
void Tensor::init(const py::array &input, const TypePtr &type_ptr) {
TypeId data_type = TypeId::kTypeUnknown;
if (type_ptr != nullptr) {
data_type = type_ptr->type_id();
......@@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) {
init(input, data_type);
}
void Tensor::init(const py::array& input, const TypeId& data_type) {
void Tensor::init(const py::array &input, const TypeId &data_type) {
py::buffer_info buf = input.request();
data_type_ = GetDataType(buf);
......@@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) {
}
}
void Tensor::init(TypeId data_type, const std::vector<int>& shape, py::array* const data) {
void Tensor::init(TypeId data_type, const std::vector<int> &shape, py::array *const data) {
data_type_ = data_type;
shape_ = shape;
switch (data_type) {
......@@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return data_type_;
}
bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out,
bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
const TypeId out_data_type) {
if (out == nullptr) {
return false;
......@@ -458,7 +458,7 @@ py::array Tensor::data_sync() {
return data_;
}
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init<TypePtr, py::tuple>(), py::arg("dtype"), py::arg("shape"))
......@@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) {
.def("__repr__", &Tensor::ToStringRepr)
.def("__eq__", &Tensor::ValueEqualPy)
.def(py::pickle(
[](const Tensor& t) { // __getstate__
[](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(t.data());
},
[](const py::tuple& t) { // __setstate__
[](const py::tuple &t) { // __setstate__
if (t.size() != 1) {
throw std::runtime_error("Invalid state!");
}
......
此差异已折叠。
......@@ -18,9 +18,9 @@
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
bool Named::operator==(const Value& other) const {
bool Named::operator==(const Value &other) const {
if (other.isa<Named>()) {
auto other_named = static_cast<const Named&>(other);
auto other_named = static_cast<const Named &>(other);
return *this == other_named;
} else {
return false;
......
......@@ -27,18 +27,18 @@
namespace mindspore {
class Named : public Value {
public:
explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named& other) : Value(other) {
explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash<std::string>{}(name); }
Named(const Named &other) : Value(other) {
this->name_ = other.name_;
hash_id_ = std::hash<std::string>{}(other.name_);
}
~Named() override = default;
MS_DECLARE_PARENT(Named, Value);
const std::string& name() const { return name_; }
virtual bool operator==(const Named& other) const { return name_ == other.name(); }
bool operator==(const Value& other) const override;
Named& operator=(const Named& other) {
const std::string &name() const { return name_; }
virtual bool operator==(const Named &other) const { return name_ == other.name(); }
bool operator==(const Value &other) const override;
Named &operator=(const Named &other) {
if (&other != this) {
this->type_ = other.type_;
this->name_ = other.name_;
......@@ -50,7 +50,7 @@ class Named : public Value {
std::size_t Hash() const { return hash_id_; }
std::size_t hash() const override { return hash_id_; }
friend std::ostream& operator<<(std::ostream& os, const Named& nmd) {
friend std::ostream &operator<<(std::ostream &os, const Named &nmd) {
os << nmd.name();
return os;
}
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -21,8 +21,8 @@
#include "pipeline/parse/data_converter.h"
namespace mindspore {
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind,
const py::object& arg_default, const SignatureEnumDType& arg_dtype)
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind,
const py::object &arg_default, const SignatureEnumDType &arg_dtype)
: name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) {
if (py::isinstance<SignatureEnumKind>(arg_default) &&
py::cast<SignatureEnumKind>(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) {
......@@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag,
}
}
Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind)
Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind)
: name(arg_name),
rw(rw_tag),
kind(arg_kind),
default_value(nullptr),
dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {}
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) {
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册