提交 ea9b5468 编写于 作者: W WilliamLian

fix bug of hccl kernel info

上级 82b4cada
......@@ -16,12 +16,30 @@
#include "kernel/hccl/hccl_kernel_metadata.h"
#include <memory>
#include <set>
#include "utils/utils.h"
#include "kernel/hccl/hcom_util.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
namespace {
std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) {
const std::set<std::string> kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0};
auto op_name = AnfAlgo::GetCNodeName(kernel_node);
auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index);
if (op_name != kReduceScatter && op_name != kAllGatherOpName) {
return format;
}
if (format == kOpFormat_FRAC_NZ && AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index).size() <= 2) {
return kOpFormat_DEFAULT;
}
if (kReduceNoSupportedSet.find(format) != kReduceNoSupportedSet.end()) {
return kOpFormat_DEFAULT;
}
return format;
}
} // namespace
void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
const std::vector<TypeId> kHcclSupportTypes = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16,
kNumberTypeFloat32, kNumberTypeInt16};
......@@ -36,13 +54,13 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
std::vector<std::string> inputs_format{};
std::vector<TypeId> inputs_type{};
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index));
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
inputs_type.push_back(type);
}
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_type;
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index));
outputs_format.emplace_back(GetKernelFormat(kernel_node, output_index));
outputs_type.push_back(type);
}
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
......
......@@ -428,5 +428,5 @@ def test_pynative_resnet50():
cost_time = end_time - start_time
print("======step: ", step, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
if step > 1:
assert cost_time < 0.5
assert cost_time < 0.3
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册