未验证 提交 5310ceab 编写于 作者: B baoachun 提交者: GitHub

add elementwise max grad op for npu (#34862)

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu
上级 73321264
......@@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
......@@ -27,21 +25,202 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
out->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute =
y_dims == framework::slice_ddim(x_dims, axis, x_dims.size());
} else {
direct_compute =
x_dims == framework::slice_ddim(y_dims, axis, y_dims.size());
}
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (direct_compute) {
const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*out}, {});
runner.Run(stream);
} else {
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
const auto& runner =
NpuOpRunner("Maximum", {transformed_x, transformed_y}, {*out}, {});
runner.Run(stream);
}
}
};
template <typename DeviceContext, typename T>
class ElementwiseMaxGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
// The ascend elementwise_max_grad op only supports broadcast
// when axis is -1, and requires all the inputs must have the
// same shape when axis is not -1. For convenience, we should
// broadcast the original input x and y to transformed_x and
// transformed_x firstly, then use tmp tensor to get the op
// output, last reduce the tmp tensor shape to match the
// paddle output.
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
auto dout_dims = dout->dims();
auto stream = dev_ctx.stream();
framework::NPUAttributeMap attr_input = {{"grad_x", true},
{"grad_y", true}};
// Reshape info vector.
std::vector<int> reduce_axes;
if (dx && dy) {
dx->mutable_data<T>(ctx.GetPlace());
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout_dims, ctx.GetPlace());
Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{tmp_dx, tmp_dy}, attr_input);
runner.Run(stream);
if (x_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + x_dims.size()) ||
(dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx);
}
if (y_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + y_dims.size()) ||
(dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy);
}
} else if (dx) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dx->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{tmp_dx, zero_tensor}, attr_input);
runner.Run(stream);
if (x_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + x_dims.size()) ||
(dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx);
}
} else if (dy) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{zero_tensor, tmp_dy}, attr_input);
runner.Run(stream);
if (y_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + y_dims.size()) ||
(dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy);
}
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support all outputs to be empty."));
}
}
};
......@@ -49,9 +228,19 @@ class ElementwiseMaxNPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
elementwise_max,
ops::ElementwiseMaxNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMaxNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, float>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, int>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, float>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, int>);
......@@ -26,70 +26,248 @@ paddle.enable_static()
SEED = 2021
class TestElementwiseMax(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "elementwise_max"
self.place = paddle.NPUPlace(0)
def ComputeGrad(x, y, out, axis):
grad = 1 / out.size
shape_x = x.shape
shape_y = y.shape
shape_out = out.shape
reduce_axes_x = []
reduce_axes_y = []
if shape_x != shape_out:
if len(shape_x.shape) < len(shape_out.shape):
src_axis = axis
else:
src_axis = 0
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.maximum(x, y)
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_x)) or (
shape_out[ax] > 1 and shape_x[ax - src_axis] == 1):
reduce_axes_x.append(ax)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.attrs = {}
self.outputs = {'Out': out}
if shape_y != shape_out:
if len(shape_y) < len(shape_out):
src_axis = axis
else:
src_axis = 0
def set_npu(self):
self.__class__.use_npu = True
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_y)) or (
shape_out[ax] > 1 and shape_y[ax - src_axis] == 1):
reduce_axes_y.append(ax)
def init_dtype(self):
self.dtype = np.float32
if len(reduce_axes_x) > 0:
for i in reduce_axes_x:
x = np.expand_dims(x, axis=i)
def test_check_output(self):
self.check_output_with_place(self.place)
if len(reduce_axes_y) > 0:
for i in reduce_axes_y:
y = np.expand_dims(y, axis=i)
mask = np.sign(np.subtract(x, y))
dx = np.maximum(mask, 0) * grad
dy = np.abs(np.minimum(mask, 0) * grad)
if len(reduce_axes_x) > 0:
for i, element in enumerate(reduce_axes_x):
dx = np.add.reduce(dx, element - i)
if len(reduce_axes_y) > 0:
for i, element in enumerate(reduce_axes_y):
dy = np.add.reduce(dy, element - i)
# TODO(ascendrc): Max grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
return dx, dy
class TestElementwiseMaxFp16(OpTest):
class TestElementwiseMaxOp(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "elementwise_max"
self.place = paddle.NPUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.maximum(x, y)
self.init_input_output()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': self.out}
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
self.dtype = np.float32
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1,
[13, 17]).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
def init_axis(self):
self.axis = -1
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['X'], 'Out', no_grad_set=set("Y"))
class TestElementwiseMaxOp_int32(TestElementwiseMaxOp):
def init_dtype(self):
self.dtype = np.int32
# CTest does not support check grad for int32.
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMaxOp_scalar(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype(self.dtype)
self.y = np.array([0.5]).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_vector(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1,
(100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_broadcast_0(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[:, 0, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(100, 1, 1))
def init_axis(self):
self.axis = 0
class TestElementwiseMaxOp_broadcast_1(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, :, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 100, 1))
def init_axis(self):
self.axis = 1
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwiseMaxOp_broadcast_2(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, 0, :] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 1, 100))
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwiseMaxOp_broadcast_3(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (50, 2)).astype(self.dtype)
self.y = self.x[0, :, :, 0] + sgn * \
np.random.uniform(1, 2, (50, 2)).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 50, 2, 1))
def init_axis(self):
self.axis = 1
class TestElementwiseMaxOp_broadcast_4(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 5)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_broadcast_5(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 1)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 1)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxNet(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册