提交 ff719bda 编写于 作者: L lingyunli63

rm mean op

上级 5237a9a0
......@@ -26,7 +26,6 @@ from .logical_or import LogicalOr
from .relu6_grad import ReLU6Grad
from .squeeze import Squeeze
from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
from .mean import SimpleMean
from .sub import Sub
from .mul import Mul
from .hsigmoid import HSigmoid
......
#!/usr/bin/env python3
# coding: utf-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""mean op compute and schedule"""
from .default_schedule import DEFAULT_GPU_THREAD
from akg.ops.math_gpu.sum_value import sum_value
import akg
from akg.ops.math_gpu.mean import mean
def gpu_schedule_Mean(outs):
"""
gpu schedule function for mean.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
out = outs[0] if isinstance(outs, list) else outs
sch = tvm.create_schedule(out.op)
if out.op.name == "T_divide":
tensor_c = out
else: # squeeze
tensor_c = out.op.input_tensors[0]
tensor_b = tensor_c.op.input_tensors[0]
if len(tensor_c.op.axis) >= 2:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1])
else:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0])
bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD)
sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
@akg.schedule(gpu_schedule_Mean)
def Mean(data, axis=None, keepdims=False):
return mean(data, axis, keepdims)
@akg.schedule(gpu_schedule_Mean)
def SimpleMean(x):
"""
SimpleMean compute the mean of the input 4D Tensor over last two axises and keep reduced dimensions.
Args:
x (tvm.tensor.Tensor): Tensor of type float16, float32.
Returns:
tvm.tensor.Tensor, has the same type as x, output shape will be (a, b, 1, 1) if input Tensor x is (a, b, c, d).
"""
axis = (2, 3)
keepdims = True
return mean(x, axis, keepdims)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: mean"""
import akg.topi
import akg.tvm
from akg.utils import format_transform as ft_util
from akg.utils import validation_check as vc_util
from akg.ops.math_gpu import sum_value
@vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
def mean(data, axis=None, keepdims=False):
"""
Computes the mean of the values of a Tensor over the whole dataset.
Args:
data (tvm.tensor.Tensor): Tensor.
axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.
Returns:
tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
retained with length 1. else these reduced axis will be eliminate.
"""
shape = [x.value for x in data.shape]
vc_util.reduce_axis_check(shape, axis)
axis = ft_util.refine_reduce_axis(data, axis)
count = 1
for i in axis:
count *= shape[i]
output, _ = sum_value.sum_value(data, axis, keepdims)
res = akg.topi.divide(output, count)
return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册