From 76f8703445b269334035a891466a284148c26734 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 12 Mar 2022 19:39:32 +0800 Subject: [PATCH] [Phi] Move allclose op kernel into phi (#40469) * move allclose kernel * remove allclose op kernel * fix coverage failed --- paddle/fluid/operators/allclose_op.cc | 39 +--------- paddle/fluid/operators/allclose_op.cu | 84 -------------------- paddle/fluid/operators/allclose_op.h | 93 ----------------------- paddle/phi/api/lib/utils/tensor_utils.cc | 7 ++ paddle/phi/kernels/allclose_kernel.h | 31 ++++++++ paddle/phi/kernels/cpu/allclose_kernel.cc | 71 +++++++++++++++++ paddle/phi/kernels/gpu/allclose_kernel.cu | 89 ++++++++++++++++++++++ paddle/phi/ops/compat/allclose_sig.cc | 49 ++++++++++++ paddle/phi/tests/ops/test_op_signature.cc | 28 +++++++ 9 files changed, 276 insertions(+), 215 deletions(-) delete mode 100644 paddle/fluid/operators/allclose_op.cu delete mode 100644 paddle/fluid/operators/allclose_op.h create mode 100644 paddle/phi/kernels/allclose_kernel.h create mode 100644 paddle/phi/kernels/cpu/allclose_kernel.cc create mode 100644 paddle/phi/kernels/gpu/allclose_kernel.cu create mode 100644 paddle/phi/ops/compat/allclose_sig.cc diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index 8fb9929c39e..706a132878d 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/allclose_op.h" #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/operator.h" @@ -23,41 +23,6 @@ namespace paddle { namespace operators { -template -struct GetTensorValue { - T operator()(const platform::CPUDeviceContext& dev_ctx, - const framework::Tensor& tensor) const { - return *(tensor.data()); - } -}; - -template -struct AllcloseFunctor { - void operator()(const platform::CPUDeviceContext& ctx, - const framework::Tensor& in, const framework::Tensor& other, - const double rtol, const double atol, bool equal_nan, - framework::Tensor* output) { - auto* in_a = in.data(); - auto* in_b = other.data(); - auto* out_data = output->mutable_data(ctx.GetPlace()); - auto num = in.numel(); - *out_data = true; - for (int i = 0; i < num; i++) { - const T a = in_a[i], b = in_b[i]; - bool val; - if (std::isnan(a) || std::isnan(b)) { - val = equal_nan && std::isnan(a) == std::isnan(b); - } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); - val = a == b || left <= right || diff <= 1e-15; - } - *out_data &= val; - } - } -}; - class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -157,8 +122,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::AllcloseOpVarTypeInference); -REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel, - ops::AllcloseKernel); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(allclose) diff --git a/paddle/fluid/operators/allclose_op.cu b/paddle/fluid/operators/allclose_op.cu deleted file mode 100644 index 32c90ff8fdc..00000000000 --- a/paddle/fluid/operators/allclose_op.cu +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/allclose_op.h" - -namespace paddle { -namespace operators { - -template -struct GetTensorValue { - T operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor& tensor) const { - const T* data = tensor.data(); - T value; - const auto gpu_place = dev_ctx.GetPlace(); - memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), - dev_ctx.stream()); - return value; - } -}; - -template -__global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data, - const double rtol, const double atol, - bool equal_nan, int num, bool* out_data) { - unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; - bool val; - for (int i = idx; i < num; i += blockDim.x * gridDim.x) { - const T a = in_data[i], b = other_data[i]; - if (isnan(a) || isnan(b)) { - val = equal_nan && isnan(a) == isnan(b); - } else { - T left = (a > b ? a - b : b - a); - T right = atol + (b > 0 ? rtol * b : (-rtol) * b); - T diff = (left > right ? left - right : right - left); - val = a == b || left <= right || diff <= 1e-15; - } - if (!val) *out_data = false; - } -} - -template -struct AllcloseFunctor { - void operator()(const platform::CUDADeviceContext& dev_ctx, - const framework::Tensor& in, const framework::Tensor& other, - const double rtol, const double atol, bool equal_nan, - framework::Tensor* output) { - int num = in.numel(); - const T* in_data = in.data(); - const T* other_data = other.data(); - bool* out_data = output->mutable_data(dev_ctx.GetPlace()); - int block = 1024; - int grid = (block - 1 + num) / block; - grid = (grid > block) ? block : grid; -#ifdef PADDLE_WITH_HIP - hipMemset(out_data, true, sizeof(bool)); -#else - cudaMemset(out_data, true, sizeof(bool)); -#endif - AllcloseCUDAKernel<<>>( - in_data, other_data, rtol, atol, equal_nan, num, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel, - ops::AllcloseKernel); diff --git a/paddle/fluid/operators/allclose_op.h b/paddle/fluid/operators/allclose_op.h deleted file mode 100644 index 7a36754194a..00000000000 --- a/paddle/fluid/operators/allclose_op.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -struct GetTensorValue { - T operator()(const platform::DeviceContext& ctx, - const framework::Tensor& tensor) const; -}; - -template -struct AllcloseFunctor { - void operator()(const DeviceContext& ctx, const framework::Tensor& in, - const framework::Tensor& other, const float rtol, - const float atol, bool equal_nan, framework::Tensor* output); -}; - -template -class AllcloseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - // get attrs - bool equal_nan = ctx.Attr("equal_nan"); - // get input/output - const auto* input = ctx.Input("Input"); - const auto* other = ctx.Input("Other"); - auto* out = ctx.Output("Out"); - - double rtol_v = std::stod(ctx.Attr("rtol")); - double atol_v = std::stod(ctx.Attr("atol")); - - auto& dev_ctx = ctx.template device_context(); - GetTensorValue get_tensor_value; - if (ctx.HasInput("Rtol")) { - const auto* rtol = ctx.Input("Rtol"); - PADDLE_ENFORCE_EQ( - rtol->numel(), 1, - platform::errors::InvalidArgument( - "Input(Rtol) size must be 1, but get %d.", rtol->numel())); - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(rtol->dtype()), - framework::proto::VarType::FP64, - platform::errors::InvalidArgument( - "Input(Rtol) type must be double, but get %s.", - framework::DataTypeToString( - framework::TransToProtoVarType(rtol->dtype())))); - rtol_v = get_tensor_value(dev_ctx, *rtol); - } - if (ctx.HasInput("Atol")) { - const auto* atol = ctx.Input("Atol"); - PADDLE_ENFORCE_EQ( - atol->numel(), 1, - platform::errors::InvalidArgument( - "Input(Atol) size must be 1, but get %d", atol->numel())); - PADDLE_ENFORCE_EQ( - framework::TransToProtoVarType(atol->dtype()), - framework::proto::VarType::FP64, - platform::errors::InvalidArgument( - "Input(Atol) type must be double, but get %s", - framework::DataTypeToString( - framework::TransToProtoVarType(atol->dtype())))); - atol_v = get_tensor_value(dev_ctx, *atol); - } - - AllcloseFunctor()(dev_ctx, *input, *other, rtol_v, atol_v, - equal_nan, out); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/lib/utils/tensor_utils.cc b/paddle/phi/api/lib/utils/tensor_utils.cc index 1c9f7c3a868..3d183ea7fee 100644 --- a/paddle/phi/api/lib/utils/tensor_utils.cc +++ b/paddle/phi/api/lib/utils/tensor_utils.cc @@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) { auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU); if (variable.IsType()) { const auto& tensor = variable.Get(); + PADDLE_ENFORCE_EQ( + tensor.numel(), + 1UL, + platform::errors::InvalidArgument("The DenseTensor used to construct " + "the Scalar contains more than 1 " + "value, it contains `%d` values.", + tensor.numel())); if (!platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); diff --git a/paddle/phi/kernels/allclose_kernel.h b/paddle/phi/kernels/allclose_kernel.h new file mode 100644 index 00000000000..3f24078b86c --- /dev/null +++ b/paddle/phi/kernels/allclose_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void AllCloseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rtol, + const Scalar& atol, + bool equal_nan, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/allclose_kernel.cc b/paddle/phi/kernels/cpu/allclose_kernel.cc new file mode 100644 index 00000000000..7ffeadfeed8 --- /dev/null +++ b/paddle/phi/kernels/cpu/allclose_kernel.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/allclose_kernel.h" + +#include + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AllCloseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rtol, + const Scalar& atol, + bool equal_nan, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + rtol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument( + "Input (Rtol) type must be double, but get %s.", rtol.dtype())); + PADDLE_ENFORCE_EQ( + atol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument( + "Input (Atol) type must be double, but get %s.", atol.dtype())); + + auto* in_a = x.data(); + auto* in_b = y.data(); + auto rtol_v = rtol.to(); + auto atol_v = atol.to(); + auto* out_data = dev_ctx.template Alloc(out); + *out_data = true; + + auto num = x.numel(); + for (int64_t i = 0; i < num; ++i) { + const T a = in_a[i], b = in_b[i]; + bool val; + if (std::isnan(a) || std::isnan(b)) { + val = equal_nan && std::isnan(a) == std::isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol_v + (b > 0 ? rtol_v * b : (-rtol_v) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + *out_data &= val; + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + allclose, CPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} diff --git a/paddle/phi/kernels/gpu/allclose_kernel.cu b/paddle/phi/kernels/gpu/allclose_kernel.cu new file mode 100644 index 00000000000..af2612bb10c --- /dev/null +++ b/paddle/phi/kernels/gpu/allclose_kernel.cu @@ -0,0 +1,89 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/allclose_kernel.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +__global__ void AllcloseCUDAKernel(const T* in_data, + const T* other_data, + const double rtol, + const double atol, + bool equal_nan, + int num, + bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const T a = in_data[i], b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + if (!val) *out_data = false; + } +} + +template +void AllCloseKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& rtol, + const Scalar& atol, + bool equal_nan, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + rtol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument( + "Input (Rtol) type must be double, but get %s.", rtol.dtype())); + PADDLE_ENFORCE_EQ( + atol.dtype(), + DataType::FLOAT64, + phi::errors::InvalidArgument( + "Input (Atol) type must be double, but get %s.", atol.dtype())); + + const T* in_data = x.data(); + const T* other_data = y.data(); + auto rtol_v = rtol.to(); + auto atol_v = atol.to(); + bool* out_data = dev_ctx.template Alloc(out); + + int num = x.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; +#ifdef PADDLE_WITH_HIP + hipMemset(out_data, true, sizeof(bool)); +#else + cudaMemset(out_data, true, sizeof(bool)); +#endif + AllcloseCUDAKernel<<>>( + in_data, other_data, rtol_v, atol_v, equal_nan, num, out_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); +} diff --git a/paddle/phi/ops/compat/allclose_sig.cc b/paddle/phi/ops/compat/allclose_sig.cc new file mode 100644 index 00000000000..e5c4fc027b5 --- /dev/null +++ b/paddle/phi/ops/compat/allclose_sig.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature AllCloseOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("Rtol")) { + if (ctx.HasInput("Atol")) { + return KernelSignature("allclose", + {"Input", "Other"}, + {"Rtol", "Atol", "equal_nan"}, + {"Out"}); + } else { + return KernelSignature("allclose", + {"Input", "Other"}, + {"Rtol", "atol", "equal_nan"}, + {"Out"}); + } + } else { + if (ctx.HasInput("Atol")) { + return KernelSignature("allclose", + {"Input", "Other"}, + {"rtol", "Atol", "equal_nan"}, + {"Out"}); + } else { + return KernelSignature("allclose", + {"Input", "Other"}, + {"rtol", "atol", "equal_nan"}, + {"Out"}); + } + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(allclose, phi::AllCloseOpArgumentMapping); diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index 88c9193a8f8..c74049e0f04 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -484,5 +484,33 @@ TEST(ARG_MAP, set_value) { "set_value"); } +TEST(ARG_MAP, allclose) { + TestArgumentMappingContext arg_case1( + {"Input", "Other", "Rtol"}, + {}, + {{"atol", paddle::any(std::string{"1e-8"})}, + {"equal_nan", paddle::any(false)}}, + {"Out"}, + {}); + auto signature1 = + OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1); + ASSERT_EQ(signature1.name, "allclose"); + auto attr_names1 = std::get<1>(signature1.args); + ASSERT_EQ(attr_names1[0], "Rtol"); + + TestArgumentMappingContext arg_case2( + {"Input", "Other", "Atol"}, + {}, + {{"rtol", paddle::any(std::string{"1e-5"})}, + {"equal_nan", paddle::any(false)}}, + {"Out"}, + {}); + auto signature2 = + OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2); + ASSERT_EQ(signature2.name, "allclose"); + auto attr_names2 = std::get<1>(signature2.args); + ASSERT_EQ(attr_names2[1], "Atol"); +} + } // namespace tests } // namespace phi -- GitLab