未验证 提交 28280647 编写于 作者: L LutaoChu 提交者: GitHub

add paddle.subtract, optimize paddle.maximum and paddle.minimum

add paddle.subtract, optimize paddle.maximum and paddle.minimum 
上级 3c2a46bd
......@@ -186,6 +186,7 @@ from .tensor.math import mod #DEFINE_ALIAS
from .tensor.math import floor_mod #DEFINE_ALIAS
from .tensor.math import multiply #DEFINE_ALIAS
from .tensor.math import add #DEFINE_ALIAS
from .tensor.math import subtract #DEFINE_ALIAS
from .tensor.math import atan #DEFINE_ALIAS
from .tensor.math import logsumexp #DEFINE_ALIAS
from .tensor.math import inverse #DEFINE_ALIAS
......
......@@ -16,7 +16,6 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
......@@ -31,6 +30,14 @@ class ApiMaximumTest(unittest.TestCase):
self.input_x = np.random.rand(10, 15).astype("float32")
self.input_y = np.random.rand(10, 15).astype("float32")
self.input_z = np.random.rand(15).astype("float32")
self.input_a = np.array([0, np.nan, np.nan]).astype('int64')
self.input_b = np.array([2, np.inf, -np.inf]).astype('int64')
self.input_c = np.array([4, 1, 3]).astype('int64')
self.np_expected1 = np.maximum(self.input_x, self.input_y)
self.np_expected2 = np.maximum(self.input_x, self.input_z)
self.np_expected3 = np.maximum(self.input_a, self.input_c)
self.np_expected4 = np.maximum(self.input_b, self.input_c)
def test_static_api(self):
paddle.enable_static()
......@@ -43,38 +50,64 @@ class ApiMaximumTest(unittest.TestCase):
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[result_max])
self.assertEqual((res == np.maximum(self.input_x, self.input_y)).all(),
True)
self.assertTrue(np.allclose(res, self.np_expected1))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_max = paddle.maximum(data_x, data_z, axis=1)
result_max = paddle.maximum(data_x, data_z)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"z": self.input_z},
fetch_list=[result_max])
self.assertEqual((res == np.maximum(self.input_x, self.input_z)).all(),
True)
self.assertTrue(np.allclose(res, self.np_expected2))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_a = paddle.static.data("a", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.maximum(data_a, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"a": self.input_a,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected3))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_b = paddle.static.data("b", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.maximum(data_b, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"b": self.input_b,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected4))
def test_dynamic_api(self):
paddle.disable_static()
np_x = np.array([10, 10]).astype('float64')
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
z = paddle.maximum(x, y)
np_z = z.numpy()
z_expected = np.array(np.maximum(self.input_x, self.input_y))
self.assertEqual((np_z == z_expected).all(), True)
z = paddle.to_tensor(self.input_z)
def test_broadcast_axis(self):
paddle.disable_static()
np_x = np.random.rand(5, 4, 3, 2).astype("float64")
np_y = np.random.rand(4, 3).astype("float64")
a = paddle.to_tensor(self.input_a)
b = paddle.to_tensor(self.input_b)
c = paddle.to_tensor(self.input_c)
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
result_1 = paddle.maximum(x, y, axis=1)
result_2 = paddle.maximum(x, y, axis=-2)
self.assertEqual((result_1.numpy() == result_2.numpy()).all(), True)
res = paddle.maximum(x, y)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
# test broadcast
res = paddle.maximum(x, z)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.maximum(a, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.maximum(b, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
......@@ -16,7 +16,6 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
......@@ -31,6 +30,14 @@ class ApiMinimumTest(unittest.TestCase):
self.input_x = np.random.rand(10, 15).astype("float32")
self.input_y = np.random.rand(10, 15).astype("float32")
self.input_z = np.random.rand(15).astype("float32")
self.input_a = np.array([0, np.nan, np.nan]).astype('int64')
self.input_b = np.array([2, np.inf, -np.inf]).astype('int64')
self.input_c = np.array([4, 1, 3]).astype('int64')
self.np_expected1 = np.minimum(self.input_x, self.input_y)
self.np_expected2 = np.minimum(self.input_x, self.input_z)
self.np_expected3 = np.minimum(self.input_a, self.input_c)
self.np_expected4 = np.minimum(self.input_b, self.input_c)
def test_static_api(self):
paddle.enable_static()
......@@ -38,43 +45,69 @@ class ApiMinimumTest(unittest.TestCase):
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_min = paddle.minimum(data_x, data_y)
result_max = paddle.minimum(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[result_min])
self.assertEqual((res == np.minimum(self.input_x, self.input_y)).all(),
True)
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected1))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_min = paddle.minimum(data_x, data_z, axis=1)
result_max = paddle.minimum(data_x, data_z)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"z": self.input_z},
fetch_list=[result_min])
self.assertEqual((res == np.minimum(self.input_x, self.input_z)).all(),
True)
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected2))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_a = paddle.static.data("a", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.minimum(data_a, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"a": self.input_a,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected3))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_b = paddle.static.data("b", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.minimum(data_b, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"b": self.input_b,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected4))
def test_dynamic_api(self):
paddle.disable_static()
np_x = np.array([10, 10]).astype('float64')
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
z = paddle.minimum(x, y)
np_z = z.numpy()
z_expected = np.array(np.minimum(self.input_x, self.input_y))
self.assertEqual((np_z == z_expected).all(), True)
z = paddle.to_tensor(self.input_z)
def test_broadcast_axis(self):
paddle.disable_static()
np_x = np.random.rand(5, 4, 3, 2).astype("float64")
np_y = np.random.rand(4, 3).astype("float64")
a = paddle.to_tensor(self.input_a)
b = paddle.to_tensor(self.input_b)
c = paddle.to_tensor(self.input_c)
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
result_1 = paddle.minimum(x, y, axis=1)
result_2 = paddle.minimum(x, y, axis=-2)
self.assertEqual((result_1.numpy() == result_2.numpy()).all(), True)
res = paddle.minimum(x, y)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
# test broadcast
res = paddle.minimum(x, z)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.minimum(a, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.minimum(b, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
class ApiSubtractTest(unittest.TestCase):
def setUp(self):
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
self.input_x = np.random.rand(10, 15).astype("float32")
self.input_y = np.random.rand(10, 15).astype("float32")
self.input_z = np.random.rand(15).astype("float32")
self.input_a = np.array([0, np.nan, np.nan]).astype('int64')
self.input_b = np.array([2, np.inf, -np.inf]).astype('int64')
self.input_c = np.array([4, 1, 3]).astype('int64')
self.np_expected1 = np.subtract(self.input_x, self.input_y)
self.np_expected2 = np.subtract(self.input_x, self.input_z)
self.np_expected3 = np.subtract(self.input_a, self.input_c)
self.np_expected4 = np.subtract(self.input_b, self.input_c)
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_max = paddle.subtract(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected1))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_max = paddle.subtract(data_x, data_z)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"z": self.input_z},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected2))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_a = paddle.static.data("a", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.subtract(data_a, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"a": self.input_a,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected3))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_b = paddle.static.data("b", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_max = paddle.subtract(data_b, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"b": self.input_b,
"c": self.input_c},
fetch_list=[result_max])
self.assertTrue(np.allclose(res, self.np_expected4))
def test_dynamic_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
z = paddle.to_tensor(self.input_z)
a = paddle.to_tensor(self.input_a)
b = paddle.to_tensor(self.input_b)
c = paddle.to_tensor(self.input_c)
res = paddle.subtract(x, y)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
# test broadcast
res = paddle.subtract(x, z)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.subtract(a, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.subtract(b, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
......@@ -148,6 +148,7 @@ from .math import mod #DEFINE_ALIAS
from .math import floor_mod #DEFINE_ALIAS
from .math import multiply #DEFINE_ALIAS
from .math import add #DEFINE_ALIAS
from .math import subtract #DEFINE_ALIAS
from .math import atan #DEFINE_ALIAS
from .math import logsumexp #DEFINE_ALIAS
from .math import inverse #DEFINE_ALIAS
......
......@@ -111,6 +111,7 @@ __all__ = [
'floor_mod',
'multiply',
'add',
'subtract',
'atan',
'logsumexp',
'inverse',
......@@ -286,6 +287,67 @@ def add(x, y, name=None):
return _elementwise_op(LayerHelper(op_type, **locals()))
def subtract(x, y, name=None):
"""
Substract two tensors element-wise. The equation is:
.. math::
out = x - y
**Note**:
``paddle.subtract`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = paddle.to_tensor([[1, 2], [7, 8]])
y = paddle.to_tensor([[5, 6], [3, 4]])
res = paddle.subtract(x, y)
print(res)
# [[-4, -4],
# [4, 4]]
x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]])
y = paddle.to_tensor([1, 0, 4])
res = paddle.subtract(x, y)
print(res)
# [[[ 0, 2, -1],
# [ 0, 2, -1]]]
x = paddle.to_tensor([2, np.nan, 5], dtype='float32')
y = paddle.to_tensor([1, 4, np.nan], dtype='float32')
res = paddle.subtract(x, y)
print(res)
# [ 1., nan, nan]
x = paddle.to_tensor([5, np.inf, -np.inf], dtype='float64')
y = paddle.to_tensor([1, 4, 5], dtype='float64')
res = paddle.subtract(x, y)
print(res)
# [ 4., inf., -inf.]
"""
op_type = 'elementwise_sub'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def divide(x, y, name=None):
"""
Divide two tensors element-wise. The equation is:
......@@ -302,7 +364,7 @@ def divide(x, y, name=None):
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. It's dimension equals with $x$.
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
......@@ -382,7 +444,7 @@ def remainder(x, y, name=None):
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. It's dimension equals with $x$.
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
......@@ -425,7 +487,7 @@ def multiply(x, y, name=None):
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. Its dimension equals with $x$.
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
......@@ -463,84 +525,118 @@ def multiply(x, y, name=None):
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def maximum(x, y, axis=-1, name=None):
def maximum(x, y, name=None):
"""
Examples:
Compare two tensors and returns a new tensor containing the element-wise maxima. The equation is:
.. code-block:: python
.. math::
out = max(x, y)
import paddle
import numpy as np
x = paddle.to_tensor([[1, 2], [3, 4]])
y = paddle.to_tensor([[5, 6], [7, 8]])
res = paddle.maximum(x, y)
print(res)
#[[5. 6.]
# [7. 8.]]
x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]])
y = paddle.to_tensor([1, 2])
res = paddle.maximum(x, y, axis=1)
print(res)
#[[[1. 2. 3.]
# [2. 2. 3.]]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, 4, np.nan], dtype='float32')
res = paddle.maximum(x, y)
print(res)
#[ 2. 4. nan]
x = paddle.to_tensor([5, 3, np.inf], dtype='float32')
y = paddle.to_tensor([1, 4, 5], dtype='float32')
res = paddle.maximum(x, y)
print(res)
#[ 5. 4. inf]
**Note**:
``paddle.maximum`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = paddle.to_tensor([[1, 2], [7, 8]])
y = paddle.to_tensor([[3, 4], [5, 6]])
res = paddle.maximum(x, y)
print(res)
# [[3, 4],
# [7, 8]]
x = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
y = paddle.to_tensor([3, 0, 4])
res = paddle.maximum(x, y)
print(res)
# [[3, 2, 4],
# [3, 2, 4]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32')
res = paddle.maximum(x, y)
print(res)
# [ 2., nan, nan]
x = paddle.to_tensor([5, 3, np.inf], dtype='float32')
y = paddle.to_tensor([1, -np.inf, 5], dtype='float32')
res = paddle.maximum(x, y)
print(res)
# [ 5., 3., inf.]
"""
op_type = 'elementwise_max'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def minimum(x, y, axis=-1, name=None):
def minimum(x, y, name=None):
"""
Examples:
Compare two tensors and returns a new tensor containing the element-wise minima. The equation is:
.. code-block:: python
.. math::
out = min(x, y)
import paddle
import numpy as np
x = paddle.to_tensor([[1, 2], [3, 4]], dtype='float32')
y = paddle.to_tensor([[5, 6], [7, 8]], dtype='float32')
res = paddle.minimum(x, y)
print(res)
#[[1. 2.]
# [3. 4.]]
x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]], dtype='float32')
y = paddle.to_tensor([1, 2], dtype='float32')
res = paddle.minimum(x, y, axis=1)
print(res)
#[[[1. 1. 1.]
# [2. 2. 2.]]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, 4, np.nan], dtype='float32')
res = paddle.minimum(x, y)
print(res)
#[ 1. 3. nan]
x = paddle.to_tensor([5, 3, np.inf], dtype='float32')
y = paddle.to_tensor([1, 4, 5], dtype='float32')
res = paddle.minimum(x, y)
print(res)
#[1. 3. 5.]
**Note**:
``paddle.minimum`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = paddle.to_tensor([[1, 2], [7, 8]])
y = paddle.to_tensor([[3, 4], [5, 6]])
res = paddle.minimum(x, y)
print(res)
# [[1, 2],
# [5, 6]]
x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]])
y = paddle.to_tensor([3, 0, 4])
res = paddle.minimum(x, y)
print(res)
# [[[1, 0, 3],
# [1, 0, 3]]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32')
res = paddle.minimum(x, y)
print(res)
# [ 1., nan, nan]
x = paddle.to_tensor([5, 3, np.inf], dtype='float64')
y = paddle.to_tensor([1, -np.inf, 5], dtype='float64')
res = paddle.minimum(x, y)
print(res)
# [ 1., -inf., 5.]
"""
op_type = 'elementwise_min'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
......@@ -549,11 +645,9 @@ Examples:
for func in [
add,
maximum,
minimum,
multiply
]:
proto_dict = {'add': 'elementwise_add', 'div': 'elementwise_div', 'maximum': 'elementwise_max', 'minimum': 'elementwise_min', 'multiply': 'elementwise_mul'}
proto_dict = {'add': 'elementwise_add', 'multiply': 'elementwise_mul'}
op_proto = OpProtoHolder.instance().get_op_proto(proto_dict[func.__name__])
additional_args_lines = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册