未验证 提交 44da9b42 编写于 作者: J joeqiao12 提交者: GitHub

add reduce_min and reduce_max (#39899)

上级 8895379a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceMaxMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
int out_dtype = context.Attr<int>("out_dtype");
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto input_dims = framework::vectorize(input->dims());
const auto& input_dim_size = input->dims().size();
std::vector<int> reduce_dims;
if (reduce_all) {
for (size_t i = 0; i < input_dims.size(); i++) {
reduce_dims.push_back(static_cast<int>(i));
}
} else {
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
reduce_dims.push_back(dims[i] + input_dim_size);
} else {
reduce_dims.push_back(dims[i]);
}
}
}
auto place = context.GetPlace();
framework::Tensor cast_out(input->type());
cast_out.Resize(output->dims());
cast_out.mutable_data<T>(place);
auto cast_out_dtype = framework::TransToProtoVarType(input->dtype());
if (out_dtype != -1) {
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
}
if (framework::TransToProtoVarType(input->type()) != cast_out_dtype) {
if (cast_out_dtype == framework::proto::VarType::FP32) {
output->mutable_data<float>(place);
} else if (cast_out_dtype == framework::proto::VarType::FP16) {
output->mutable_data<paddle::platform::float16>(place);
} else if (cast_out_dtype == framework::proto::VarType::INT32) {
output->mutable_data<int32_t>(place);
}
} else {
output->ShareDataWith(cast_out);
}
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->dtype()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));
MLUCnnlReduceDesc reduction_desc(
reduce_dims, CNNL_REDUCE_MAX, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(),
nullptr, input_desc.get(), GetBasePtr(input),
0 /*indices_size*/, nullptr, nullptr, output_desc.get(),
GetBasePtr(output));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(reduce_max, ops::ReduceMaxMLUKernel<float>,
ops::ReduceMaxMLUKernel<plat::float16>,
ops::ReduceMaxMLUKernel<int>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceMinMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
int out_dtype = context.Attr<int>("out_dtype");
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto input_dims = framework::vectorize(input->dims());
const auto& input_dim_size = input->dims().size();
std::vector<int> reduce_dims;
if (reduce_all) {
for (size_t i = 0; i < input_dims.size(); i++) {
reduce_dims.push_back(static_cast<int>(i));
}
} else {
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
reduce_dims.push_back(dims[i] + input_dim_size);
} else {
reduce_dims.push_back(dims[i]);
}
}
}
auto place = context.GetPlace();
framework::Tensor cast_out(input->type());
cast_out.Resize(output->dims());
cast_out.mutable_data<T>(place);
auto cast_out_dtype = framework::TransToProtoVarType(input->dtype());
if (out_dtype != -1) {
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
}
if (framework::TransToProtoVarType(input->type()) != cast_out_dtype) {
if (cast_out_dtype == framework::proto::VarType::FP32) {
output->mutable_data<float>(place);
} else if (cast_out_dtype == framework::proto::VarType::FP16) {
output->mutable_data<paddle::platform::float16>(place);
} else if (cast_out_dtype == framework::proto::VarType::INT32) {
output->mutable_data<int32_t>(place);
}
} else {
output->ShareDataWith(cast_out);
}
MLUCnnlTensorDesc input_desc(*input, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(input->dtype()));
MLUCnnlTensorDesc output_desc(*output, CNNL_LAYOUT_ARRAY,
ToCnnlDataType(output->dtype()));
MLUCnnlReduceDesc reduction_desc(
reduce_dims, CNNL_REDUCE_MIN, ToCnnlDataType<T>(),
CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES);
MLUCnnl::Reduce(context, true /*need_workspace*/, reduction_desc.get(),
nullptr, input_desc.get(), GetBasePtr(input),
0 /*indices_size*/, nullptr, nullptr, output_desc.get(),
GetBasePtr(output));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(reduce_min, ops::ReduceMinMLUKernel<float>,
ops::ReduceMinMLUKernel<plat::float16>,
ops::ReduceMinMLUKernel<int>);
......@@ -541,11 +541,12 @@ class ReduceOp : public framework::OperatorWithKernel {
#endif
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU or NPU place"));
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()),
true, platform::errors::InvalidArgument(
"float16 can only be used on GPU or NPU or MLU place"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import convert_np_dtype_to_dtype_
paddle.enable_static()
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestMLUReduceMaxOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'dim': [-1]}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output_with_place(self.place)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMaxOpMultiAxises(TestMLUReduceMaxOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'dim': [-2, -1]}
self.outputs = {
'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim']))
}
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceAll(TestMLUReduceMaxOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].max()}
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMaxOpWithOutDtype_int32(TestMLUReduceMaxOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out':
self.inputs['X'].max(axis=tuple(self.attrs['dim'])).astype(np.int32)
}
def init_dtype(self):
self.dtype = np.int32
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMaxOpWithOutDtype_fp16(TestMLUReduceMaxOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.FP16)
}
self.outputs = {
'Out': self.inputs['X'].max(
axis=tuple(self.attrs['dim'])).astype(np.float16)
}
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMaxOpWithOutDtype_fp32(TestMLUReduceMaxOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_max"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.outputs = {
'Out': self.inputs['X'].max(
axis=tuple(self.attrs['dim'])).astype(np.float32)
}
def init_dtype(self):
self.dtype = np.float32
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import convert_np_dtype_to_dtype_
paddle.enable_static()
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestMLUReduceMinOp(OpTest):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'dim': [-1]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output_with_place(self.place)
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMinOpMultiAxises(TestMLUReduceMinOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'dim': [-2, -1]}
self.outputs = {
'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim']))
}
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceAll(TestMLUReduceMinOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {'reduce_all': True}
self.outputs = {'Out': self.inputs['X'].min()}
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMinOpWithOutDtype_int32(TestMLUReduceMinOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.INT32)
}
self.outputs = {
'Out':
self.inputs['X'].min(axis=tuple(self.attrs['dim'])).astype(np.int32)
}
def init_dtype(self):
self.dtype = np.int32
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMinOpWithOutDtype_fp16(TestMLUReduceMinOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.FP16)
}
self.outputs = {
'Out': self.inputs['X'].min(
axis=tuple(self.attrs['dim'])).astype(np.float16)
}
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
@skip_check_grad_ci(
reason="reduce_min is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.")
class TestReduceMinOpWithOutDtype_fp32(TestMLUReduceMinOp):
"""Remove Min with subgradient from gradient check to confirm the success of CI."""
def setUp(self):
self.op_type = "reduce_min"
self.set_mlu()
self.init_dtype()
self.inputs = {'X': np.random.random((5, 6, 10)).astype(self.dtype)}
self.attrs = {
'dim': [-2, -1],
'out_dtype': int(core.VarDesc.VarType.FP32)
}
self.outputs = {
'Out': self.inputs['X'].min(
axis=tuple(self.attrs['dim'])).astype(np.float32)
}
def init_dtype(self):
self.dtype = np.float32
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册