提交 eba74ca8 编写于 作者: L liuxiao

add ops for VM

上级 e03359cc
......@@ -600,7 +600,6 @@ def get_bprop_roi_align(self):
sample_num = self.sample_num
def bprop(inputs, rois, out, dout):
rois_shape = shape_op(rois)
inputs_shape = shape_op(inputs)
dx = G.ROIAlignGrad(inputs_shape,
pooled_height,
......@@ -608,7 +607,7 @@ def get_bprop_roi_align(self):
spatial_scale,
sample_num,
)(dout, rois)
return dx, zeros_like(rois_shape)
return dx, zeros_like(rois)
return bprop
......
......@@ -76,6 +76,8 @@ from .strided_slice_d import _strided_slice_d_tbe
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
from .split_d import _split_d_tbe
from .exp import _exp_tbe
from .elu import _elu_tbe
from .elu_grad import _elu_grad_tbe
from .div import _div_tbe
from .log import _log_tbe
from .floor_div import _floor_div_tbe
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Elu op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
elu_op_info = TBERegOp("Elu") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("elu.so") \
.compute_cost(10) \
.kernel_name("elu") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("alpha", "optional", "float", "all", "1.0") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(elu_op_info)
def _elu_tbe():
"""Elu TBE register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""EluGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
elu_grad_op_info = TBERegOp("EluGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("elu_grad.so") \
.compute_cost(10) \
.kernel_name("elu_grad") \
.partial_flag(True) \
.input(0, "grads", False, "required", "all") \
.input(1, "activations", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(elu_grad_op_info)
def _elu_grad_tbe():
"""EluGrad TBE register"""
return
......@@ -1527,7 +1527,8 @@ class L2Loss(PrimitiveWithInfer):
def infer_dtype(self, x_type):
validator.check_subclass("x_type", x_type, mstype.tensor, self.name)
validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name)
valid_types = [mstype.float16, mstype.float32, mstype.double]
validator.check_tensor_type_same({'x_type': x_type}, valid_types, self.name)
return x_type
......
......@@ -874,7 +874,7 @@ test_case_nn_ops = [
'skip': ['backward']}),
('L2Loss_1', {
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)],
'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float32)],
'desc_bprop': []}),
('L2Loss_2', {
'block': P.L2Loss(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册