未验证 提交 f9e55dee 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] fix arg_max and reduce_max (#42887)

* fix arg_max and reduce_max

* add arg_max ut
上级 21e1d10f
......@@ -34,11 +34,18 @@ struct VisitDataArgNPUMaxFunctor {
out.template mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis");
auto dtype = ctx.Attr<int>("dtype");
const bool& flatten = ctx.Attr<bool>("flatten");
Tensor transformed_x(x.type());
transformed_x.ShareDataWith(x);
if (flatten) {
transformed_x.Resize(phi::make_ddim({x.numel()}));
}
auto stream = ctx.template device_context<NPUDeviceContext>().stream();
NpuOpRunner runner;
runner.SetType("ArgMaxV2")
.AddInput(x)
.AddInput(transformed_x)
.AddInput(std::vector<int64_t>{axis})
.AddOutput(out)
.AddAttrDataType("dtype", dtype)
......
......@@ -112,6 +112,8 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel<T> {
auto* x = context.Input<Tensor>("X");
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto reduce_dims = context.Attr<std::vector<int>>("dim");
bool reduce_all = context.Attr<bool>("reduce_all");
int in_dtype = context.Attr<int>("in_dtype");
PADDLE_ENFORCE_EQ(
......@@ -129,12 +131,30 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel<T> {
// broadcast
auto x_dims_vec = phi::vectorize(x->dims());
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < x_dims_vec.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
Tensor tmp_out, tmp_out_grad;
auto tmp_out_dims_vec = x_dims_vec;
for (auto d : reduce_dims) {
tmp_out_dims_vec[d] = 1;
}
tmp_out.ShareDataWith(*out);
tmp_out.Resize(phi::make_ddim(tmp_out_dims_vec));
tmp_out_grad.ShareDataWith(*out_grad);
tmp_out_grad.Resize(phi::make_ddim(tmp_out_dims_vec));
Tensor transformed_out(x->type());
transformed_out.Resize(phi::make_ddim(x_dims_vec));
transformed_out.mutable_data<T>(place);
NpuOpRunner r_brd_out;
r_brd_out.SetType("BroadcastTo")
.AddInput(*out)
.AddInput(tmp_out)
.AddInput(std::move(x_dims_vec))
.AddOutput(transformed_out)
.Run(stream);
......@@ -143,7 +163,7 @@ class ReduceMaxGradNPUKernel : public framework::OpKernel<T> {
transformed_out_grad.mutable_data<T>(place);
NpuOpRunner r_brd_out_grad;
r_brd_out_grad.SetType("BroadcastTo")
.AddInput(*out_grad)
.AddInput(tmp_out_grad)
.AddInput(std::move(x_dims_vec))
.AddOutput(transformed_out_grad)
.Run(stream);
......
......@@ -328,5 +328,32 @@ class TestArgMaxAPI_2(unittest.TestCase):
run(place)
class TestArgMaxAPI_3(unittest.TestCase):
def initTestCase(self):
self.dims = (1, 9)
self.dtype = 'float32'
def setUp(self):
self.initTestCase()
self.__class__.use_npu = True
self.place = [paddle.NPUPlace(0)]
def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
np.random.seed(2021)
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
tensor_input = paddle.to_tensor(numpy_input)
numpy_output = np.argmax(numpy_input).reshape([1])
paddle_output = paddle.argmax(tensor_input)
self.assertEqual(
np.allclose(numpy_output, paddle_output.numpy()), True)
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
paddle.enable_static()
for place in self.place:
run(place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册