From 8672e153637a2be3aaa804bc735db7e20b6cba5c Mon Sep 17 00:00:00 2001 From: danleifeng <52735331+danleifeng@users.noreply.github.com> Date: Wed, 4 Sep 2019 20:20:54 +0800 Subject: [PATCH] elementwise broadcast function enhancement (#19536) elementwise broadcast function enhancement --- .../elementwise/elementwise_op_function.h | 195 +++++++++++++++--- .../ngraph/test_elementwise_max_ngraph_op.py | 2 +- .../ngraph/test_elementwise_min_ngraph_op.py | 2 +- .../ngraph/test_elementwise_pow_ngraph_op.py | 2 +- .../ngraph/test_elementwise_sub_ngraph_op.py | 2 +- .../unittests/test_elementwise_add_op.py | 28 +++ .../unittests/test_elementwise_div_op.py | 20 ++ .../unittests/test_elementwise_max_op.py | 12 ++ .../unittests/test_elementwise_min_op.py | 22 +- .../unittests/test_elementwise_mul_op.py | 20 ++ .../unittests/test_elementwise_pow_op.py | 10 + .../unittests/test_elementwise_sub_op.py | 10 + 12 files changed, 293 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 7d0256cc1c..59a9c3086d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -47,25 +47,65 @@ namespace operators { * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) * pre=2*3, n=4*5, post=1 * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) + * + * New parameter: *mid_flag* is added to solve m*n*k & m*1*k + * broadcast cases. + * 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5) + * mid_flag should not be NULL. + * x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20) */ inline void get_mid_dims(const framework::DDim &x_dims, const framework::DDim &y_dims, const int axis, - int *pre, int *n, int *post) { + int *pre, int *n, int *post, int *mid_flag = NULL) { *pre = 1; *n = 1; *post = 1; - for (int i = 0; i < axis; ++i) { - (*pre) *= x_dims[i]; - } + if (mid_flag != NULL) { + *mid_flag = 0; + int mid = 0; + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + if (x_dims[i + axis] != y_dims[i]) { + // only support single y_dims[i] = 1 now. + PADDLE_ENFORCE_EQ(*mid_flag, 0, + "Broadcast support y_dims with single 1."); + PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch."); + // m*n*k m*1*k + for (int j = 0; j < i; ++j) { + (*pre) *= y_dims[j]; + } + *n = std::max(x_dims[i + axis], y_dims[i]); + *mid_flag = 1; + mid = i; + break; + } + (*n) *= y_dims[i]; + } + if (*mid_flag) { + for (int i = mid + 1; i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } else { + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } + } + } else { // for fused_elementwise_activation_op. keep the old version. + for (int i = 0; i < axis; ++i) { + (*pre) *= x_dims[i]; + } - for (int i = 0; i < y_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], - "Broadcast dimension mismatch."); - (*n) *= y_dims[i]; - } + for (int i = 0; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], + "Broadcast dimension mismatch."); + (*n) *= y_dims[i]; + } - for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { - (*post) *= x_dims[i]; + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } } } @@ -171,7 +211,6 @@ class MidWiseTransformIterator } } } - return *this; } @@ -268,6 +307,15 @@ class TransformFunctor { MidWiseTransformIterator(y_, n, post), z_, func_); } + inline void RunMidRowWise(int n, int pre, int post) const { + platform::Transform trans; + for (int i = 0; i < pre; i++) { + trans(ctx_, x_ + i * n * post, x_ + (i + 1) * n * post, + RowwiseTransformIterator(y_ + i * post, post), + z_ + i * n * post, func_); + } + } + private: const T *x_; const T *y_; @@ -501,6 +549,88 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x, #endif +template +static void ElemwiseGradBroadcastMid2CPU(const T *x, const T *y, const T *out, + const T *dout, int pre, int n, + int post, DX_OP dx_op, DY_OP dy_op, + T *dx, T *dy) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int x_offset = i * n * post + j * post + k; + int y_offset = i * post + k; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = + dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + if (j == 0) { + dy[y_offset] = tmp; + } else { + dy[y_offset] += tmp; + } + } + } + } + } +} + +#ifdef __NVCC__ +template +static __global__ void ElemwiseGradBroadcastMid2CUDAKernel( + const T *x, const T *y, const T *out, const T *dout, int pre, int n, + int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + int j = threadIdx.x; + int tid = blockIdx.x; + + T val(0); + int ttid = tid; + + while (true) { + int i = ttid / post; + int k = ttid % post; + if (i >= pre) break; + + int x_offset = i * n * post + j * post + k; + int y_offset = i * post + k; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + + if (dy != nullptr) { + val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + + ttid += ELEMWISE_MAX_BLOCK_DIM; + } + + if (dy) { + int h = n; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, j, h); + if (threadIdx.x == 0) { + dy[tid] = val; + } + } +} + +template +static void ElemwiseGradBroadcastMid2CUDA(cudaStream_t stream, const T *x, + const T *y, const T *out, + const T *dout, int pre, int n, + int post, DX_OP dx_op, DY_OP dy_op, + T *dx, T *dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, n); + int gird_size = pre * post; + ElemwiseGradBroadcastMid2CUDAKernel<<>>( + x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy); +} + +#endif + template void ElemwiseGradComputeNoBroadcast( const framework::ExecutionContext &ctx, const framework::DDim &x_dim, @@ -533,23 +663,39 @@ void ElemwiseGradComputeWithBroadcast( auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); axis = (y_dim.size() == 0) ? x_dim.size() : axis; - int pre, n, post; - get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); - if (post == 1) { - int h = pre; - int w = n; + int pre, n, post, mid_flag = 0; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &mid_flag); + if (mid_flag) { + PADDLE_ENFORCE_EQ(mid_flag, 1, "mid_flag should be no more than 1."); + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + ElemwiseGradBroadcastMid2CUDA( + ctx.template device_context().stream(), x.data(), + y.data(), out.data(), dout.data(), pre, n, post, dx_op, + dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); +#endif + } else { + ElemwiseGradBroadcastMid2CPU( + x.data(), y.data(), out.data(), dout.data(), pre, n, post, + dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } + } else if (post == 1) { if (platform::is_gpu_place(ctx.GetPlace())) { #ifdef __NVCC__ ElemwiseGradBroadcast1CUDA( ctx.template device_context().stream(), x.data(), - y.data(), out.data(), dout.data(), h, w, dx_op, dy_op, + y.data(), out.data(), dout.data(), pre, n, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); #endif } else { ElemwiseGradBroadcast1CPU( - x.data(), y.data(), out.data(), dout.data(), h, w, dx_op, - dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + x.data(), y.data(), out.data(), dout.data(), pre, n, + dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); } } else { @@ -689,9 +835,12 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, "Axis should be in range [0, x_dims)"); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); axis = (y_dims.size() == 0) ? x_dims.size() : axis; - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + int pre, n, post, mid_flag = 0; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag); + if (mid_flag) { + functor.RunMidRowWise(n, pre, post); + return; + } if (post == 1) { functor.RunRowWise(n, pre); return; diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py index 30d6ab4765..c680241720 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_elementwise_max_op import * +from test_elementwise_max_op import TestElementwiseMaxOp_scalar, TestElementwiseMaxOp_Vector, TestElementwiseMaxOp_broadcast_0 if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py index 8fa19b3268..443445288a 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_elementwise_min_op import * +from test_elementwise_min_op import TestElementwiseMinOp_scalar, TestElementwiseMinOp_Vector, TestElementwiseMinOp_broadcast_0 if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py index 02a7abc6e2..1601de2313 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_elementwise_pow_op import * +from test_elementwise_pow_op import TestElementwisePowOp_scalar, TestElementwisePowOp_tensor, TestElementwisePowOp_broadcast_0 if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py index 078cf12f6d..fe29008f5e 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_elementwise_sub_op import * +from test_elementwise_sub_op import TestElementwiseSubOp_scalar, TestElementwiseSubOp_Vector, TestElementwiseSubOp_broadcast_0 if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 5aec5d8e38..5783048f5f 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -218,6 +218,34 @@ class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp): self.axis = 0 +class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2, 1, 4).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_broadcast_5(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2, 1, 4).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 3, 1, 5).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 3, 1, 5).astype(self.dtype) + self.out = self.x + self.y + + class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): def init_input_output(self): self.x = np.random.rand(2, 3, 4).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py index 15d4db590e..4e679607d1 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_div_op.py @@ -131,6 +131,26 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): } +class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp): + def setUp(self): + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 1, 4]).astype("float32") + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + +class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): + def setUp(self): + self.op_type = "elementwise_div" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32") + } + self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} + + class TestElementwiseDivOpFp16(ElementwiseDivOp): def init_dtype(self): self.dtype = np.float16 diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py index 43c58710ba..db7f5a640e 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_max_op.py @@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): } +class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_max" + x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) + sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float32) + y = x + sgn * \ + np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py index 45c861e2c3..c1e93f6a4e 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_min_op.py @@ -55,7 +55,7 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp): self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} -class TestElementwiseMaxOp_Vector(TestElementwiseOp): +class TestElementwiseMinOp_Vector(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_min" x = np.random.random((32, )).astype("float32") @@ -65,7 +65,7 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp): self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} -class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): +class TestElementwiseMinOp_broadcast_0(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_min" x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) @@ -81,7 +81,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): } -class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): +class TestElementwiseMinOp_broadcast_1(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_min" x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) @@ -97,7 +97,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): } -class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): +class TestElementwiseMinOp_broadcast_2(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_min" x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) @@ -112,7 +112,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): } -class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): +class TestElementwiseMinOp_broadcast_3(TestElementwiseOp): def setUp(self): self.op_type = "elementwise_min" x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) @@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): } +class TestElementwiseMinOp_broadcast_4(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_min" + x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) + sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float32) + y = x + sgn * \ + np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float32) + self.inputs = {'X': x, 'Y': y} + + self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py index 0484099188..2415aeb0cb 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py @@ -135,6 +135,26 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): } +class TestElementwiseMulOp_broadcast_4(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype(np.float64), + 'Y': np.random.rand(2, 1, 4).astype(np.float64) + } + self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + + +class TestElementwiseMulOp_broadcast_5(ElementwiseMulOp): + def setUp(self): + self.op_type = "elementwise_mul" + self.inputs = { + 'X': np.random.rand(2, 3, 4, 5).astype(np.float64), + 'Y': np.random.rand(2, 3, 1, 5).astype(np.float64) + } + self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']} + + class TestElementwiseMulOpFp16(ElementwiseMulOp): def init_dtype(self): self.dtype = np.float16 diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index 0b0c7c5ecb..e6a065889c 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -104,5 +104,15 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): } +class TestElementwisePowOp_broadcast_4(TestElementwisePowOp): + def setUp(self): + self.op_type = "elementwise_pow" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32") + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py index 6cb88a8bb1..e9a389bbaf 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py @@ -117,5 +117,15 @@ class TestElementwiseSubOp_broadcast_3(TestElementwiseOp): } +class TestElementwiseSubOp_broadcast_4(TestElementwiseOp): + def setUp(self): + self.op_type = "elementwise_sub" + self.inputs = { + 'X': np.random.rand(2, 3, 4, 5).astype(np.float32), + 'Y': np.random.rand(2, 3, 1, 5).astype(np.float32) + } + self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} + + if __name__ == '__main__': unittest.main() -- GitLab