未验证 提交 f586110d 编写于 作者: C cambriconhsq 提交者: GitHub

[MLU] add mlu kernel for elementwise_max_grad (#43608)

* [MLU] add mlu kernel for elementwise_max_grad

* [MLU] modify mlu kernel elementwise_min_grad impl
上级 2353db3a
......@@ -27,6 +27,14 @@ class ElementwiseMaxMLUKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class ElementwiseMaxGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
MLUMinMaxGradHelper<MAXIMUM_GRAD, T>(ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -34,4 +42,8 @@ namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL(elementwise_max, ops::ElementwiseMaxMLUKernel<int>,
ops::ElementwiseMaxMLUKernel<float>,
ops::ElementwiseMaxMLUKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(
elementwise_max_grad, ops::ElementwiseMaxGradMLUKernel<int>,
ops::ElementwiseMaxGradMLUKernel<float>,
ops::ElementwiseMaxGradMLUKernel<paddle::platform::float16>);
#endif
......@@ -34,92 +34,7 @@ template <typename T>
class ElementwiseMinGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1)
: axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
// mask = LessEqual(x, y)
Tensor mask(x->dtype());
mask.Resize(phi::make_ddim(out_dims_array));
mask.mutable_data<T>(ctx.GetPlace());
cnnlDataType_t data_type = ToCnnlDataType<T>();
MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type);
MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type);
MLUCnnlTensorDesc mask_desc(max_dim, out_dims_array.data(), data_type);
MLUCnnl::Logic(ctx, CNNL_LOGIC_OP_LE, x_desc.get(), GetBasePtr(x),
y_desc.get(), GetBasePtr(y), mask_desc.get(),
GetBasePtr(&mask));
// dx = Mul(dz, mask)
Tensor dx_temp(x->dtype());
dx_temp.Resize(dout->dims());
dx_temp.mutable_data<T>(ctx.GetPlace());
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, data_type,
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
dout_desc.get(), GetBasePtr(&mask), dout_desc.get(),
GetBasePtr(&dx_temp), data_type);
// dy = Sub(dz, dx)
Tensor dy_temp(y->dtype());
dy_temp.Resize(dout->dims());
dy_temp.mutable_data<T>(ctx.GetPlace());
MLUCnnlOpTensorDesc sub_op_desc(CNNL_OP_TENSOR_SUB, data_type,
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, sub_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
dout_desc.get(), GetBasePtr(&dx_temp), dout_desc.get(),
GetBasePtr(&dy_temp), data_type);
if (dx) {
if (dx->dims() != dout->dims()) {
dx->mutable_data<T>(ctx.GetPlace());
std::vector<int> reduce_axes;
GetReduceAxes(axis, dx_temp.dims(), dx->dims(), &reduce_axes);
MLUCnnlReduceDesc reduction_desc(
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, dout_desc.get(), GetBasePtr(&dx_temp), 0,
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
} else {
dx->ShareDataWith(dx_temp);
}
}
if (dy) {
if (dy->dims() != dout->dims()) {
dy->mutable_data<T>(ctx.GetPlace());
std::vector<int> reduce_axes;
GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes);
MLUCnnlReduceDesc reduction_desc(
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dy_desc(*dy);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0,
nullptr, nullptr, dy_desc.get(), GetBasePtr(dy));
} else {
dy->ShareDataWith(dy_temp);
}
}
MLUMinMaxGradHelper<MINIMUM_GRAD, T>(ctx);
}
};
......
......@@ -224,6 +224,102 @@ void MLUUnaryOp(const framework::ExecutionContext& ctx) {
out_desc.get(), GetBasePtr(out));
}
// ------------------ MLUElementwiseGradOp -----------------
enum MINMAX_GRAD_FUNCTOR {
MAXIMUM_GRAD,
MINIMUM_GRAD,
};
template <MINMAX_GRAD_FUNCTOR Functor, typename Tin, typename Tout = Tin>
void MLUMinMaxGradHelper(const framework::ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
const auto& x_dims = x->dims();
const auto& y_dims = y->dims();
axis =
(axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
// mask = Logic(x, y) only support min & max
cnnlLogicOp_t logic =
Functor == MAXIMUM_GRAD ? CNNL_LOGIC_OP_GE : CNNL_LOGIC_OP_LE;
Tensor mask(x->dtype());
mask.Resize(phi::make_ddim(out_dims_array));
mask.mutable_data<Tin>(ctx.GetPlace());
cnnlDataType_t data_type = ToCnnlDataType<Tin>();
MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type);
MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type);
MLUCnnlTensorDesc mask_desc(max_dim, out_dims_array.data(), data_type);
MLUCnnl::Logic(ctx, logic, x_desc.get(), GetBasePtr(x), y_desc.get(),
GetBasePtr(y), mask_desc.get(), GetBasePtr(&mask));
// dx = Mul(dz, mask)
Tensor dx_temp(x->dtype());
dx_temp.Resize(dout->dims());
dx_temp.mutable_data<Tout>(ctx.GetPlace());
MLUCnnlTensorDesc dout_desc(*dout);
MLUCnnlOpTensorDesc mul_op_desc(CNNL_OP_TENSOR_MUL, data_type,
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, mul_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
dout_desc.get(), GetBasePtr(&mask), dout_desc.get(),
GetBasePtr(&dx_temp), data_type);
// dy = Sub(dz, dx)
Tensor dy_temp(y->dtype());
dy_temp.Resize(dout->dims());
dy_temp.mutable_data<Tout>(ctx.GetPlace());
MLUCnnlOpTensorDesc sub_op_desc(CNNL_OP_TENSOR_SUB, data_type,
CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx, sub_op_desc.get(), dout_desc.get(), GetBasePtr(dout),
dout_desc.get(), GetBasePtr(&dx_temp), dout_desc.get(),
GetBasePtr(&dy_temp), data_type);
if (dx) {
if (dx->dims() != dout->dims()) {
dx->mutable_data<Tout>(ctx.GetPlace());
std::vector<int> reduce_axes;
GetReduceAxes(axis, dx_temp.dims(), dx->dims(), &reduce_axes);
MLUCnnlReduceDesc reduction_desc(
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, dout_desc.get(), GetBasePtr(&dx_temp), 0,
nullptr, nullptr, dx_desc.get(), GetBasePtr(dx));
} else {
dx->ShareDataWith(dx_temp);
}
}
if (dy) {
if (dy->dims() != dout->dims()) {
dy->mutable_data<Tout>(ctx.GetPlace());
std::vector<int> reduce_axes;
GetReduceAxes(axis, dy_temp.dims(), dy->dims(), &reduce_axes);
MLUCnnlReduceDesc reduction_desc(
reduce_axes, CNNL_REDUCE_ADD, data_type, CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dy_desc(*dy);
MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(),
nullptr, dout_desc.get(), GetBasePtr(&dy_temp), 0,
nullptr, nullptr, dy_desc.get(), GetBasePtr(dy));
} else {
dy->ShareDataWith(dy_temp);
}
}
}
} // namespace operators
} // namespace paddle
#endif
......@@ -19,222 +19,350 @@ import unittest
import sys
sys.path.append("..")
from op_test import OpTest
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid.core import ops
paddle.enable_static()
SEED = 2022
class TestElementwiseMax(OpTest):
def ComputeGrad(x, y, out, axis):
grad = 1 / out.size
shape_x = x.shape
shape_y = y.shape
shape_out = out.shape
reduce_axes_x = []
reduce_axes_y = []
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
if shape_x != shape_out:
if len(shape_x.shape) < len(shape_out.shape):
src_axis = axis
else:
src_axis = 0
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.maximum(x, y)
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_x)) or (
shape_out[ax] > 1 and shape_x[ax - src_axis] == 1):
reduce_axes_x.append(ax)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.attrs = {}
self.outputs = {'Out': out}
if shape_y != shape_out:
if len(shape_y) < len(shape_out):
src_axis = axis
else:
src_axis = 0
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_y)) or (
shape_out[ax] > 1 and shape_y[ax - src_axis] == 1):
reduce_axes_y.append(ax)
def init_dtype(self):
self.dtype = np.float32
if len(reduce_axes_x) > 0:
for i in reduce_axes_x:
x = np.expand_dims(x, axis=i)
def test_check_output(self):
self.check_output_with_place(self.place)
if len(reduce_axes_y) > 0:
for i in reduce_axes_y:
y = np.expand_dims(y, axis=i)
mask = np.sign(np.subtract(x, y))
dx = np.maximum(mask, 0) * grad
dy = np.abs(np.minimum(mask, 0) * grad)
if len(reduce_axes_x) > 0:
for i, element in enumerate(reduce_axes_x):
dx = np.add.reduce(dx, element - i)
if len(reduce_axes_y) > 0:
for i, element in enumerate(reduce_axes_y):
dy = np.add.reduce(dy, element - i)
return dx, dy
class TestElementwiseMaxFp16(OpTest):
class TestElementwiseMaxOp(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.maximum(x, y)
self.init_input_output()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': self.out}
def set_mlu(self):
self.place = paddle.MLUPlace(0)
self.__class__.use_mlu = True
self.__class__.no_need_check_grad = True
self.place = paddle.device.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float16
self.dtype = np.float32
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype(
self.dtype)
self.out = np.maximum(self.x, self.y)
def init_axis(self):
self.axis = -1
def test_check_output(self):
self.check_output_with_place(self.place)
class TestElementwiseMaxInt32(OpTest):
def test_check_grad_normal(self):
if self.dtype == np.float16:
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
max_relative_error=0.5)
else:
self.check_grad_with_place(
self.place,
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
max_relative_error=0.9)
else:
self.check_grad_with_place(
self.place,
['Y'],
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
max_relative_error=0.1)
else:
self.check_grad_with_place(
self.place,
['X'],
'Out',
no_grad_set=set("Y"),
)
class TestElementwiseMaxOp_int32(TestElementwiseMaxOp):
def init_dtype(self):
self.dtype = np.int32
# CTest does not support check grad for int32.
def test_check_grad_normal(self):
pass
class TestTestElementwiseMax_Vector(TestElementwiseMax):
def test_check_grad_ingore_x(self):
pass
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [100]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float32")
}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
def test_check_grad_ingore_y(self):
pass
class TestTestElementwiseMax_broadcast_0(TestElementwiseMax):
class TestElementwiseMaxOp_FP16(TestElementwiseMaxOp):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [100, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float32")
}
def init_dtype(self):
self.dtype = np.float16
self.attrs = {'axis': 0}
self.outputs = {
'Out': np.maximum(self.inputs['X'],
self.inputs['Y'].reshape(100, 1, 1))
}
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseMaxOp_scalar(TestElementwiseMaxOp):
class TestTestElementwiseMax_broadcast_1(TestElementwiseMax):
def init_input_output(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype(self.dtype)
self.y = np.array([0.5]).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 100, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float32")
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': np.maximum(self.inputs['X'],
self.inputs['Y'].reshape(1, 100, 1))
}
class TestElementwiseMaxOp_vector(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1,
(100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestTestElementwiseMax_broadcast_2(TestElementwiseMax):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [100]).astype("float32")
}
class TestElementwiseMaxOp_broadcast_0(TestElementwiseMaxOp):
self.outputs = {
'Out': np.maximum(self.inputs['X'],
self.inputs['Y'].reshape(1, 1, 100))
}
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[:, 0, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(100, 1, 1))
def init_axis(self):
self.axis = 0
class TestTestElementwiseMax_broadcast_3(TestElementwiseMax):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 10, 12, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [10, 12]).astype("float32")
}
class TestElementwiseMaxOp_broadcast_1(TestElementwiseMaxOp):
self.attrs = {'axis': 1}
self.outputs = {
'Out':
np.maximum(self.inputs['X'], self.inputs['Y'].reshape(1, 10, 12, 1))
}
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, :, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 100, 1))
def init_axis(self):
self.axis = 1
class TestTestElementwiseMax_broadcast_4(TestElementwiseMax):
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 50]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 1, 50]).astype("float32")
}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestTestElementwiseMax_broadcast_5(TestElementwiseMax):
class TestElementwiseMaxOp_broadcast_2(TestElementwiseMaxOp):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 20]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 1, 20]).astype("float32")
}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, 0, :] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 1, 100))
def test_check_grad_normal(self):
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
user_defined_grads=[dx, dy])
class TestTestElementwiseMax_commonuse_1(TestElementwiseMax):
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 100]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [1, 1, 100]).astype("float32"),
}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestTestElementwiseMax_commonuse_2(TestElementwiseMax):
class TestElementwiseMaxOp_broadcast_3(TestElementwiseMaxOp):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [30, 3, 1, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [30, 1, 4, 1]).astype("float32"),
}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (50, 2)).astype(self.dtype)
self.y = self.x[0, :, :, 0] + sgn * \
np.random.uniform(1, 2, (50, 2)).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 50, 2, 1))
def init_axis(self):
self.axis = 1
class TestTestElementwiseMax_xsize_lessthan_ysize(TestElementwiseMax):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_max"
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 12]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 10, 12]).astype("float32"),
}
class TestElementwiseMaxOp_broadcast_4(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 5)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
self.attrs = {'axis': 2}
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMaxOp_broadcast_5(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 1)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 1)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxNet(unittest.TestCase):
def _test(self, run_mlu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)
a_np = np.random.random(size=(32, 32)).astype('float32')
b_np = np.random.random(size=(32, 32)).astype('float32')
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
label = paddle.static.data(name="label",
shape=[32, 1],
dtype='int64')
c = paddle.maximum(a, b)
fc_1 = fluid.layers.fc(input=c, size=128)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.reduce_mean(cost)
sgd = fluid.optimizer.SGD(learning_rate=0.01)
sgd.minimize(loss)
if run_mlu:
place = paddle.MLUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)
print("Start run on {}".format(place))
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={
"a": a_np,
"b": b_np,
"label": label_np
},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_mlu(self):
cpu_pred, cpu_loss = self._test(False)
mlu_pred, mlu_loss = self._test(True)
self.assertTrue(np.allclose(mlu_pred, cpu_pred))
self.assertTrue(np.allclose(mlu_loss, cpu_loss))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册