From eba74ca8f784ed71c08c853fbe40ab002edabaac Mon Sep 17 00:00:00 2001 From: liuxiao Date: Wed, 6 May 2020 19:36:00 +0800 Subject: [PATCH] add ops for VM --- mindspore/ops/_grad/grad_nn_ops.py | 3 +- mindspore/ops/_op_impl/tbe/__init__.py | 2 ++ mindspore/ops/_op_impl/tbe/elu.py | 40 ++++++++++++++++++++++++ mindspore/ops/_op_impl/tbe/elu_grad.py | 43 ++++++++++++++++++++++++++ mindspore/ops/operations/nn_ops.py | 3 +- tests/ut/python/ops/test_ops.py | 2 +- 6 files changed, 89 insertions(+), 4 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/elu.py create mode 100644 mindspore/ops/_op_impl/tbe/elu_grad.py diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 153abc0fb..362bda736 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -600,7 +600,6 @@ def get_bprop_roi_align(self): sample_num = self.sample_num def bprop(inputs, rois, out, dout): - rois_shape = shape_op(rois) inputs_shape = shape_op(inputs) dx = G.ROIAlignGrad(inputs_shape, pooled_height, @@ -608,7 +607,7 @@ def get_bprop_roi_align(self): spatial_scale, sample_num, )(dout, rois) - return dx, zeros_like(rois_shape) + return dx, zeros_like(rois) return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 73afef73a..738c9ef47 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -76,6 +76,8 @@ from .strided_slice_d import _strided_slice_d_tbe from .strided_slice_grad_d import _strided_slice_grad_d_tbe from .split_d import _split_d_tbe from .exp import _exp_tbe +from .elu import _elu_tbe +from .elu_grad import _elu_grad_tbe from .div import _div_tbe from .log import _log_tbe from .floor_div import _floor_div_tbe diff --git a/mindspore/ops/_op_impl/tbe/elu.py b/mindspore/ops/_op_impl/tbe/elu.py new file mode 100644 index 000000000..9125d1472 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/elu.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Elu op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +elu_op_info = TBERegOp("Elu") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("elu.so") \ + .compute_cost(10) \ + .kernel_name("elu") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .attr("alpha", "optional", "float", "all", "1.0") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(elu_op_info) +def _elu_tbe(): + """Elu TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/elu_grad.py b/mindspore/ops/_op_impl/tbe/elu_grad.py new file mode 100644 index 000000000..c3486dd02 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/elu_grad.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""EluGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +elu_grad_op_info = TBERegOp("EluGrad") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("elu_grad.so") \ + .compute_cost(10) \ + .kernel_name("elu_grad") \ + .partial_flag(True) \ + .input(0, "grads", False, "required", "all") \ + .input(1, "activations", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(elu_grad_op_info) +def _elu_grad_tbe(): + """EluGrad TBE register""" + return diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2a2dbe08a..7ba341fd5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1527,7 +1527,8 @@ class L2Loss(PrimitiveWithInfer): def infer_dtype(self, x_type): validator.check_subclass("x_type", x_type, mstype.tensor, self.name) - validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name) + valid_types = [mstype.float16, mstype.float32, mstype.double] + validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name) return x_type diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 9d7e8c898..7a3d7d967 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -874,7 +874,7 @@ test_case_nn_ops = [ 'skip': ['backward']}), ('L2Loss_1', { 'block': P.L2Loss(), - 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)], + 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)], 'desc_bprop': []}), ('L2Loss_2', { 'block': P.L2Loss(), -- GitLab