提交 99a6033a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3032 Refactor the akg op registers

Merge pull request !3032 from DeshiChen/0711_akg_op_register_master
...@@ -103,6 +103,7 @@ class OpInfo { ...@@ -103,6 +103,7 @@ class OpInfo {
partial_flag_ = opinfo.partial_flag_; partial_flag_ = opinfo.partial_flag_;
dynamic_format_ = opinfo.dynamic_format_; dynamic_format_ = opinfo.dynamic_format_;
op_pattern_ = opinfo.op_pattern(); op_pattern_ = opinfo.op_pattern();
processor_ = opinfo.processor_;
for (const auto &attr : opinfo.attrs_ptr()) { for (const auto &attr : opinfo.attrs_ptr()) {
attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
} }
...@@ -121,6 +122,7 @@ class OpInfo { ...@@ -121,6 +122,7 @@ class OpInfo {
std::string fusion_type() const { return fusion_type_; } std::string fusion_type() const { return fusion_type_; }
std::string kernel_name() const { return kernel_name_; } std::string kernel_name() const { return kernel_name_; }
OpPattern op_pattern() const { return op_pattern_; } OpPattern op_pattern() const { return op_pattern_; }
std::string processor() const { return processor_; }
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
...@@ -136,6 +138,7 @@ class OpInfo { ...@@ -136,6 +138,7 @@ class OpInfo {
void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
void set_processor(const std::string &processor) { processor_ = processor; }
void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
...@@ -144,6 +147,10 @@ class OpInfo { ...@@ -144,6 +147,10 @@ class OpInfo {
void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
void ClearInputs() { (void)inputs_ptr_.clear(); } void ClearInputs() { (void)inputs_ptr_.clear(); }
void ClearOutputs() { (void)outputs_ptr_.clear(); } void ClearOutputs() { (void)outputs_ptr_.clear(); }
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
this->processor_ == other_info->processor_;
}
private: private:
std::string op_name_; std::string op_name_;
...@@ -157,6 +164,7 @@ class OpInfo { ...@@ -157,6 +164,7 @@ class OpInfo {
bool partial_flag_ = false; bool partial_flag_ = false;
bool dynamic_format_ = false; bool dynamic_format_ = false;
OpPattern op_pattern_ = kCommonPattern; OpPattern op_pattern_ = kCommonPattern;
std::string processor_;
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;
......
...@@ -45,9 +45,10 @@ constexpr auto kAttr = "attr"; ...@@ -45,9 +45,10 @@ constexpr auto kAttr = "attr";
constexpr auto kIputs = "inputs"; constexpr auto kIputs = "inputs";
constexpr auto kOutputs = "outputs"; constexpr auto kOutputs = "outputs";
constexpr auto kAiCPU = "AiCPU"; constexpr auto kAiCPU = "AiCPU";
constexpr auto kAiCore = "AiCore";
constexpr auto kCUDA = "CUDA";
constexpr auto kTbe = "TBE"; constexpr auto kTbe = "TBE";
constexpr auto kAkg = "akg"; constexpr auto kAkg = "AKG";
constexpr auto kAutodiff = "AutoDiff";
constexpr auto kName = "name"; constexpr auto kName = "name";
constexpr auto kParamType = "param_type"; constexpr auto kParamType = "param_type";
constexpr auto kDtype = "dtype"; constexpr auto kDtype = "dtype";
...@@ -58,6 +59,7 @@ constexpr auto kIndex = "index"; ...@@ -58,6 +59,7 @@ constexpr auto kIndex = "index";
constexpr auto kFormat = "format"; constexpr auto kFormat = "format";
constexpr auto kNeedCompile = "need_compile"; constexpr auto kNeedCompile = "need_compile";
constexpr auto kShape = "shape"; constexpr auto kShape = "shape";
constexpr auto kProcessor = "processor";
std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_; std::vector<std::shared_ptr<OpInfo>> OpLib::op_info_;
static std::string ImplTypeToStr(OpImplyType impl_type) { static std::string ImplTypeToStr(OpImplyType impl_type) {
...@@ -81,7 +83,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) ...@@ -81,7 +83,7 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path)
if (imply_type_string == kTbe) { if (imply_type_string == kTbe) {
OpImplyType imply_type = kTBE; OpImplyType imply_type = kTBE;
ret = DecodeOpInfo(op_json, imply_type, impl_path); ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kAutodiff) { } else if (imply_type_string == kAkg) {
OpImplyType imply_type = kAKG; OpImplyType imply_type = kAKG;
ret = DecodeOpInfo(op_json, imply_type, impl_path); ret = DecodeOpInfo(op_json, imply_type, impl_path);
} else if (imply_type_string == kAiCPU) { } else if (imply_type_string == kAiCPU) {
...@@ -125,6 +127,11 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p ...@@ -125,6 +127,11 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
} }
} }
void OpLib::DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info) {
MS_EXCEPTION_IF_NULL(op_info);
op_info->set_processor(obj.at(kProcessor));
}
bool OpLib::RegOpFromLocalInfo() { bool OpLib::RegOpFromLocalInfo() {
MS_LOG(INFO) << "Start"; MS_LOG(INFO) << "Start";
static bool has_load = false; static bool has_load = false;
...@@ -179,6 +186,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI ...@@ -179,6 +186,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI
op_info->set_fusion_type(obj.at(kFusionType)); op_info->set_fusion_type(obj.at(kFusionType));
if (imply_type == kTBE) { if (imply_type == kTBE) {
DecodeTBESpecificInfo(obj, op_info); DecodeTBESpecificInfo(obj, op_info);
} else if (imply_type == kAKG) {
DecodeAKGSpecificInfo(obj, op_info);
} }
auto attrs = obj.at(kAttr); auto attrs = obj.at(kAttr);
for (const auto &attr : attrs) { for (const auto &attr : attrs) {
...@@ -330,7 +339,12 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im ...@@ -330,7 +339,12 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
for (const auto &op_info : op_info_) { for (const auto &op_info : op_info_) {
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) {
return op_info; auto akg_processor_match = [&]() {
return is_gpu ? op_info->processor() == kCUDA : op_info->processor() == kAiCore;
};
if (imply_type != kAKG || akg_processor_match()) {
return op_info;
}
} }
} }
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
...@@ -363,19 +377,14 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) { ...@@ -363,19 +377,14 @@ bool OpLib::GetRefInfo(const std::shared_ptr<OpInfo> &op_info) {
} }
bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) { bool OpLib::CheckRepetition(const std::shared_ptr<OpInfo> &op_info) {
bool has_register = false;
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
for (const auto &exist_op_info : op_info_) { for (const auto &exist_op_info : op_info_) {
MS_EXCEPTION_IF_NULL(exist_op_info); MS_EXCEPTION_IF_NULL(exist_op_info);
if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && if (exist_op_info->equals_to(op_info)) {
exist_op_info->impl_path() == op_info->impl_path()) { return true;
MS_LOG(INFO) << "Op has already exist, please use other name, op name: " << op_info->op_name()
<< " op type: " << ImplTypeToStr(op_info->imply_type());
has_register = true;
break;
} }
} }
return has_register; return false;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -44,6 +44,7 @@ class OpLib { ...@@ -44,6 +44,7 @@ class OpLib {
static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io, static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr<OpIOInfo> &op_io,
size_t index); size_t index);
static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info); static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info);
static void DecodeAKGSpecificInfo(const nlohmann::json &obj, const std::shared_ptr<OpInfo> &op_info);
static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type,
const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format); const std::shared_ptr<OpInfo> &op_info, const nlohmann::json &dtype_format);
static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info); static bool GetRefInfo(const std::shared_ptr<OpInfo> &op_info);
......
...@@ -32,7 +32,7 @@ Note: ...@@ -32,7 +32,7 @@ Note:
from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AkgRegOp, AiCPURegOp, TBERegOp, DataType from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind from .._c_expression import signature_rw, signature_kind
...@@ -42,6 +42,6 @@ __primitive__ = [ ...@@ -42,6 +42,6 @@ __primitive__ = [
] ]
__all__ = ["get_vm_impl_fn", "vm_impl_registry", __all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "AkgRegOp", "AiCPURegOp", "TBERegOp", "DataType", "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
"constexpr"] "constexpr"]
__all__.extend(__primitive__) __all__.extend(__primitive__)
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import platform import platform
from .aicpu import * from .aicpu import *
if "Windows" not in platform.system(): if "Windows" not in platform.system():
from .akg.gpu import * from .akg import *
from .tbe import * from .tbe import *
__all__ = [] __all__ = []
...@@ -13,77 +13,6 @@ ...@@ -13,77 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""autodiff ops""" """akg ops"""
from .abs import _abs_akg from . import ascend
from .add_n import _add_n_akg
from .add import _add_akg
from .apply_momentum import _apply_momentum_akg
from .assign import _assign_akg
from .inplace_assign import _inplace_assign_akg
from .assign_add import _assign_add_akg
from .bias_add_grad import _bias_add_grad_akg
from .bias_add import _bias_add_akg
from .cast import _cast_akg
from .clear_zero import _clear_zero_akg
from .conv_bn1 import _conv_bn1_akg
from .conv2d_backprop_filter import _conv2d_backprop_filter_akg
from .conv2d_backprop_input import _conv2d_backprop_input_akg
from .conv2d import _conv2d_akg
from .div import _div_akg
from .equal_count import _equal_count_akg
from .exp import _exp_akg
from .five2four import _five2four_akg
from .four2five import _four2five_akg
from .fused_batch_norm_grad import _fused_batch_norm_grad_akg
from .fused_batch_norm_infer import _fused_batch_norm_infer_akg
from .fused_batch_norm import _fused_batch_norm_akg
from .fused_bn1_grad import _bn1_grad_akg
from .fused_bn1 import _fused_bn1_akg
from .fused_bn2_grad import _bn2_grad_akg
from .fused_bn2 import _fused_bn2_akg
from .fused_bn3_grad import _bn3_grad_akg
from .fused_bn3 import _fused_bn3_akg
from .gather_v2 import _gather_v2_akg
from .less import _less_akg
from .log import _log_akg
from .matmul import _matmul_akg
from .batchmatmul import _batchmatmul_akg
from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_akg
from .max_pool_with_argmax import _max_pool_with_argmax_akg
from .max import _max_akg
from .maximum import _maximum_akg
from .mean_grad import _mean_grad_akg
from .mean import _mean_akg
from .minimum import _minimum_akg
from .mul import _mul_akg
from .neg import _neg_akg
from .one_hot import _one_hot_akg
from .pow import _power_akg
from .real_div import _real_div_akg
from .reciprocal import _reciprocal_akg
from .reduce_max import _reduce_max_akg
from .reduce_mean import _reduce_mean_akg
from .reduce_sum import _reduce_sum_akg
from .relu_grad import _relu_grad_akg
from .relu import _relu_akg
from .reshape import _reshape_akg
from .round import _round_akg
from .rsqrt import _rsqrt_akg
from .select import _select_akg
from .softmax import _softmax_akg
from .sparse_softmax_cross_entropy_with_logits import _sparse_softmax_cross_entropy_with_logits_akg
from .sqrt import _sqrt_akg
from .strided_slice import _strided_slice_akg
from .sub import _sub_akg
from .sum import _sum_akg
from .tile import _tile_akg
from .zeros_like import _zeros_like_akg
from .argmax import _argmax_akg
from .floordiv import _floor_div_akg
from .equal import _equal_akg
from .greater_equal import _greater_equal_akg
from .less_equal import _less_equal_akg
from .expand_dims import _expand_dims_akg
from .greater import _greater_akg
from .equiv_format import _equiv_format_akg
from . import gpu from . import gpu
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""__init__"""
from .add import _add_akg
from .batchmatmul import _batchmatmul_akg
from .cast import _cast_akg
from .expand_dims import _expand_dims_akg
from .greater import _greater_akg
from .inplace_assign import _inplace_assign_akg
from .maximum import _maximum_akg
from .minimum import _minimum_akg
from .mul import _mul_akg
from .real_div import _real_div_akg
from .rsqrt import _rsqrt_akg
from .select import _select_akg
from .sqrt import _sqrt_akg
from .sub import _sub_akg
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TensorAdd op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("TensorAdd") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \
.dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \
.dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \
.dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _add_akg():
"""TensorAdd Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BatchMatMul op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("BatchMatMul") \
.fusion_type("OPAQUE") \
.input(0, "x1") \
.input(1, "x2") \
.output(0, "output") \
.attr("transpose_a", "optional", "bool") \
.attr("transpose_b", "optional", "bool") \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _batchmatmul_akg():
"""BatchMatMul AKG register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cast op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Cast") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "output") \
.attr("dst_type", "required", "str") \
.dtype_format(DT.F16_Default, DT.F32_Default) \
.dtype_format(DT.F16_Default, DT.I32_Default) \
.dtype_format(DT.F32_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.I32_Default) \
.dtype_format(DT.I32_Default, DT.F16_Default) \
.dtype_format(DT.I32_Default, DT.F32_Default) \
.dtype_format(DT.BOOL_Default, DT.F16_Default) \
.dtype_format(DT.BOOL_Default, DT.F32_Default) \
.dtype_format(DT.BOOL_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F32_5HD) \
.dtype_format(DT.F32_5HD, DT.F16_5HD) \
.dtype_format(DT.BOOL_5HD, DT.I32_5HD) \
.dtype_format(DT.BOOL_5HD, DT.F32_5HD) \
.dtype_format(DT.F16_FracNZ, DT.F32_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.BOOL_FracNZ, DT.I32_FracNZ) \
.dtype_format(DT.BOOL_FracNZ, DT.F32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _cast_akg():
"""Cast Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ExpandDims op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("ExpandDims") \
.fusion_type("OPAQUE") \
.input(0, "x") \
.output(0, "y") \
.attr("axis", "required", "int") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.get_op_info()
@op_info_register(op_info)
def _expand_dims_akg():
"""ExpandDims Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Greater op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Greater") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.BOOL_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.BOOL_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.BOOL_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.BOOL_5HD) \
.get_op_info()
@op_info_register(op_info)
def _greater_akg():
"""Greater Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""InplaceAssign op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("InplaceAssign") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.input(2, "z") \
.output(0, "output") \
.attr("fake_output", "optional", "bool") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \
.dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \
.dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \
.get_op_info()
@op_info_register(op_info)
def _inplace_assign_akg():
"""InplaceAssign Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Maximum op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Maximum") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.get_op_info()
@op_info_register(op_info)
def _maximum_akg():
"""Maximum Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Minimum op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Minimum") \
.fusion_type("COMMREDUCE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \
.dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _minimum_akg():
"""Minimum Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mul op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Mul") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.attr("x_shape", "required", "listInt") \
.attr("y_shape", "required", "listInt") \
.attr("data_format", "required", "listStr") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \
.dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _mul_akg():
"""Mul Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RealDiv op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("RealDiv") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _real_div_akg():
"""RealDiv Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rsqrt op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Rsqrt") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD) \
.get_op_info()
@op_info_register(op_info)
def _rsqrt_akg():
"""Rsqrt Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Select op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Select") \
.fusion_type("ELEMWISE") \
.input(0, "condition") \
.input(1, "x") \
.input(2, "y") \
.output(0, "output") \
.dtype_format(DT.BOOL_Default, DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.BOOL_Default, DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.BOOL_Default, DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.BOOL_5HD, DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.BOOL_5HD, DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.BOOL_5HD, DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.get_op_info()
@op_info_register(op_info)
def _select_akg():
"""Select Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sqrt op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Sqrt") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD) \
.get_op_info()
@op_info_register(op_info)
def _sqrt_akg():
"""Sqrt Akg register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sub op"""
from mindspore.ops.op_info_register import op_info_register, AkgAscendRegOp, DataType as DT
op_info = AkgAscendRegOp("Sub") \
.fusion_type("ELEMWISE") \
.input(0, "x") \
.input(1, "y") \
.output(0, "output") \
.dtype_format(DT.F16_Default, DT.F16_Default, DT.F16_Default) \
.dtype_format(DT.F32_Default, DT.F32_Default, DT.F32_Default) \
.dtype_format(DT.I32_Default, DT.I32_Default, DT.I32_Default) \
.dtype_format(DT.F16_5HD, DT.F16_5HD, DT.F16_5HD) \
.dtype_format(DT.F32_5HD, DT.F32_5HD, DT.F32_5HD) \
.dtype_format(DT.I32_5HD, DT.I32_5HD, DT.I32_5HD) \
.dtype_format(DT.F16_FracZ, DT.F16_FracZ, DT.F16_FracZ) \
.dtype_format(DT.F32_FracZ, DT.F32_FracZ, DT.F32_FracZ) \
.dtype_format(DT.I32_FracZ, DT.I32_FracZ, DT.I32_FracZ) \
.dtype_format(DT.F16_FracNZ, DT.F16_FracNZ, DT.F16_FracNZ) \
.dtype_format(DT.F32_FracNZ, DT.F32_FracNZ, DT.F32_FracNZ) \
.dtype_format(DT.I32_FracNZ, DT.I32_FracNZ, DT.I32_FracNZ) \
.get_op_info()
@op_info_register(op_info)
def _sub_akg():
"""Sub Akg register"""
return
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
# limitations under the License. # limitations under the License.
"""Cast op""" """Cast op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
cast_op_info = AkgRegOp("Cast") \ cast_op_info = AkgGpuRegOp("Cast") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
.attr("dst_type", "required", "str") \ .attr("dst_type", "required", "str") \
.dtype_format(DataType.F16_Default, DataType.F32_Default) \ .dtype_format(DataType.F16_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.F32_Default) \ .dtype_format(DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \ .dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
.get_op_info() .get_op_info()
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Equal op""" """Equal op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
equal_op_info = AkgRegOp("Equal") \ equal_op_info = AkgGpuRegOp("Equal") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""GreaterEqual op""" """GreaterEqual op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
greater_equal_op_info = AkgRegOp("GreaterEqual") \ greater_equal_op_info = AkgGpuRegOp("GreaterEqual") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""HSigmoid op""" """HSigmoid op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
hsigmoid_op_info = AkgRegOp("HSigmoid") \ hsigmoid_op_info = AkgGpuRegOp("HSigmoid") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""HSigmoidGrad op""" """HSigmoidGrad op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
hsigmoidgrad_op_info = AkgRegOp("HSigmoidGrad") \ hsigmoidgrad_op_info = AkgGpuRegOp("HSigmoidGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "y_grad") \ .input(0, "y_grad") \
.input(1, "x") \ .input(1, "x") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""HSwish op""" """HSwish op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
hswish_op_info = AkgRegOp("HSwish") \ hswish_op_info = AkgGpuRegOp("HSwish") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""HSwishGrad op""" """HSwishGrad op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
hswish_grad_op_info = AkgRegOp("HSwishGrad") \ hswish_grad_op_info = AkgGpuRegOp("HSwishGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "y_grad") \ .input(0, "y_grad") \
.input(1, "x") \ .input(1, "x") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""LessEqual op""" """LessEqual op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
lessequal_op_info = AkgRegOp("LessEqual") \ lessequal_op_info = AkgGpuRegOp("LessEqual") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""LogicalAnd op""" """LogicalAnd op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
logicaland_op_info = AkgRegOp("LogicalAnd") \ logicaland_op_info = AkgGpuRegOp("LogicalAnd") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
...@@ -23,6 +23,7 @@ logicaland_op_info = AkgRegOp("LogicalAnd") \ ...@@ -23,6 +23,7 @@ logicaland_op_info = AkgRegOp("LogicalAnd") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()
@op_info_register(logicaland_op_info) @op_info_register(logicaland_op_info)
def _logical_and_akg(): def _logical_and_akg():
"""LogicalAnd register""" """LogicalAnd register"""
......
...@@ -13,15 +13,16 @@ ...@@ -13,15 +13,16 @@
# limitations under the License. # limitations under the License.
"""LogicalNot op""" """LogicalNot op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
logical_not_op_info = AkgRegOp("LogicalNot") \ logical_not_op_info = AkgGpuRegOp("LogicalNot") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()
@op_info_register(logical_not_op_info) @op_info_register(logical_not_op_info)
def _logical_not_akg(): def _logical_not_akg():
"""LogicalNot AutoDiff register""" """LogicalNot AutoDiff register"""
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""LogicalOr op""" """LogicalOr op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
logicalor_op_info = AkgRegOp("LogicalOr") \ logicalor_op_info = AkgGpuRegOp("LogicalOr") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
...@@ -23,6 +23,7 @@ logicalor_op_info = AkgRegOp("LogicalOr") \ ...@@ -23,6 +23,7 @@ logicalor_op_info = AkgRegOp("LogicalOr") \
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()
@op_info_register(logicalor_op_info) @op_info_register(logicalor_op_info)
def _logical_or_akg(): def _logical_or_akg():
"""LogicalOr register""" """LogicalOr register"""
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""SimpleMean op""" """SimpleMean op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
mean_op_info = AkgRegOp("SimpleMean") \ mean_op_info = AkgGpuRegOp("SimpleMean") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""SimpleMeanGrad op""" """SimpleMeanGrad op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
mean_grad_op_info = AkgRegOp("SimpleMeanGrad") \ mean_grad_op_info = AkgGpuRegOp("SimpleMeanGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "HEAD") \ .input(0, "HEAD") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Mul op""" """Mul op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
mul_op_info = AkgRegOp("Mul") \ mul_op_info = AkgGpuRegOp("Mul") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""NotEqual op""" """NotEqual op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
notequal_op_info = AkgRegOp("NotEqual") \ notequal_op_info = AkgGpuRegOp("NotEqual") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""ReLU6 op""" """ReLU6 op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
relu_op_info = AkgRegOp("ReLU6") \ relu_op_info = AkgGpuRegOp("ReLU6") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""ReLU6Grad op""" """ReLU6Grad op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
relu_grad_op_info = AkgRegOp("ReLU6Grad") \ relu_grad_op_info = AkgGpuRegOp("ReLU6Grad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "y_grad") \ .input(0, "y_grad") \
.input(1, "x") \ .input(1, "x") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Squeeze op""" """Squeeze op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
squeeze_op_info = AkgRegOp("Squeeze") \ squeeze_op_info = AkgGpuRegOp("Squeeze") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""SqueezeGrad op""" """SqueezeGrad op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
squeeze_grad_op_info = AkgRegOp("SqueezeGrad") \ squeeze_grad_op_info = AkgGpuRegOp("SqueezeGrad") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "y_grad") \ .input(0, "y_grad") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Sub op""" """Sub op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
sub_op_info = AkgRegOp("Sub") \ sub_op_info = AkgGpuRegOp("Sub") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.input(1, "y") \ .input(1, "y") \
...@@ -25,6 +25,7 @@ sub_op_info = AkgRegOp("Sub") \ ...@@ -25,6 +25,7 @@ sub_op_info = AkgRegOp("Sub") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info() .get_op_info()
@op_info_register(sub_op_info) @op_info_register(sub_op_info)
def _sub_akg(): def _sub_akg():
"""Sub AutoDiff register""" """Sub AutoDiff register"""
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Tile op""" """Tile op"""
from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType from mindspore.ops.op_info_register import op_info_register, AkgGpuRegOp, DataType
tile_op_info = AkgRegOp("Tile") \ tile_op_info = AkgGpuRegOp("Tile") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.input(0, "x") \ .input(0, "x") \
.output(0, "output") \ .output(0, "output") \
......
...@@ -215,10 +215,10 @@ class RegOp: ...@@ -215,10 +215,10 @@ class RegOp:
class AkgRegOp(RegOp): class AkgRegOp(RegOp):
"""Class for Akg op info register.""" """Class for Akg op info register."""
def __init__(self, op_name): def __init__(self, op_name, processor):
super(AkgRegOp, self).__init__(op_name) super(AkgRegOp, self).__init__(op_name)
self.imply_type = "AutoDiff" self.imply_type = "AKG"
self.processor = "cuda" self.processor = processor
def input(self, index=None, name=None, **kwargs): def input(self, index=None, name=None, **kwargs):
""" """
...@@ -270,6 +270,16 @@ class AkgRegOp(RegOp): ...@@ -270,6 +270,16 @@ class AkgRegOp(RegOp):
return self return self
class AkgGpuRegOp(AkgRegOp):
def __init__(self, op_name):
super(AkgGpuRegOp, self).__init__(op_name, "CUDA")
class AkgAscendRegOp(AkgRegOp):
def __init__(self, op_name):
super(AkgAscendRegOp, self).__init__(op_name, "AiCore")
class AiCPURegOp(RegOp): class AiCPURegOp(RegOp):
"""Class for AiCPU op info register""" """Class for AiCPU op info register"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册