未验证 提交 f6fab559 编写于 作者: R ronnywang 提交者: GitHub

[NPU] add reduce_mean_op_npu and test (#34053)

* add reduce_mean_op_npu and test

* remove skip.If

* update
上级 9ed5db28
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_npu.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
template <typename T>
class NPUReduceMeanOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");
auto input_dims_vec = framework::vectorize(input->dims());
if (reduce_all) {
dims.clear();
for (size_t i = 0; i < input_dims_vec.size(); i++) {
dims.push_back(static_cast<int>(i));
}
}
const auto& runner = NpuOpRunner("ReduceMeanD", {*input}, {*output},
{{"axes", dims}, {"keep_dims", keep_dim}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename T>
class NPUReduceMeanGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(ctx.GetPlace());
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
auto input_dims_vec = framework::vectorize(input->dims());
int reduce_numel = 1;
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < input_dims_vec.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
for (auto& d : reduce_dims) {
if (d < 0) {
d = d + input_dims_vec.size();
}
reduce_numel *= input_dims_vec[d];
}
const auto& runner =
NpuOpRunner("FillV2D", {}, {*input_grad},
{{"value", 1.0f / static_cast<float>(reduce_numel)},
{"dims", input_dims_vec}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
Tensor transformed_input_grad, transformed_out_grad;
Tensor tmp_output_grad;
auto tmp_output_dims_vec = input_dims_vec;
for (auto d : reduce_dims) {
tmp_output_dims_vec[d] = 1;
}
tmp_output_grad.ShareDataWith(*output_grad);
tmp_output_grad.Resize(framework::make_ddim(tmp_output_dims_vec));
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
NpuElementWiseOpBroadcast<T>(dev_ctx, input_grad, &tmp_output_grad, 0,
&transformed_input_grad,
&transformed_out_grad);
const auto& runner2 =
NpuOpRunner("Mul", {transformed_input_grad, transformed_out_grad},
{*input_grad}, {});
runner2.Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(reduce_mean, ops::NPUReduceMeanOpKernel<float>);
REGISTER_OP_NPU_KERNEL(reduce_mean_grad, ops::NPUReduceMeanGradOpKernel<float>);
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
class TestMeanOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.outputs = {'Out': self.inputs['X'].mean(axis=0)}
def test_check_output(self):
self.check_output_with_place(paddle.NPUPlace(0))
def test_check_grad(self):
self.check_grad_with_place(paddle.NPUPlace(0), ['X'], 'Out')
class TestMeanOp5D(TestMeanOp):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {
'X': np.random.random((1, 2, 5, 6, 10)).astype("float32")
}
self.outputs = {'Out': self.inputs['X'].mean(axis=0)}
class TestMeanOp6D(TestMeanOp):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {
'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float32")
}
self.outputs = {'Out': self.inputs['X'].mean(axis=0)}
class TestMeanOp8D(TestMeanOp):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {
'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float32")
}
self.attrs = {'dim': (0, 3)}
self.outputs = {'Out': self.inputs['X'].mean(axis=(0, 3))}
class Test1DReduce(TestMeanOp):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random(120).astype("float32")}
self.outputs = {'Out': self.inputs['X'].mean(axis=0)}
class Test2DReduce0(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [0]}
self.inputs = {'X': np.random.random((20, 10)).astype("float32")}
self.outputs = {'Out': self.inputs['X'].mean(axis=0)}
class Test2DReduce1(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [1]}
self.inputs = {'X': np.random.random((20, 10)).astype("float32")}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class Test3DReduce0(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [1]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class Test3DReduce1(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class Test3DReduce2(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [-2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class Test3DReduce3(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.attrs = {'dim': [1, 2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float32")}
self.outputs = {
'Out': self.inputs['X'].mean(axis=tuple(self.attrs['dim']))
}
class TestKeepDimReduce(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")}
self.attrs = {'dim': [1], 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim'])
}
class TestKeepDim8DReduce(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float32")
}
self.attrs = {'dim': (3, 4, 5), 'keep_dim': True}
self.outputs = {
'Out': self.inputs['X'].mean(
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim'])
}
class TestReduceAll(Test1DReduce):
def setUp(self):
self.set_npu()
self.op_type = "reduce_mean"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].mean()}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册