提交 17d71280 编写于 作者: G geekun

fix codex and support akg op profiling

上级 980b67d1
......@@ -221,6 +221,11 @@ std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
}
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string fg_name = "GraphKernel_";
fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
return fg_name;
}
return func_graph->ToString();
}
MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
......
......@@ -167,7 +167,7 @@ std::string ProfilingUtils::GetGraphLastTbeKernelName(const std::vector<CNodePtr
std::string last_tbe_kernel_name;
// find last tbe_kernel
for (auto iter = cnode_exec_order.rbegin(); iter != cnode_exec_order.rend(); ++iter) {
if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL) {
if (AnfAlgo::GetKernelType(*iter) == TBE_KERNEL || AnfAlgo::GetKernelType(*iter) == AKG_KERNEL) {
last_tbe_kernel_name = (*iter)->fullname_with_scope();
break;
}
......@@ -319,7 +319,7 @@ void ProfilingUtils::SetGraphProfilingCNode(uint32_t graph_id, const std::vector
bool ProfilingUtils::ValidComputeGraph(NotNull<const session::KernelGraph *> graph_ptr) {
for (const auto &node : graph_ptr->execution_order()) {
if (AnfAlgo::GetKernelType(node) == TBE_KERNEL) {
if (AnfAlgo::GetKernelType(node) == TBE_KERNEL || AnfAlgo::GetKernelType(node) == AKG_KERNEL) {
return true;
}
}
......
......@@ -91,6 +91,7 @@ class PBinOperation : public PBase<PBinOperation<T, T2> > {
public:
PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false)
: prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {}
~PBinOperation() = default;
AnfNodePtr GetNode(const AnfNodePtr &node) const {
AnfNodePtr lhs = x_.GetNode(node->func_graph());
......@@ -282,6 +283,7 @@ template <typename... TArgs>
class PPrimitive : public PBase<PPrimitive<TArgs...> > {
public:
explicit PPrimitive(const PrimitivePtr &prim, const TArgs &... args) : prim_(prim), args_(args...) {}
~PPrimitive() = default;
AnfNodePtr GetNode(const AnfNodePtr &node) const {
tuple_utils::PTupleGetNode get_node(node);
......@@ -378,6 +380,7 @@ class PConstant : public PBase<PConstant<T> > {
check_value_(check_value),
is_scalar_(is_scalar) {}
~PConstant() = default;
// Sets as_node_ as the node received as argument to produce a same-shape node with GetNode
const PConstant<T> &WithShapeAs(const AnfNodePtr &node) const {
if (node == nullptr) {
......@@ -556,7 +559,9 @@ class PConstant : public PBase<PConstant<T> > {
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
if (x == nullptr) {
memset_s(data, mem_size, 0, mem_size);
if (memset_s(data, mem_size, 0, mem_size) != 0) {
return nullptr;
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
......@@ -588,14 +593,19 @@ class PConstant : public PBase<PConstant<T> > {
if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) {
return nullptr;
}
int ret = 0;
char *source_data = reinterpret_cast<char *>(GetPointerToTensorData(x));
if (x_tensor_ptr->DataSize() == 1) {
for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) {
memcpy_s(data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr), source_data,
GetTypeByte(tensor_type_ptr));
ret = memcpy_s(data + i * GetTypeByte(tensor_type_ptr), GetTypeByte(tensor_type_ptr), source_data,
GetTypeByte(tensor_type_ptr));
}
} else {
memcpy_s(data, mem_size, source_data, mem_size);
ret = memcpy_s(data, mem_size, source_data, mem_size);
}
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
<< new_tensor_ptr->DataSize();
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
......@@ -615,7 +625,9 @@ class PConstant : public PBase<PConstant<T> > {
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
memset_s(data, mem_size, value, mem_size);
if (memset_s(data, mem_size, value, mem_size) != 0) {
return nullptr;
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_vnode;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册