未验证 提交 d4668938 编写于 作者: H huangxu96 提交者: GitHub

Allclose op (#27891)

* Still has bugs.

* Fixed allclose_op bug, which cannot deal with some cases of fp64 inputs.

* improved CUDA kernel performance.

* Changed CUDA code.

* Fixed a bug in cuda kernel which cannot deal with large dimension input, and added an unittest for it.

* Add a test case for float32 input.
上级 975bd887
...@@ -13,12 +13,49 @@ ...@@ -13,12 +13,49 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/allclose_op.h" #include "paddle/fluid/operators/allclose_op.h"
#include <cmath>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct GetTensorValue<platform::CPUDeviceContext, T> {
T operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
return *(tensor.data<T>());
}
};
template <typename T>
struct AllcloseFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
auto* in_a = in.data<T>();
auto* in_b = other.data<T>();
auto* out_data = output->mutable_data<bool>(ctx.GetPlace());
auto num = in.numel();
*out_data = true;
for (int i = 0; i < num; i++) {
const T a = in_a[i], b = in_b[i];
bool val;
if (std::isnan(a) || std::isnan(b)) {
val = equal_nan && std::isnan(a) == std::isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
*out_data &= val;
}
}
};
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -26,12 +63,9 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -26,12 +63,9 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor, it's data type should be float32, float64."); "The input tensor, it's data type should be float32, float64.");
AddInput("Other", AddInput("Other",
"The input tensor, it's data type should be float32, float64."); "The input tensor, it's data type should be float32, float64.");
AddInput("Rtol", "The relative tolerance.");
AddInput("Atol", "The absolute tolerance.");
AddOutput("Out", "The output tensor, it's data type is bool."); AddOutput("Out", "The output tensor, it's data type is bool.");
AddAttr<float>("rtol", "The relative tolerance. Default: :math:`1e-5` .")
.SetDefault(1e-5);
AddAttr<float>("atol", "The absolute tolerance. Default: :math:`1e-8` .")
.SetDefault(1e-8);
AddAttr<bool>("equal_nan", AddAttr<bool>("equal_nan",
"If :math:`True` , then two :math:`NaNs` will be " "If :math:`True` , then two :math:`NaNs` will be "
"compared as equal. Default: :math:`False` .") "compared as equal. Default: :math:`False` .")
...@@ -54,16 +88,12 @@ class AllcloseOp : public framework::OperatorWithKernel { ...@@ -54,16 +88,12 @@ class AllcloseOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose");
platform::errors::NotFound( OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose");
"Input(Input) of allclose op should not be null.")); OP_INOUT_CHECK(ctx->HasInput("Rtol"), "Input", "Rtol", "Allclose");
PADDLE_ENFORCE_EQ(ctx->HasInput("Other"), true, OP_INOUT_CHECK(ctx->HasInput("Atol"), "Input", "Atol", "Allclose");
platform::errors::NotFound( OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose");
"Input(Other) of allclose op should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"The output(Out) of allclose op must not be null."));
auto input_dim = ctx->GetInputDim("Input"); auto input_dim = ctx->GetInputDim("Input");
auto other_dim = ctx->GetInputDim("Other"); auto other_dim = ctx->GetInputDim("Other");
...@@ -96,7 +126,7 @@ class AllcloseOp : public framework::OperatorWithKernel { ...@@ -96,7 +126,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context()); ctx.device_context());
...@@ -105,7 +135,7 @@ class AllcloseOp : public framework::OperatorWithKernel { ...@@ -105,7 +135,7 @@ class AllcloseOp : public framework::OperatorWithKernel {
class AllcloseOpVarTypeInference : public framework::VarTypeInference { class AllcloseOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext* ctx) const override {
ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL);
} }
}; };
......
...@@ -12,12 +12,70 @@ ...@@ -12,12 +12,70 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#define EIGEN_USE_GPU #include <cuda_runtime.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h" #include "paddle/fluid/operators/allclose_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct GetTensorValue<platform::CUDADeviceContext, T> {
T operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& tensor) const {
const T* data = tensor.data<T>();
T value;
const auto gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T),
dev_ctx.stream());
return value;
}
};
template <typename T>
__global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
const double rtol, const double atol,
bool equal_nan, int num, bool* out_data) {
unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x;
bool val;
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
const T a = in_data[i], b = other_data[i];
if (isnan(a) || isnan(b)) {
val = equal_nan && isnan(a) == isnan(b);
} else {
T left = (a > b ? a - b : b - a);
T right = atol + (b > 0 ? rtol * b : (-rtol) * b);
T diff = (left > right ? left - right : right - left);
val = a == b || left <= right || diff <= 1e-15;
}
if (!val) *out_data = false;
}
}
template <typename T>
struct AllcloseFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor& in, const framework::Tensor& other,
const double rtol, const double atol, bool equal_nan,
framework::Tensor* output) {
int num = in.numel();
const T* in_data = in.data<T>();
const T* other_data = other.data<T>();
bool* out_data = output->mutable_data<bool>(dev_ctx.GetPlace());
int block = 1024;
int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid;
cudaMemset(out_data, true, sizeof(bool));
AllcloseCUDAKernel<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, other_data, rtol, atol, equal_nan, num, out_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel<CUDA, float>, REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel<CUDA, float>,
......
...@@ -22,38 +22,38 @@ namespace paddle { ...@@ -22,38 +22,38 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
struct GetTensorValue {
T operator()(const platform::DeviceContext& ctx,
const framework::Tensor& tensor) const;
};
template <typename DeviceContext, typename T>
struct AllcloseFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& other, const float rtol,
const float atol, bool equal_nan, framework::Tensor* output);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AllcloseKernel : public framework::OpKernel<T> { class AllcloseKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// get attrs // get attrs
float rtol = ctx.Attr<float>("rtol");
float atol = ctx.Attr<float>("atol");
bool equal_nan = ctx.Attr<bool>("equal_nan"); bool equal_nan = ctx.Attr<bool>("equal_nan");
// get input/output // get input/output
auto* input = ctx.Input<Tensor>("Input"); const auto* input = ctx.Input<Tensor>("Input");
auto* other = ctx.Input<Tensor>("Other"); const auto* other = ctx.Input<Tensor>("Other");
const auto* rtol = ctx.Input<Tensor>("Rtol");
const auto* atol = ctx.Input<Tensor>("Atol");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<bool>(ctx.GetPlace()); auto& dev_ctx = ctx.template device_context<DeviceContext>();
// get place
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); GetTensorValue<DeviceContext, double> get_tensor_value;
double rtol_v = get_tensor_value(dev_ctx, *rtol);
auto input_v = framework::EigenVector<T>::Flatten(*input); double atol_v = get_tensor_value(dev_ctx, *atol);
auto other_v = framework::EigenVector<T>::Flatten(*other); AllcloseFunctor<DeviceContext, T>()(dev_ctx, *input, *other, rtol_v, atol_v,
auto out_v = framework::EigenScalar<bool>::From(*out); equal_nan, out);
auto left = (input_v - other_v).abs();
auto right = static_cast<T>(atol) + static_cast<T>(rtol) * other_v.abs();
auto compare_res = left <= right;
if (equal_nan) {
auto input_nan = input_v.isnan();
auto other_nan = other_v.isnan();
out_v.device(place) =
(input_nan == other_nan).all() && (compare_res != input_nan).all();
} else {
out_v.device(place) = compare_res.all();
}
} }
}; };
......
...@@ -22,19 +22,20 @@ class TestAllcloseOp(OpTest): ...@@ -22,19 +22,20 @@ class TestAllcloseOp(OpTest):
def set_args(self): def set_args(self):
self.input = np.array([10000., 1e-07]).astype("float32") self.input = np.array([10000., 1e-07]).astype("float32")
self.other = np.array([10000.1, 1e-08]).astype("float32") self.other = np.array([10000.1, 1e-08]).astype("float32")
self.rtol = 1e-05 self.rtol = np.array([1e-05]).astype("float64")
self.atol = 1e-08 self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False self.equal_nan = False
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.op_type = "allclose" self.op_type = "allclose"
self.inputs = {'Input': self.input, 'Other': self.other} self.inputs = {
self.attrs = { 'Input': self.input,
'rtol': self.rtol, 'Other': self.other,
'atol': self.atol, "Rtol": self.rtol,
'equal_nan': self.equal_nan "Atol": self.atol
} }
self.attrs = {'equal_nan': self.equal_nan}
self.outputs = { self.outputs = {
'Out': np.array([ 'Out': np.array([
np.allclose( np.allclose(
...@@ -54,8 +55,8 @@ class TestAllcloseOpSmallNum(TestAllcloseOp): ...@@ -54,8 +55,8 @@ class TestAllcloseOpSmallNum(TestAllcloseOp):
def set_args(self): def set_args(self):
self.input = np.array([10000., 1e-08]).astype("float32") self.input = np.array([10000., 1e-08]).astype("float32")
self.other = np.array([10000.1, 1e-09]).astype("float32") self.other = np.array([10000.1, 1e-09]).astype("float32")
self.rtol = 1e-05 self.rtol = np.array([1e-05]).astype("float64")
self.atol = 1e-08 self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False self.equal_nan = False
...@@ -63,8 +64,8 @@ class TestAllcloseOpNanFalse(TestAllcloseOp): ...@@ -63,8 +64,8 @@ class TestAllcloseOpNanFalse(TestAllcloseOp):
def set_args(self): def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32") self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32") self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = 1e-05 self.rtol = np.array([1e-05]).astype("float64")
self.atol = 1e-08 self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False self.equal_nan = False
...@@ -72,8 +73,8 @@ class TestAllcloseOpNanTrue(TestAllcloseOp): ...@@ -72,8 +73,8 @@ class TestAllcloseOpNanTrue(TestAllcloseOp):
def set_args(self): def set_args(self):
self.input = np.array([1.0, float('nan')]).astype("float32") self.input = np.array([1.0, float('nan')]).astype("float32")
self.other = np.array([1.0, float('nan')]).astype("float32") self.other = np.array([1.0, float('nan')]).astype("float32")
self.rtol = 1e-05 self.rtol = np.array([1e-05]).astype("float64")
self.atol = 1e-08 self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = True self.equal_nan = True
...@@ -130,5 +131,33 @@ class TestAllcloseError(unittest.TestCase): ...@@ -130,5 +131,33 @@ class TestAllcloseError(unittest.TestCase):
self.assertRaises(TypeError, test_equal_nan) self.assertRaises(TypeError, test_equal_nan)
class TestAllcloseOpFloat32(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float32")
self.other = np.array([10]).astype("float32")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
class TestAllcloseOpFloat64(TestAllcloseOp):
def set_args(self):
self.input = np.array([10.1]).astype("float64")
self.other = np.array([10]).astype("float64")
self.rtol = np.array([0.01]).astype("float64")
self.atol = np.array([0]).astype("float64")
self.equal_nan = False
class TestAllcloseOpLargeDimInput(TestAllcloseOp):
def set_args(self):
self.input = np.array(np.zeros([2048, 1024])).astype("float64")
self.other = np.array(np.zeros([2048, 1024])).astype("float64")
self.input[-1][-1] = 100
self.rtol = np.array([1e-05]).astype("float64")
self.atol = np.array([1e-08]).astype("float64")
self.equal_nan = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import to_tensor
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.data_feeder import check_type, check_variable_and_dtype
from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layers.layer_function_generator import templatedoc
...@@ -95,8 +96,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -95,8 +96,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
Args: Args:
x(Tensor): ${input_comment}. x(Tensor): ${input_comment}.
y(Tensor): ${other_comment}. y(Tensor): ${other_comment}.
rtol(rtoltype, optional): ${rtol_comment}. rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` .
atol(atoltype, optional): ${atol_comment}. atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` .
equal_nan(equalnantype, optional): ${equal_nan_comment}. equal_nan(equalnantype, optional): ${equal_nan_comment}.
name (str, optional): Name for the operation. For more information, please name (str, optional): Name for the operation. For more information, please
refer to :ref:`api_guide_Name`. Default: None. refer to :ref:`api_guide_Name`. Default: None.
...@@ -142,7 +143,9 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -142,7 +143,9 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.allclose(x, y, 'rtol', rtol, 'atol', atol, 'equal_nan', rtol_tensor = to_tensor(rtol, dtype='float64')
atol_tensor = to_tensor(atol, dtype='float64')
return core.ops.allclose(x, y, rtol_tensor, atol_tensor, 'equal_nan',
equal_nan) equal_nan)
check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose')
...@@ -152,11 +155,26 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -152,11 +155,26 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
check_type(equal_nan, 'equal_nan', bool, 'allclose') check_type(equal_nan, 'equal_nan', bool, 'allclose')
helper = LayerHelper("allclose", **locals()) helper = LayerHelper("allclose", **locals())
rtol_var = helper.create_global_variable(
name=fluid.unique_name.generate('rtol'),
persistable=True,
dtype='float64',
shape=[1])
helper.set_variable_initializer(
rtol_var, initializer=fluid.initializer.ConstantInitializer(rtol))
atol_var = helper.create_variable(
name=fluid.unique_name.generate('atol'),
persistable=True,
dtype='float64',
shape=[1])
helper.set_variable_initializer(
atol_var, initializer=fluid.initializer.ConstantInitializer(atol))
out = helper.create_variable_for_type_inference(dtype='bool') out = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'Input': x, 'Other': y} inputs = {'Input': x, 'Other': y, 'Rtol': rtol_var, 'Atol': atol_var}
outputs = {'Out': out} outputs = {'Out': out}
attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan} attrs = {'equal_nan': equal_nan}
helper.append_op( helper.append_op(
type='allclose', inputs=inputs, outputs=outputs, attrs=attrs) type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册