未验证 提交 76f87034 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move allclose op kernel into phi (#40469)

* move allclose kernel

* remove allclose op kernel

* fix coverage failed
上级 39de9b8a
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/allclose_op.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -23,41 +23,6 @@
namespace paddle {
namespace operators {
template <typename T>
struct GetTensorValue<platform::CPUDeviceContext, T> {
T operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
return *(tensor.data<T>());
}
};
template <typename T>
struct AllcloseFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
auto num = in.numel();
*out_data = true;
for (int i = 0; i < num; i++) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
*out_data &= val;
}
}
};
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -157,8 +122,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::AllcloseOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
ops::AllcloseKernel<CPU, double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(allclose)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct GetTensorValue {
T operator()(const platform::DeviceContext& ctx,
const framework::Tensor& tensor) const;
};
template <typename DeviceContext, typename T>
struct AllcloseFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& other, const float rtol,
const float atol, bool equal_nan, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
class AllcloseKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// get attrs
bool equal_nan = ctx.Attr<bool>("equal_nan");
// get input/output
const auto* input = ctx.Input<Tensor>("Input");
const auto* other = ctx.Input<Tensor>("Other");
auto* out = ctx.Output<Tensor>("Out");
double rtol_v = std::stod(ctx.Attr<std::string>("rtol"));
double atol_v = std::stod(ctx.Attr<std::string>("atol"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
GetTensorValue<DeviceContext, double> get_tensor_value;
if (ctx.HasInput("Rtol")) {
const auto* rtol = ctx.Input<Tensor>("Rtol");
PADDLE_ENFORCE_EQ(
rtol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Rtol) size must be 1, but get %d.", rtol->numel()));
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(rtol->dtype()),
framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Rtol) type must be double, but get %s.",
framework::DataTypeToString(
framework::TransToProtoVarType(rtol->dtype()))));
rtol_v = get_tensor_value(dev_ctx, *rtol);
}
if (ctx.HasInput("Atol")) {
const auto* atol = ctx.Input<Tensor>("Atol");
PADDLE_ENFORCE_EQ(
atol->numel(), 1,
platform::errors::InvalidArgument(
"Input(Atol) size must be 1, but get %d", atol->numel()));
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(atol->dtype()),
framework::proto::VarType::FP64,
platform::errors::InvalidArgument(
"Input(Atol) type must be double, but get %s",
framework::DataTypeToString(
framework::TransToProtoVarType(atol->dtype()))));
atol_v = get_tensor_value(dev_ctx, *atol);
}
AllcloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
equal_nan, out);
}
};
} // namespace operators
} // namespace paddle
......@@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
auto expected_place = phi::TransToPhiPlace(phi::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
const auto& tensor = variable.Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
tensor.numel(),
1UL,
platform::errors::InvalidArgument("The DenseTensor used to construct "
"the Scalar contains more than 1 "
"value, it contains `%d` values.",
tensor.numel()));
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AllCloseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& rtol,
const Scalar& atol,
bool equal_nan,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/allclose_kernel.h"
#include <cmath>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AllCloseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& rtol,
const Scalar& atol,
bool equal_nan,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
rtol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument(
"Input (Rtol) type must be double, but get %s.", rtol.dtype()));
PADDLE_ENFORCE_EQ(
atol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument(
"Input (Atol) type must be double, but get %s.", atol.dtype()));
auto* in_a = x.data<T>();
auto* in_b = y.data<T>();
auto rtol_v = rtol.to<double>();
auto atol_v = atol.to<double>();
auto* out_data = dev_ctx.template Alloc<bool>(out);
*out_data = true;
auto num = x.numel();
for (int64_t i = 0; i < num; ++i) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol_v + (b > 0 ? rtol_v * b : (-rtol_v) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
*out_data &= val;
}
}
} // namespace phi
PD_REGISTER_KERNEL(
allclose, CPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,30 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
#include "paddle/phi/kernels/allclose_kernel.h"
namespace paddle {
namespace operators {
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
template <typename T>
struct GetTensorValue<platform::CUDADeviceContext, T> {
T operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
const T* data = tensor.data<T>();
T value;
const auto gpu_place = dev_ctx.GetPlace();
memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T),
dev_ctx.stream());
return value;
}
};
namespace phi {
template <typename T>
__global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
const double rtol, const double atol,
bool equal_nan, int num, bool* out_data) {
__global__ void AllcloseCUDAKernel(const T* in_data,
const T* other_data,
const double rtol,
const double atol,
bool equal_nan,
int num,
bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
......@@ -52,33 +43,47 @@ __global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
}
}
template <typename T>
struct AllcloseFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
int num = in.numel();
const T* in_data = in.data<T>();
const T* other_data = other.data<T>();
bool* out_data = output->mutable_data<bool>(dev_ctx.GetPlace());
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
template <typename T, typename Context>
void AllCloseKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const Scalar& rtol,
const Scalar& atol,
bool equal_nan,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
rtol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument(
"Input (Rtol) type must be double, but get %s.", rtol.dtype()));
PADDLE_ENFORCE_EQ(
atol.dtype(),
DataType::FLOAT64,
phi::errors::InvalidArgument(
"Input (Atol) type must be double, but get %s.", atol.dtype()));
const T* in_data = x.data<T>();
const T* other_data = y.data<T>();
auto rtol_v = rtol.to<double>();
auto atol_v = atol.to<double>();
bool* out_data = dev_ctx.template Alloc<bool>(out);
int num = x.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
#ifdef PADDLE_WITH_HIP
hipMemset(out_data, true, sizeof(bool));
hipMemset(out_data, true, sizeof(bool));
#else
cudaMemset(out_data, true, sizeof(bool));
cudaMemset(out_data, true, sizeof(bool));
#endif
AllcloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, other_data, rtol, atol, equal_nan, num, out_data);
}
};
AllcloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, other_data, rtol_v, atol_v, equal_nan, num, out_data);
}
} // namespace operators
} // namespace paddle
} // namespace phi
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel<CUDA, float>,
ops::AllcloseKernel<CUDA, double>);
PD_REGISTER_KERNEL(
allclose, GPU, ALL_LAYOUT, phi::AllCloseKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AllCloseOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Rtol")) {
if (ctx.HasInput("Atol")) {
return KernelSignature("allclose",
{"Input", "Other"},
{"Rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("allclose",
{"Input", "Other"},
{"Rtol", "atol", "equal_nan"},
{"Out"});
}
} else {
if (ctx.HasInput("Atol")) {
return KernelSignature("allclose",
{"Input", "Other"},
{"rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("allclose",
{"Input", "Other"},
{"rtol", "atol", "equal_nan"},
{"Out"});
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(allclose, phi::AllCloseOpArgumentMapping);
......@@ -484,5 +484,33 @@ TEST(ARG_MAP, set_value) {
"set_value");
}
TEST(ARG_MAP, allclose) {
TestArgumentMappingContext arg_case1(
{"Input", "Other", "Rtol"},
{},
{{"atol", paddle::any(std::string{"1e-8"})},
{"equal_nan", paddle::any(false)}},
{"Out"},
{});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case1);
ASSERT_EQ(signature1.name, "allclose");
auto attr_names1 = std::get<1>(signature1.args);
ASSERT_EQ(attr_names1[0], "Rtol");
TestArgumentMappingContext arg_case2(
{"Input", "Other", "Atol"},
{},
{{"rtol", paddle::any(std::string{"1e-5"})},
{"equal_nan", paddle::any(false)}},
{"Out"},
{});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("allclose")(arg_case2);
ASSERT_EQ(signature2.name, "allclose");
auto attr_names2 = std::get<1>(signature2.args);
ASSERT_EQ(attr_names2[1], "Atol");
}
} // namespace tests
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册