提交 5237a9a0 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!88 add GPU ops and apply fix_input_order patch

Merge pull request !88 from lingyunli63/implement_gpu_ops
...@@ -99,8 +99,9 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1) ...@@ -99,8 +99,9 @@ add_definitions(-DDMLC_LOG_CUSTOMIZE=1)
if(USE_AKG_LOG) if(USE_AKG_LOG)
add_definitions(-DUSE_AKG_LOG=1) add_definitions(-DUSE_AKG_LOG=1)
endif() endif()
if(NOT USE_CUDA) if(NOT USE_CUDA
add_definitions("-DBACKEND_D") OR ENABLE_AKG)
add_definitions("-DFIX_INPUT_ORDER_TVM")
endif() endif()
# Generic compilation options # Generic compilation options
......
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,16 +13,23 @@ ...@@ -15,16 +13,23 @@
# limitations under the License. # limitations under the License.
"""__init__""" """__init__"""
from .notequal import NotEqual
from .equal import Equal from .equal import Equal
from .equal import gpu_schedule_Equal from .greater_equal import GreaterEqual
from .less_equal import LessEqual
from .tile import Tile from .tile import Tile
from .tile import gpu_schedule_Tile
from .cast import Cast from .cast import Cast
from .relu6 import ReLU6, gpu_schedule_ReLU6 from .relu6 import ReLU6
from .relu6_grad import ReLU6Grad, gpu_schedule_ReLU6Grad from .logical_and import LogicalAnd
from .squeeze import Squeeze, gpu_schedule_Squeeze from .logical_not import LogicalNot
from .logical_or import LogicalOr
from .relu6_grad import ReLU6Grad
from .squeeze import Squeeze
from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
from .mean import SimpleMean, gpu_schedule_SimpleMean from .mean import SimpleMean
from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad from .sub import Sub
from .mul import Mul from .mul import Mul
from .hsigmoid import HSigmoid
from .hsigmoid_grad import HSigmoidGrad
from .hswish import HSwish
from .hswish_grad import HSwishGrad
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,13 +13,13 @@ ...@@ -15,13 +13,13 @@
# limitations under the License. # limitations under the License.
"""cast""" """cast"""
import logging import akg
import akg.tvm from akg.ops.math_gpu import cast
from akg.ops.math import cast
from akg.topi.generic import schedule_elemwise
import akg.topi as topi import akg.topi as topi
@akg.schedule(topi.cuda.schedule_injective) @akg.schedule(topi.cuda.schedule_injective)
def Cast(x, dst_type): def Cast(x, dst_type):
"""cast.""" """cast."""
if x.dtype == "int64" and dst_type == "float16":
x = cast.cast(x, "float32")
return cast.cast(x, dst_type) return cast.cast(x, dst_type)
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding: utf-8 # coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,30 +13,13 @@ ...@@ -13,30 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""equal""" """equal"""
import akg.tvm import akg
from akg.ops.math import equal import akg.topi as topi
from akg.topi.generic import schedule_elemwise from akg.ops.math_gpu import equal
@akg.schedule(topi.cuda.schedule_injective)
def Equal(x, y): def Equal(x, y):
"""equal.""" """Equal"""
return equal.equal(x, y) return equal.equal(x, y)
def gpu_schedule_Equal(outs):
"""
gpu schedule for Equal.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with akg.tvm.target.create(device):
sch = schedule_elemwise(outs)
return sch
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""greater_equal"""
import akg
import akg.topi as topi
from akg.ops.math_gpu import greater_equal
@akg.schedule(topi.cuda.schedule_injective)
def GreaterEqual(x, y):
"""GreaterEqual"""
return greater_equal.greater_equal(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""hsigmoid"""
import akg.topi as topi
import akg.tvm as tvm
from akg.topi import tag
import akg
@tvm.tag_scope(tag=tag.ELEMWISE)
def topi_nn_hsigmoid(x):
"""
topi hsigmoid
Args:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, 1,
(x(*i) + 3) / 6)))
@akg.schedule(topi.cuda.schedule_injective)
def HSigmoid(x):
"""
HSigmoid
Args:
x:
Returns:
"""
return topi_nn_hsigmoid(x)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSigmoid grad"""
import akg.topi as topi
import akg.tvm as tvm
import akg
@akg.schedule(topi.cuda.schedule_injective)
def HSigmoidGrad(y_grad, x):
"""
HSigmoidGrad
Args:
y_grad:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, 0,
y_grad(*i) / 6)))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwish"""
import akg.topi as topi
import akg.tvm as tvm
from akg.topi import tag
import akg
@tvm.tag_scope(tag=tag.ELEMWISE)
def topi_nn_HSwish(x):
"""
topi HSwish
Args:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, x(*i),
x(*i) * (x(*i) + 3) / 6)))
@akg.schedule(topi.cuda.schedule_injective)
def HSwish(x):
"""
HSwish
Args:
x:
Returns:
"""
return topi_nn_HSwish(x)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwishGrad"""
import akg.topi as topi
import akg.tvm as tvm
import akg
@akg.schedule(topi.cuda.schedule_injective)
def HSwishGrad(y_grad, x):
"""
HSwishGrad
Args:
y_grad:
x:
Returns:
"""
shape = x.shape
res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, y_grad(*i) * (2 * x(*i) + 3) / 6))
res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= 3, y_grad(*i), res0(*i)))
return res6
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""less_equal"""
import akg
import akg.topi
from akg.ops.math_gpu import less_equal
@akg.schedule(akg.topi.cuda.schedule_injective)
def LessEqual(x, y):
return less_equal.less_equal(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_and"""
import akg
import akg.topi as topi
from akg.ops.math_gpu import logical_and
@akg.schedule(topi.cuda.schedule_injective)
def LogicalAnd(x, y):
"""LogicalAnd."""
return logical_and.logical_and(x, y)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_not"""
import akg
import akg.topi as topi
from akg.ops.math_gpu import logical_not
@akg.schedule(topi.cuda.schedule_injective)
def LogicalNot(x):
"""LogicalNot."""
return logical_not.logical_not(x)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""logical_or"""
import akg
import akg.topi as topi
from akg.ops.math_gpu import logical_or
@akg.schedule(topi.cuda.schedule_injective)
def LogicalOr(x, y):
"""LogicalOr."""
return logical_or.logical_or(x, y)
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding: utf-8 # coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,21 +15,10 @@ ...@@ -15,21 +15,10 @@
# limitations under the License. # limitations under the License.
"""mean op compute and schedule""" """mean op compute and schedule"""
import akg.tvm as tvm
from akg.ops.math.mean import mean
from .default_schedule import DEFAULT_GPU_THREAD from .default_schedule import DEFAULT_GPU_THREAD
from akg.ops.math_gpu.sum_value import sum_value
def Mean(x, axis=None, keepdims=True): import akg
"""mean.""" from akg.ops.math_gpu.mean import mean
outs = mean(x, axis, keepdims)
# remove useless mean_output
if isinstance(outs, tuple):
outs = outs[0]
if outs.op.name == "mean_output":
outs = outs.op.input_tensors[0]
return outs
def gpu_schedule_Mean(outs): def gpu_schedule_Mean(outs):
""" """
...@@ -43,25 +32,28 @@ def gpu_schedule_Mean(outs): ...@@ -43,25 +32,28 @@ def gpu_schedule_Mean(outs):
""" """
out = outs[0] if isinstance(outs, list) else outs out = outs[0] if isinstance(outs, list) else outs
device = "cuda" sch = tvm.create_schedule(out.op)
with tvm.target.create(device): if out.op.name == "T_divide":
sch = tvm.create_schedule(out.op) tensor_c = out
if out.op.name == "T_divide": else: # squeeze
tensor_c = out tensor_c = out.op.input_tensors[0]
else: # squeeze
tensor_c = out.op.input_tensors[0]
tensor_b = tensor_c.op.input_tensors[0] tensor_b = tensor_c.op.input_tensors[0]
if len(tensor_c.op.axis) >= 2: if len(tensor_c.op.axis) >= 2:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1]) sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1])
else: else:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0]) sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0])
bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD) bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD)
sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x")) sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x")) sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch return sch
@akg.schedule(gpu_schedule_Mean)
def Mean(data, axis=None, keepdims=False):
return mean(data, axis, keepdims)
@akg.schedule(gpu_schedule_Mean)
def SimpleMean(x): def SimpleMean(x):
""" """
SimpleMean compute the mean of the input 4D Tensor over last two axises and keep reduced dimensions. SimpleMean compute the mean of the input 4D Tensor over last two axises and keep reduced dimensions.
...@@ -74,9 +66,4 @@ def SimpleMean(x): ...@@ -74,9 +66,4 @@ def SimpleMean(x):
""" """
axis = (2, 3) axis = (2, 3)
keepdims = True keepdims = True
return Mean(x, axis, keepdims) return mean(x, axis, keepdims)
def gpu_schedule_SimpleMean(outs):
"""gpu schedule function for SimpleMean."""
return gpu_schedule_Mean(outs)
#!/usr/bin/env python3
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""mean_grad"""
import akg.tvm as tvm
import akg
from akg.ops.math import mean
from .default_schedule import DEFAULT_GPU_THREAD
def mean_ad(head, input_shape, axis, keepdims):
"""mean autodiff."""
tensor_a = tvm.placeholder(input_shape, head.dtype, "A")
tensor_b = mean.mean(tensor_a, axis, keepdims)
# remove useless mean_output
if isinstance(tensor_b, tuple):
tensor_b = tensor_b[0]
if tensor_b.op.name == "mean_output":
tensor_b = tensor_b.op.input_tensors[0]
jacs = list(akg.differentiate(tensor_b, [tensor_a], head))
return jacs[0]
def MeanGrad(y_grad, input_shape, axis=None, keepdims=True):
"""Mean Grad."""
if axis is None and not keepdims:
raise ValueError("Mean not support (axis=None && keepdims=False) now")
return mean_ad(y_grad, input_shape, axis, keepdims)
def gpu_schedule_MeanGrad(outs):
"""gpu schedule MeanGrad."""
out = outs[0] if isinstance(outs, list) else outs
device = "cuda"
with tvm.target.create(device):
sch = tvm.create_schedule(out.op)
tensor_c = out
tensor_b = tensor_c.op.input_tensors[0]
if len(tensor_c.op.axis) >= 2:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[1])
else:
sch[tensor_b].compute_at(sch[tensor_c], tensor_c.op.axis[0])
bx, tx = sch[tensor_c].split(tensor_c.op.axis[0], factor=DEFAULT_GPU_THREAD)
sch[tensor_c].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[tensor_c].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
def SimpleMeanGrad(HEAD, input_shape):
"""
Compute Simple Mean Grad.
Args:
HEAD (tvm.tensor.Tensor): output gradient, dy, defined in Primitive.
input_shape (Union[list[int], tuple[int]]): shape of mean input, x.shape.
Returns:
tvm.tensor.Tensor, gradient of mean input.
"""
axis = (2, 3)
keepdims = True
return MeanGrad(HEAD, input_shape, axis, keepdims)
def gpu_schedule_SimpleMeanGrad(outs):
"""
gpu schedule SimpleMeanGrad.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
return gpu_schedule_MeanGrad(outs)
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding: utf-8 # coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
"""mul""" """mul"""
import akg import akg
import akg.topi as topi import akg.topi as topi
import akg.tvm as tvm from akg.ops.math_gpu import mul
from akg.ops.math import mul
@akg.schedule(topi.cuda.schedule_injective) @akg.schedule(topi.cuda.schedule_injective)
def Mul(x, y): def Mul(x, y):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: notequal"""
import akg
import akg.topi
from akg.ops.math_gpu import notequal
@akg.schedule(akg.topi.cuda.schedule_injective)
def NotEqual(x, y):
"""notequal."""
return notequal.notequal(x, y)
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,14 +14,16 @@ ...@@ -16,14 +14,16 @@
"""relu6""" """relu6"""
import akg.topi as topi import akg.topi as topi
import akg.tvm as tvm
from akg.topi import tag from akg.topi import tag
import akg
import akg.tvm as tvm
@tvm.tag_scope(tag=tag.ELEMWISE) @tvm.tag_scope(tag=tag.ELEMWISE)
def topi_nn_relu6(x): def topi_nn_relu6(x):
"""topi nn relu6.""" """topi nn relu6."""
return tvm.compute(x.shape, lambda *i: tvm.min(tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype))) return tvm.compute(x.shape, lambda *i: tvm.min(tvm.max(x(*i), tvm.const(0, x.dtype)), tvm.const(6, x.dtype)))
@akg.schedule(topi.cuda.schedule_injective)
def ReLU6(x): def ReLU6(x):
""" """
Compute elementwise with function: min(max(x, 0), 6). Compute elementwise with function: min(max(x, 0), 6).
...@@ -35,22 +35,3 @@ def ReLU6(x): ...@@ -35,22 +35,3 @@ def ReLU6(x):
tvm.tensor.Tensor, has same type and shape as input. tvm.tensor.Tensor, has same type and shape as input.
""" """
return topi_nn_relu6(x) return topi_nn_relu6(x)
def gpu_schedule_ReLU6(outs):
"""
gpu schedule ReLU6.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,7 +15,9 @@ ...@@ -17,7 +15,9 @@
"""relu6 grad""" """relu6 grad"""
import akg.topi as topi import akg.topi as topi
import akg.tvm as tvm import akg.tvm as tvm
import akg
@akg.schedule(topi.cuda.schedule_injective)
def ReLU6Grad(y_grad, x): def ReLU6Grad(y_grad, x):
""" """
Computes Gradients of Rectified Linear 6. Computes Gradients of Rectified Linear 6.
...@@ -39,23 +39,3 @@ def ReLU6Grad(y_grad, x): ...@@ -39,23 +39,3 @@ def ReLU6Grad(y_grad, x):
res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= six, zero, res0(*i))) res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= six, zero, res0(*i)))
res = tvm.compute(shape, lambda *i: tvm.if_then_else(res6(*i) == zero, zero, y_grad(*i))) res = tvm.compute(shape, lambda *i: tvm.if_then_else(res6(*i) == zero, zero, y_grad(*i)))
return res return res
def gpu_schedule_ReLU6Grad(outs):
"""
gpu schedule ReLU6Grad.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,8 +14,9 @@ ...@@ -16,8 +14,9 @@
"""squeeze""" """squeeze"""
import akg.topi as topi import akg.topi as topi
import akg.tvm as tvm import akg
@akg.schedule(topi.cuda.schedule_injective)
def Squeeze(x, axis=None): def Squeeze(x, axis=None):
""" """
Remove the dimensions which have shape size 1. Remove the dimensions which have shape size 1.
...@@ -30,23 +29,3 @@ def Squeeze(x, axis=None): ...@@ -30,23 +29,3 @@ def Squeeze(x, axis=None):
tvm.tensor.Tensor, has the same type and element as x, but some size 1 dimensions are removed. tvm.tensor.Tensor, has the same type and element as x, but some size 1 dimensions are removed.
""" """
return topi.squeeze(x, axis) return topi.squeeze(x, axis)
def gpu_schedule_Squeeze(outs):
"""
gpu schedule Squeeze.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_injective(outs)
return sch
#!/usr/bin/env python3 # Copyright 2020 Huawei Technologies Co., Ltd
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,31 +14,31 @@ ...@@ -16,31 +14,31 @@
"""squeeze grad""" """squeeze grad"""
import akg.topi as topi import akg.topi as topi
import akg
def SqueezeGrad(y_grad, x_shape, axis=None): def gpu_schedule_SqueezeGrad(outs):
""" """
Computes gradients for squeeze op. gpu schedule SqueezeGrad.
Args: Args:
y_grad (tvm.tensor.Tensor): the gradient needed to be propagation. outs (tvm.tensor.Tensor): outputs of compute.
x_shape (Union[list, tuple]): output Tensor shape.
axis (Union[list, tuple, int, None], optional): eliminated axis by squeeze.
Returns: Returns:
tvm.tensor.Tensor: output gradient. sch (schedule.Schedule): The created schedule.
""" """
return topi.reshape(y_grad, x_shape) from .default_schedule import default_schedule
return default_schedule(outs)
def gpu_schedule_SqueezeGrad(outs): @akg.schedule(gpu_schedule_SqueezeGrad)
def SqueezeGrad(y_grad, x_shape):
""" """
gpu schedule SqueezeGrad. Computes gradients for squeeze op.
Args: Args:
outs (tvm.tensor.Tensor): outputs of compute. y_grad (tvm.tensor.Tensor): the gradient needed to be propagation.
x_shape (Union[list, tuple]): output Tensor shape.
Returns: Returns:
sch (schedule.Schedule): The created schedule. tvm.tensor.Tensor: output gradient.
""" """
from .default_schedule import default_schedule return topi.reshape(y_grad, x_shape)
return default_schedule(outs)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""sub"""
import akg
import akg.topi as topi
from akg.ops.math_gpu import sub
@akg.schedule(topi.cuda.schedule_injective)
def Sub(x, y):
"""Sub."""
return sub.sub(x, y)
#!/usr/bin/env python3 #!/usr/bin/env python3
# coding: utf-8 # coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,29 +13,13 @@ ...@@ -13,29 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""tile""" """tile"""
import akg.tvm from akg.ops.array_gpu import tile
from akg.ops.array import tile import akg.topi as topi
from akg.topi.generic import schedule_elemwise import akg
@akg.schedule(topi.cuda.schedule_injective)
def Tile(x, multiples): def Tile(x, multiples):
"""tile.""" """tile."""
return tile.tile(x, multiples) return tile.tile(x, multiples)
def gpu_schedule_Tile(outs):
"""
gpu schedule for tile.
Args:
outs (tvm.tensor.Tensor): outputs of compute.
Returns:
sch (schedule.Schedule): The created schedule.
"""
device = 'cuda'
ctx = akg.tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with akg.tvm.target.create(device):
s = schedule_elemwise(outs)
return s
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: tile"""
import akg.tvm
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple))
def tile(data, multiples):
"""
Repeats the data in the specified dimensions according to the multiples.
Args:
data (tvm.tensor.Tensor): Tensor.
multiples (Union[list, tuple]): Elements must be int. The number of repetitions.
Returns:
tvm.tensor.Tensor, has the same dtype as data.
"""
vc_util.check_shape(data.shape)
vc_util.check_int_list(multiples, "multiples")
output = akg.topi.tile(data, multiples)
return output
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: cast"""
import akg.tvm
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, str)
def cast(data, dst_type):
"""
cast data to target type.
Args:
data (tvm.tensor.Tensor): Tensor to be casted.
dst_type (str): target cast type.
Returns:
tvm.tensor.Tensor, type is dst_type.
"""
vc_util.check_shape(data.shape)
out = akg.topi.cast(data, dst_type)
return out
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: equal"""
import akg.tvm
import akg.topi
from akg.utils.dsl_create import produce_shapes
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def equal(input1, input2):
"""
check whether input1 equals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 equal to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get equal compute
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
input1_bro = akg.topi.broadcast_to(input1, shape)
input2_bro = akg.topi.broadcast_to(input2, shape)
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: greaterequal"""
import akg.tvm
import akg.topi
from akg.utils.dsl_create import produce_shapes
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def greater_equal(input1, input2):
"""
Check whether input1 greaterquals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 greaterquals to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get greaterquals compute
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
input1_bro = akg.topi.broadcast_to(input1, shape)
input2_bro = akg.topi.broadcast_to(input2, shape)
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] >= input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: lessequal"""
import akg.tvm
import akg.topi
from akg.utils.dsl_create import produce_shapes
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def less_equal(input1, input2):
"""
Check whether input1 lessequals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get lessequal compute
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
input1_bro = akg.topi.broadcast_to(input1, shape)
input2_bro = akg.topi.broadcast_to(input2, shape)
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: logical_and"""
import akg.tvm
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def logical_and(input1, input2):
"""
Compute logical_and of input1 and input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. LogicalAnd of input1 and input2.
"""
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
res = akg.topi.logical_and(input1, input2)
return res
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: logical_not"""
import akg.tvm
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor)
def logical_not(input1):
"""
Compute logical_not of input1.
Args:
input1 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor.
"""
res = akg.topi.logical_not(input1)
return res
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: logical_or"""
import akg.tvm
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def logical_or(input1, input2):
"""
Compute logical_or of input1 and input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. LogicalOr of input1 and input2.
"""
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
res = akg.topi.logical_or(input1, input2)
return res
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: mean"""
import akg.topi
import akg.tvm
from akg.utils import format_transform as ft_util
from akg.utils import validation_check as vc_util
from akg.ops.math_gpu import sum_value
@vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
def mean(data, axis=None, keepdims=False):
"""
Computes the mean of the values of a Tensor over the whole dataset.
Args:
data (tvm.tensor.Tensor): Tensor.
axis (Union[list, tuple, int, None]): If the tuple is empty, the axis equal to None.
keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.
Returns:
tvm.tensor.Tensor, has the same type as data. If keepdims equal to True, all reduced dimensions are
retained with length 1. else these reduced axis will be eliminate.
"""
shape = [x.value for x in data.shape]
vc_util.reduce_axis_check(shape, axis)
axis = ft_util.refine_reduce_axis(data, axis)
count = 1
for i in axis:
count *= shape[i]
output, _ = sum_value.sum_value(data, axis, keepdims)
res = akg.topi.divide(output, count)
return res
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: mul"""
import akg.topi
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def mul(l_input, r_input):
"""
Calculate x * y element-wise.
Note:
mul supports broadcasting.
Args:
l_input (tvm.tensor.Tensor): Tensor.
r_input (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor, has the same type as l_input and r_input.
"""
shape1 = [x.value for x in l_input.shape]
shape2 = [x.value for x in r_input.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
vc_util.auto_broadcast_check(shape1, shape2)
vc_util.elemwise_dtype_check(l_input.dtype, r_input.dtype)
output = akg.topi.multiply(l_input, r_input)
return output
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: notequal"""
import akg.tvm
import akg.topi
from akg.utils.dsl_create import produce_shapes
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def notequal(input1, input2):
"""
check whether input1 notequals to input2.
Args:
input1 (tvm.tensor.Tensor): Tensor.
input2 (tvm.tensor.Tensor): Tensor.
Returns:
tvm.tensor.Tensor. If input1 notequal to input2 return True, else return False.
"""
shape1 = [x.value for x in input1.shape]
shape2 = [x.value for x in input2.shape]
vc_util.check_shape(shape1)
vc_util.check_shape(shape2)
shape1, shape2, shape = produce_shapes(shape1, shape2)
vc_util.elemwise_dtype_check(input1.dtype, input2.dtype)
dtype = input1.dtype
# get notequal compute
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
input1_bro = akg.topi.broadcast_to(input1, shape)
input2_bro = akg.topi.broadcast_to(input2, shape)
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] != input2_bro[indice],
t_value[indice], f_value[indice]), name="C")
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
return res
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: sub"""
import akg.topi
import akg.tvm
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
def sub(data1, data2):
"""
Computes data1 - data2 elementwise, broadcast is supported.
Args:
data1 (tvm.tensor.Tensor): Tensor.
data2 (tvm.tensor.Tensor): Tensor of same type as data1, if shape(data2) != shape(data1), broadcast will happen.
Returns:
tvm.tensor.Tensor, subtracted result, with same type as input tensors and broadcasted shape of data1 and data2.
"""
vc_util.elemwise_dtype_check(data1.dtype, data2.dtype)
vc_util.check_shape(data1.shape)
vc_util.check_shape(data2.shape)
vc_util.auto_broadcast_check(data1.shape, data2.shape)
res = akg.topi.subtract(data1, data2)
return res
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""operator dsl function: sum"""
import akg.topi
import akg.tvm
from akg.utils import format_transform as ft_util
from akg.utils import validation_check as vc_util
@vc_util.check_input_type(akg.tvm.tensor.Tensor, (list, tuple, int, type(None)), (bool, type(None)))
def sum_value(inputs, axis=None, keepdims=False):
"""
Compute the sum of elements across dimensions of a tensor.
Args:
inputs (tvm.tensor.Tensor): Tensor.
axis (Union[list, tuple, int, None]): If the list or tuple is empty, the axis equal to None.
keepdims (bool): If keepdims equal to True, the result shape length is same to input shape length.
Returns:
tvm.tensor.Tensor, has same type as input. If keepdims is True, all reduced dimensions are retained
with length 1, else these reduced axis will be eliminate.
"""
axis = ft_util.refine_reduce_axis(inputs, axis)
vc_util.check_shape(inputs.shape)
if not axis:
output = akg.topi.identity(inputs)
else:
output = akg.topi.sum(inputs, axis=axis, keepdims=keepdims)
return output
...@@ -234,7 +234,7 @@ class HostDeviceSplitter : public IRMutator { ...@@ -234,7 +234,7 @@ class HostDeviceSplitter : public IRMutator {
} }
} }
#ifdef BACKEND_D #ifdef FIX_INPUT_ORDER_TVM
std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>(); std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
bool match = false; bool match = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册