提交 04fdeddf 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1299 cpu kernel support multiple dtype

Merge pull request !1299 from sunsuodong/ops_int32
......@@ -59,6 +59,7 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
TypeId dtype = kTypeUnknown;
if (IsInputNotCNode(kernel_node, input_index)) {
input_no_cnode_indexes->emplace_back(input_index);
dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
} else {
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
......@@ -84,22 +85,25 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes) {
if (kernel_attr.GetInputSize() != input_types.size()) {
MS_LOG(ERROR) << "Output num is not equal!";
MS_LOG(ERROR) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
return false;
}
auto input_num = input_types.size();
for (size_t i = 0; i < input_num; ++i) {
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
[i](size_t index) { return index == i; });
if (is_not_cnode_idx) {
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size());
if (have_cnode_input && is_not_cnode_idx) {
continue;
}
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i];
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
<< ", actual input dtype:" << input_types[i];
return false;
}
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i];
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
<< ", actual input format:" << input_formats[i];
return false;
}
}
......@@ -114,17 +118,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
auto kernel_attrs =
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
for (auto &kernel_attr : kernel_attrs) {
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
for (auto &input_index : input_not_cnode_indexes) {
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
}
break;
}
......
......@@ -55,10 +55,12 @@ class CPUKernelRegistrar {
~CPUKernelRegistrar() = default;
};
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS)
#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS)
#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
......
......@@ -35,10 +35,18 @@ class ReshapeCPUKernel : public CPUKernel {
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ReshapeCPUKernel);
} // namespace kernel
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册