提交 8672e153 编写于 作者: D danleifeng 提交者: gongweibao

elementwise broadcast function enhancement (#19536)

elementwise broadcast function enhancement
上级 a50785b0
...@@ -47,13 +47,52 @@ namespace operators { ...@@ -47,13 +47,52 @@ namespace operators {
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1 * pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
* broadcast cases.
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
* mid_flag should not be NULL.
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
*/ */
inline void get_mid_dims(const framework::DDim &x_dims, inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis, const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post) { int *pre, int *n, int *post, int *mid_flag = NULL) {
*pre = 1; *pre = 1;
*n = 1; *n = 1;
*post = 1; *post = 1;
if (mid_flag != NULL) {
*mid_flag = 0;
int mid = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ(*mid_flag, 0,
"Broadcast support y_dims with single 1.");
PADDLE_ENFORCE_EQ(y_dims[i], 1, "Broadcast dimension mismatch.");
// m*n*k m*1*k
for (int j = 0; j < i; ++j) {
(*pre) *= y_dims[j];
}
*n = std::max(x_dims[i + axis], y_dims[i]);
*mid_flag = 1;
mid = i;
break;
}
(*n) *= y_dims[i];
}
if (*mid_flag) {
for (int i = mid + 1; i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
} else {
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
} else { // for fused_elementwise_activation_op. keep the old version.
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i]; (*pre) *= x_dims[i];
} }
...@@ -67,6 +106,7 @@ inline void get_mid_dims(const framework::DDim &x_dims, ...@@ -67,6 +106,7 @@ inline void get_mid_dims(const framework::DDim &x_dims,
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i]; (*post) *= x_dims[i];
} }
}
} }
inline framework::DDim trim_trailing_singular_dims( inline framework::DDim trim_trailing_singular_dims(
...@@ -171,7 +211,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> ...@@ -171,7 +211,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext>
} }
} }
} }
return *this; return *this;
} }
...@@ -268,6 +307,15 @@ class TransformFunctor { ...@@ -268,6 +307,15 @@ class TransformFunctor {
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_); MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
} }
inline void RunMidRowWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
for (int i = 0; i < pre; i++) {
trans(ctx_, x_ + i * n * post, x_ + (i + 1) * n * post,
RowwiseTransformIterator<T, DeviceContext>(y_ + i * post, post),
z_ + i * n * post, func_);
}
}
private: private:
const T *x_; const T *x_;
const T *y_; const T *y_;
...@@ -501,6 +549,88 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x, ...@@ -501,6 +549,88 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
#endif #endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
if (j == 0) {
dy[y_offset] = tmp;
} else {
dy[y_offset] += tmp;
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcastMid2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = threadIdx.x;
int tid = blockIdx.x;
T val(0);
int ttid = tid;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dy) {
int h = n;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, j, h);
if (threadIdx.x == 0) {
dy[tid] = val;
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, n);
int gird_size = pre * post;
ElemwiseGradBroadcastMid2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast( void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim, const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
...@@ -533,23 +663,39 @@ void ElemwiseGradComputeWithBroadcast( ...@@ -533,23 +663,39 @@ void ElemwiseGradComputeWithBroadcast(
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post; int pre, n, post, mid_flag = 0;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &mid_flag);
if (post == 1) { if (mid_flag) {
int h = pre; PADDLE_ENFORCE_EQ(mid_flag, 1, "mid_flag should be no more than 1.");
int w = n; if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcastMid2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcastMid2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else if (post == 1) {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__ #ifdef __NVCC__
ElemwiseGradBroadcast1CUDA( ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(), ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, dy_op, y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif #endif
} else { } else {
ElemwiseGradBroadcast1CPU( ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), h, w, dx_op, x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())); dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
} }
} else { } else {
...@@ -689,9 +835,12 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ...@@ -689,9 +835,12 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post, mid_flag = 0;
int pre, n, post; get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag);
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); if (mid_flag) {
functor.RunMidRowWise(n, pre, post);
return;
}
if (post == 1) { if (post == 1) {
functor.RunRowWise(n, pre); functor.RunRowWise(n, pre);
return; return;
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_elementwise_max_op import * from test_elementwise_max_op import TestElementwiseMaxOp_scalar, TestElementwiseMaxOp_Vector, TestElementwiseMaxOp_broadcast_0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_elementwise_min_op import * from test_elementwise_min_op import TestElementwiseMinOp_scalar, TestElementwiseMinOp_Vector, TestElementwiseMinOp_broadcast_0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_elementwise_pow_op import * from test_elementwise_pow_op import TestElementwisePowOp_scalar, TestElementwisePowOp_tensor, TestElementwisePowOp_broadcast_0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_elementwise_sub_op import * from test_elementwise_sub_op import TestElementwiseSubOp_scalar, TestElementwiseSubOp_Vector, TestElementwiseSubOp_broadcast_0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -218,6 +218,34 @@ class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp): ...@@ -218,6 +218,34 @@ class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp):
self.axis = 0 self.axis = 0
class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2, 1, 4).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_5(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2, 1, 4).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 3, 1, 5).astype(self.dtype)
self.out = self.x + self.y
class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 3, 1, 5).astype(self.dtype)
self.out = self.x + self.y
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
......
...@@ -131,6 +131,26 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): ...@@ -131,6 +131,26 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
} }
class TestElementwiseDivOp_broadcast_4(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 1, 4]).astype("float32")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOpFp16(ElementwiseDivOp): class TestElementwiseDivOpFp16(ElementwiseDivOp):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
......
...@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): ...@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float32)
y = x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float32)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -55,7 +55,7 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp): ...@@ -55,7 +55,7 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_Vector(TestElementwiseOp): class TestElementwiseMinOp_Vector(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
x = np.random.random((32, )).astype("float32") x = np.random.random((32, )).astype("float32")
...@@ -65,7 +65,7 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp): ...@@ -65,7 +65,7 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): class TestElementwiseMinOp_broadcast_0(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32)
...@@ -81,7 +81,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp): ...@@ -81,7 +81,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp):
} }
class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): class TestElementwiseMinOp_broadcast_1(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32)
...@@ -97,7 +97,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp): ...@@ -97,7 +97,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp):
} }
class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): class TestElementwiseMinOp_broadcast_2(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32) x = np.random.uniform(0.5, 1, (2, 3, 4)).astype(np.float32)
...@@ -112,7 +112,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp): ...@@ -112,7 +112,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp):
} }
class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): class TestElementwiseMinOp_broadcast_3(TestElementwiseOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_min" self.op_type = "elementwise_min"
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32) x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32)
...@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp): ...@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseMinOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_min"
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float32)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float32)
y = x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float32)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -135,6 +135,26 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -135,6 +135,26 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementwiseMulOp_broadcast_4(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(2, 1, 4).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_broadcast_5(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float64),
'Y': np.random.rand(2, 3, 1, 5).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOpFp16(ElementwiseMulOp): class TestElementwiseMulOpFp16(ElementwiseMulOp):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
......
...@@ -104,5 +104,15 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp): ...@@ -104,5 +104,15 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp):
} }
class TestElementwisePowOp_broadcast_4(TestElementwisePowOp):
def setUp(self):
self.op_type = "elementwise_pow"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32")
}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -117,5 +117,15 @@ class TestElementwiseSubOp_broadcast_3(TestElementwiseOp): ...@@ -117,5 +117,15 @@ class TestElementwiseSubOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(2, 3, 1, 5).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册