diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 5738decdb8f5019cfcbec48f508f32b3fa73e5e4..bf752a5df370a671a38597762b21c721d767a4b6 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -242,3 +242,5 @@ from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe from .confusion_matrix import _confusion_matrix_tbe from .broadcast_to import _broadcast_to_tbe +from .strided_read import _strided_read_tbe +from .strided_write import _strided_write_tbe diff --git a/mindspore/ops/_op_impl/tbe/strided_read.py b/mindspore/ops/_op_impl/tbe/strided_read.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebd29f8f26e1cd6a7413a85575a65705f7f5290 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/strided_read.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""StridedRead op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +strided_read_op_info = TBERegOp("StridedRead") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("strided_read.so") \ + .compute_cost(10) \ + .kernel_name("strided_read") \ + .partial_flag(True) \ + .attr("axis", "required", "int", "all") \ + .attr("stride", "required", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .get_op_info() + + +@op_info_register(strided_read_op_info) +def _strided_read_tbe(): + """StridedRead TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/strided_write.py b/mindspore/ops/_op_impl/tbe/strided_write.py new file mode 100644 index 0000000000000000000000000000000000000000..feda752b284b77f881539f3b4173d53a95f2c190 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/strided_write.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""StridedWrite op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +strided_write_op_info = TBERegOp("StridedWrite") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("strided_write.so") \ + .compute_cost(10) \ + .kernel_name("strided_write") \ + .partial_flag(True) \ + .attr("axis", "required", "int", "all") \ + .attr("stride", "required", "int", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .get_op_info() + + +@op_info_register(strided_write_op_info) +def _strided_write_tbe(): + """StridedWrite TBE register""" + return