未验证 提交 e913796c 编写于 作者: W WJJ1995 提交者: GitHub

[NPU] Add elementwise_pow_grad npu op (#35278)

* add elementwise_pow_grad_npu

* fixed bug for CI

* deal with comments

* fixed bug for CI

* deal with comments
上级 648e3775
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
......@@ -27,21 +28,198 @@ template <typename DeviceContext, typename T>
class ElementwisePowNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
int axis = ctx.Attr<int>("axis");
out->mutable_data<T>(place);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
bool direct_compute = false;
auto x_dims = x->dims();
auto y_dims = y->dims();
axis =
(axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis);
if (x_dims.size() >= y_dims.size()) {
direct_compute =
y_dims == framework::slice_ddim(x_dims, axis, x_dims.size());
} else {
direct_compute =
x_dims == framework::slice_ddim(y_dims, axis, y_dims.size());
}
auto stream = dev_ctx.stream();
if (direct_compute) {
const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {});
runner.Run(stream);
} else {
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
const auto& runner =
NpuOpRunner("Pow", {transformed_x, transformed_y}, {*out}, {});
runner.Run(stream);
}
}
};
const auto& runner = NpuOpRunner("Pow", {*x, *y}, {*out}, {});
runner.Run(stream);
template <typename DeviceContext, typename T>
class ElementwisePowGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto place = ctx.GetPlace();
auto x_dims = x->dims();
auto y_dims = y->dims();
axis =
(axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis);
Tensor transformed_x, transformed_y;
NpuElementWiseOpBroadcast<T>(dev_ctx, x, y, axis, &transformed_x,
&transformed_y);
auto dout_dims = dout->dims();
auto stream = dev_ctx.stream();
// Reshape info vector.
std::vector<int> reduce_axes;
if (dx) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, place);
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dx->mutable_data<T>(place);
Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout_dims, place);
// dx = dout * y * pow(x, y - 1);
Tensor PowGrad_dx_temp1(dout->type());
PowGrad_dx_temp1.mutable_data<T>(dout->dims(), place);
const auto& runner_PowGrad_dx_temp1 =
NpuOpRunner("Mul", {*dout, transformed_y}, {PowGrad_dx_temp1}, {});
runner_PowGrad_dx_temp1.Run(stream);
Tensor one_dx(transformed_y.type());
one_dx.mutable_data<T>(transformed_y.dims(), place);
const auto& runner_one_dx =
NpuOpRunner("OnesLike", {transformed_y}, {one_dx}, {});
runner_one_dx.Run(stream);
Tensor sub_dx(transformed_y.type());
sub_dx.mutable_data<T>(transformed_y.dims(), place);
const auto& runner_sub_dx =
NpuOpRunner("Sub", {transformed_y, one_dx}, {sub_dx}, {});
runner_sub_dx.Run(stream);
Tensor PowGrad_dx_temp2(transformed_x.type());
PowGrad_dx_temp2.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_PowGrad_dx_temp2 =
NpuOpRunner("Pow", {transformed_x, sub_dx}, {PowGrad_dx_temp2}, {});
runner_PowGrad_dx_temp2.Run(stream);
const auto& runner_dx = NpuOpRunner(
"Mul", {PowGrad_dx_temp1, PowGrad_dx_temp2}, {tmp_dx}, {});
runner_dx.Run(stream);
if (x_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (x_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + x_dims.size()) ||
(dout_dims[ax] > 1 && x_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dx}, {*dx},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dx, place, dev_ctx, dx);
}
}
if (dy) {
Tensor zero_tensor(dout->type());
zero_tensor.mutable_data<T>(dout_dims, place);
FillNpuTensorWithConstant<T>(&zero_tensor, static_cast<T>(0));
dy->mutable_data<T>(place);
Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout_dims, place);
// dy = dout * log(x) * pow(x, y)
Tensor PowGrad_dy_temp1(transformed_x.type());
PowGrad_dy_temp1.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_PowGrad_dy_temp1 = NpuOpRunner(
"Pow", {transformed_x, transformed_y}, {PowGrad_dy_temp1}, {});
runner_PowGrad_dy_temp1.Run(stream);
Tensor one_dy(transformed_x.type());
one_dy.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_one_dy =
NpuOpRunner("OnesLike", {transformed_x}, {one_dy}, {});
runner_one_dy.Run(stream);
Tensor sub_dy(transformed_x.type());
sub_dy.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_sub_dy =
NpuOpRunner("Sub", {transformed_x, one_dy}, {sub_dy}, {});
runner_sub_dy.Run(stream);
Tensor log_dy(transformed_x.type());
log_dy.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_log_dy = NpuOpRunner("Log1p", {sub_dy}, {log_dy}, {});
runner_log_dy.Run(stream);
Tensor PowGrad_dy_temp2(transformed_x.type());
PowGrad_dy_temp2.mutable_data<T>(transformed_x.dims(), place);
const auto& runner_PowGrad_dy_temp2 = NpuOpRunner(
"Mul", {log_dy, PowGrad_dy_temp1}, {PowGrad_dy_temp2}, {});
runner_PowGrad_dy_temp2.Run(stream);
const auto& runner_dy =
NpuOpRunner("Mul", {*dout, PowGrad_dy_temp2}, {tmp_dy}, {});
runner_dy.Run(stream);
if (y_dims != dout_dims) {
reduce_axes.clear();
int src_axis = (y_dims.size() < dout_dims.size() ? axis : 0);
for (int ax = 0; ax < dout_dims.size(); ++ax) {
if ((ax < src_axis || ax >= src_axis + y_dims.size()) ||
(dout_dims[ax] > 1 && y_dims[ax - src_axis] == 1)) {
reduce_axes.push_back(ax);
}
}
if (!reduce_axes.empty()) {
const auto& runner =
NpuOpRunner("ReduceSumD", {tmp_dy}, {*dy},
{{"axes", reduce_axes}, {"keep_dims", false}});
runner.Run(stream);
}
} else {
framework::TensorCopy(tmp_dy, place, dev_ctx, dy);
}
}
if (!dx && !dy) {
PADDLE_THROW(platform::errors::Unavailable(
"Not support all outputs to be empty."));
}
}
};
......@@ -49,9 +227,18 @@ class ElementwisePowNPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
elementwise_pow,
ops::ElementwisePowNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwisePowNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
ops::ElementwisePowNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwisePowNPUKernel<plat::NPUDeviceContext, float>,
ops::ElementwisePowNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwisePowNPUKernel<plat::NPUDeviceContext, int>);
REGISTER_OP_NPU_KERNEL(
elementwise_pow_grad,
ops::ElementwisePowGradNPUKernel<plat::NPUDeviceContext, plat::float16>,
ops::ElementwisePowGradNPUKernel<plat::NPUDeviceContext, float>,
ops::ElementwisePowGradNPUKernel<plat::NPUDeviceContext, double>,
ops::ElementwisePowGradNPUKernel<plat::NPUDeviceContext, int>);
......@@ -13,19 +13,71 @@
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle
from op_test import OpTest
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
def ComputeGrad(x, y, out, axis):
grad = 1 / out.size
shape_x = x.shape
shape_y = y.shape
shape_out = out.shape
reduce_axes_x = []
reduce_axes_y = []
if shape_x != shape_out:
if len(shape_x) < len(shape_out):
src_axis = axis
else:
src_axis = 0
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_x)) or (
shape_out[ax] > 1 and shape_x[ax - src_axis] == 1):
reduce_axes_x.append(ax)
if shape_y != shape_out:
if len(shape_y) < len(shape_out):
src_axis = axis
else:
src_axis = 0
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_y)) or (
shape_out[ax] > 1 and shape_y[ax - src_axis] == 1):
reduce_axes_y.append(ax)
if len(reduce_axes_x) > 0:
for i in reduce_axes_x:
x = np.expand_dims(x, axis=i)
if len(reduce_axes_y) > 0:
for i in reduce_axes_y:
y = np.expand_dims(y, axis=i)
dx = y * np.power(x, y - 1) * grad
dy = np.log(x) * np.power(x, y) * grad
if len(reduce_axes_x) > 0:
for i, element in enumerate(reduce_axes_x):
dx = np.add.reduce(dx, element - i)
if len(reduce_axes_y) > 0:
for i, element in enumerate(reduce_axes_y):
dy = np.add.reduce(dy, element - i)
return dx, dy
class TestElementwisePow(OpTest):
def setUp(self):
self.set_npu()
......@@ -33,17 +85,15 @@ class TestElementwisePow(OpTest):
self.place = paddle.NPUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
out = np.power(x, y)
self.init_input_output()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': self.out}
def set_npu(self):
self.__class__.use_npu = True
......@@ -54,44 +104,177 @@ class TestElementwisePow(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place)
# TODO(ascendrc): Pow grad test
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
def init_axis(self):
self.axis = -1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowFp16(TestElementwisePow):
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
class TestElementwisePowFp16(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "elementwise_pow"
self.place = paddle.NPUPlace(0)
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
self.init_dtype()
np.random.seed(SEED)
x = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
y = np.random.uniform(1, 2, [3, 4]).astype(self.dtype)
out = np.power(x, y)
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(x),
'Y': OpTest.np_dtype_to_fluid_dtype(y)
}
self.attrs = {}
self.outputs = {'Out': out}
class TestElementwisePowDouble(TestElementwisePow):
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float16
self.dtype = np.float64
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)
class TestElementwisePowOp_broadcast_0(TestElementwisePow):
def init_axis(self):
self.axis = 1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [1, 11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowOp_broadcast_1(TestElementwisePow):
def init_axis(self):
self.axis = 1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [2, 100, 1]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.power(self.x, self.y.reshape(1, 100, 1))
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowOp_broadcast_2(TestElementwisePow):
def init_axis(self):
self.axis = 0
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(0.1, 1, [100, 3, 1]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.power(self.x, self.y.reshape(100, 1, 1))
def test_check_grad_normal(self):
if self.dtype == np.float16:
return
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X', 'Y'], 'Out', user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(
self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowNet(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program()
......
......@@ -41,6 +41,7 @@ NEED_TO_FIX_OP_LIST = [
'elementwise_min',
'elementwise_mul',
'elementwise_sub',
'elementwise_pow',
'filter_by_instag',
'fused_elemwise_activation',
'fused_emb_seq_pool',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册