未验证 提交 5310ceab 编写于 作者: B baoachun 提交者: GitHub

add elementwise max grad op for npu (#34862)

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu

* add elementwise max grad op for npu
上级 73321264
...@@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
...@@ -27,21 +25,202 @@ template <typename DeviceContext, typename T> ...@@ -27,21 +25,202 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxNPUKernel : public framework::OpKernel<T> { class ElementwiseMaxNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto place = ctx.GetPlace(); int axis = ctx.Attr<int>("axis");
out->mutable_data<T>(place); bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute =
y_dims == framework::slice_ddim(x_dims, axis, x_dims.size());
} else {
direct_compute =
x_dims == framework::slice_ddim(y_dims, axis, y_dims.size());
}
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*out}, {}); if (direct_compute) {
runner.Run(stream); const auto& runner = NpuOpRunner("Maximum", {*x, *y}, {*out}, {});
runner.Run(stream);
} else {
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
const auto& runner =
NpuOpRunner("Maximum", {transformed_x, transformed_y}, {*out}, {});
runner.Run(stream);
}
}
};
template <typename DeviceContext, typename T>
class ElementwiseMaxGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
// The ascend elementwise_max_grad op only supports broadcast
// when axis is -1, and requires all the inputs must have the
// same shape when axis is not -1. For convenience, we should
// broadcast the original input x and y to transformed_x and
// transformed_x firstly, then use tmp tensor to get the op
// output, last reduce the tmp tensor shape to match the
// paddle output.
auto x_dims = x->dims();
auto y_dims = y->dims();
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
auto dout_dims = dout->dims();
auto stream = dev_ctx.stream();
framework::NPUAttributeMap attr_input = {{"grad_x", true},
{"grad_y", true}};
// Reshape info vector.
std::vector<int> reduce_axes;
if (dx && dy) {
dx->mutable_data<T>(ctx.GetPlace());
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout_dims, ctx.GetPlace());
Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{tmp_dx, tmp_dy}, attr_input);
runner.Run(stream);
if (x_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + x_dims.size()) ||
(dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx);
}
if (y_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + y_dims.size()) ||
(dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy);
}
} else if (dx) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dx->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{tmp_dx, zero_tensor}, attr_input);
runner.Run(stream);
if (x_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + x_dims.size()) ||
(dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dx, ctx.GetPlace(), dev_ctx, dx);
}
} else if (dy) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dy->mutable_data<T>(ctx.GetPlace());
Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout_dims, ctx.GetPlace());
const auto& runner =
NpuOpRunner("MaximumGrad", {*dout, transformed_x, transformed_y},
{zero_tensor, tmp_dy}, attr_input);
runner.Run(stream);
if (y_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + y_dims.size()) ||
(dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dy, ctx.GetPlace(), dev_ctx, dy);
}
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support all outputs to be empty."));
}
} }
}; };
...@@ -49,9 +228,19 @@ class ElementwiseMaxNPUKernel : public framework::OpKernel<T> { ...@@ -49,9 +228,19 @@ class ElementwiseMaxNPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
elementwise_max, elementwise_max,
ops::ElementwiseMaxNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwiseMaxNPUKernel<paddle::platform::NPUDeviceContext, ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, float>,
paddle::platform::float16>); ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, int>,
ops::ElementwiseMaxNPUKernel<plat::NPUDeviceContext, int64_t>);
REGISTER_OP_NPU_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, float>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwiseMaxGradNPUKernel<plat::NPUDeviceContext, int>);
...@@ -26,70 +26,248 @@ paddle.enable_static() ...@@ -26,70 +26,248 @@ paddle.enable_static()
SEED = 2021 SEED = 2021
class TestElementwiseMax(OpTest): def ComputeGrad(x, y, out, axis):
def setUp(self): grad = 1 / out.size
self.set_npu() shape_x = x.shape
self.op_type = "elementwise_max" shape_y = y.shape
self.place = paddle.NPUPlace(0) shape_out = out.shape
reduce_axes_x = []
reduce_axes_y = []
if shape_x != shape_out:
if len(shape_x.shape) < len(shape_out.shape):
src_axis = axis
else:
src_axis = 0
self.init_dtype() for ax in range(len(shape_out)):
np.random.seed(SEED) if (ax < src_axis or ax >= src_axis + len(shape_x)) or (
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) shape_out[ax] > 1 and shape_x[ax - src_axis] == 1):
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) reduce_axes_x.append(ax)
out = np.maximum(x, y)
self.inputs = { if shape_y != shape_out:
'X': OpTest.np_dtype_to_fluid_dtype(x), if len(shape_y) < len(shape_out):
'Y': OpTest.np_dtype_to_fluid_dtype(y) src_axis = axis
} else:
self.attrs = {} src_axis = 0
self.outputs = {'Out': out}
def set_npu(self): for ax in range(len(shape_out)):
self.__class__.use_npu = True if (ax < src_axis or ax >= src_axis + len(shape_y)) or (
shape_out[ax] > 1 and shape_y[ax - src_axis] == 1):
reduce_axes_y.append(ax)
def init_dtype(self): if len(reduce_axes_x) > 0:
self.dtype = np.float32 for i in reduce_axes_x:
x = np.expand_dims(x, axis=i)
def test_check_output(self): if len(reduce_axes_y) > 0:
self.check_output_with_place(self.place) for i in reduce_axes_y:
y = np.expand_dims(y, axis=i)
mask = np.sign(np.subtract(x, y))
dx = np.maximum(mask, 0) * grad
dy = np.abs(np.minimum(mask, 0) * grad)
if len(reduce_axes_x) > 0:
for i, element in enumerate(reduce_axes_x):
dx = np.add.reduce(dx, element - i)
if len(reduce_axes_y) > 0:
for i, element in enumerate(reduce_axes_y):
dy = np.add.reduce(dy, element - i)
# TODO(ascendrc): Max grad test return dx, dy
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
class TestElementwiseMaxFp16(OpTest): class TestElementwiseMaxOp(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.op_type = "elementwise_max" self.op_type = "elementwise_max"
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
self.init_dtype() self.init_dtype()
np.random.seed(SEED) self.init_input_output()
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype) self.init_axis()
y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.maximum(x, y)
self.inputs = { self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x), 'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(y) 'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
} }
self.attrs = {} self.attrs = {'axis': self.axis}
self.outputs = {'Out': out} self.outputs = {'Out': self.out}
def set_npu(self): def set_npu(self):
self.__class__.use_npu = True self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float32
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
sgn = np.random.choice([-1, 1], [13, 17]).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1,
[13, 17]).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
def init_axis(self):
self.axis = -1
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5) self.check_output_with_place(self.place)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
self.check_grad_with_place(
self.place, ['X'], 'Out', no_grad_set=set("Y"))
class TestElementwiseMaxOp_int32(TestElementwiseMaxOp):
def init_dtype(self):
self.dtype = np.int32
# CTest does not support check grad for int32.
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
class TestElementwiseMaxOp_scalar(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.random_integers(-5, 5, [2, 3, 20]).astype(self.dtype)
self.y = np.array([0.5]).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_vector(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.random((100, )).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x + sgn * np.random.uniform(0.1, 1,
(100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_broadcast_0(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[:, 0, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(100, 1, 1))
def init_axis(self):
self.axis = 0
class TestElementwiseMaxOp_broadcast_1(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, :, 0] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 100, 1))
def init_axis(self):
self.axis = 1
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwiseMaxOp_broadcast_2(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 100)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (100, )).astype(self.dtype)
self.y = self.x[0, 0, :] + sgn * \
np.random.uniform(1, 2, (100, )).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 1, 100))
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
if self.dtype == np.float16:
return
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
if self.dtype == np.float16:
return
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwiseMaxOp_broadcast_3(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (50, 2)).astype(self.dtype)
self.y = self.x[0, :, :, 0] + sgn * \
np.random.uniform(1, 2, (50, 2)).astype(self.dtype)
self.out = np.maximum(self.x, self.y.reshape(1, 50, 2, 1))
def init_axis(self):
self.axis = 1
class TestElementwiseMaxOp_broadcast_4(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 5)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxOp_broadcast_5(TestElementwiseMaxOp):
def init_input_output(self):
self.x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(self.dtype)
sgn = np.random.choice([-1, 1], (2, 3, 1, 1)).astype(self.dtype)
self.y = self.x + sgn * \
np.random.uniform(1, 2, (2, 3, 1, 1)).astype(self.dtype)
self.out = np.maximum(self.x, self.y)
class TestElementwiseMaxNet(unittest.TestCase): class TestElementwiseMaxNet(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册