From 50778ad63557e1049bd68a04c6c83bf7601170ef Mon Sep 17 00:00:00 2001 From: TTerror Date: Mon, 25 Oct 2021 11:02:02 +0800 Subject: [PATCH] add some ops to train ssd on kunlun (#36407) * add some ops to train ssd on kunlun * add some ops to train ssd on kunlun * add some ops to train ssd on kunlun * update cast op unittest * update cast op unittest * update cast op unittest * update xpu cmake * update cast unittest --- cmake/external/xpu.cmake | 2 +- cmake/operators.cmake | 2 +- paddle/fluid/operators/cast_op_xpu.cc | 80 +++--- paddle/fluid/operators/clip_op_xpu.cc | 78 +++++ .../operators/controlflow/CMakeLists.txt | 6 + .../operators/controlflow/compare_op_xpu.cc | 145 ++++++++++ paddle/fluid/operators/stack_op_xpu.cc | 2 + paddle/fluid/platform/xpu/xpu2_op_list.h | 29 ++ .../tests/unittests/xpu/test_cast_op_xpu.py | 93 +++--- .../tests/unittests/xpu/test_clip_op_xpu.py | 216 ++++++++++++++ .../unittests/xpu/test_compare_op_xpu.py | 272 ++++++++++++++++++ .../tests/unittests/xpu/test_stack_op_xpu.py | 22 ++ 12 files changed, 847 insertions(+), 100 deletions(-) create mode 100644 paddle/fluid/operators/clip_op_xpu.cc create mode 100644 paddle/fluid/operators/controlflow/compare_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 70bdc67980..11a7adbbeb 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -35,7 +35,7 @@ ELSE () ENDIF() SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") -SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210921") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20211020") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 24c7d3f07f..7830cf7b50 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -299,7 +299,7 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") endif() - if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0) + if (WITH_XPU AND ${pybind_flag} EQUAL 0 AND ${xpu_cc_srcs_len} GREATER 0) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n") endif() diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index c7c0f81f21..c1a296f2b2 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -23,6 +23,9 @@ limitations under the License. */ namespace paddle { namespace operators { +using var_type = framework::proto::VarType; +namespace plat = paddle::platform; + template class CastXPUKernel : public framework::OpKernel { using XPUInTDType = typename XPUTypeTrait::Type; @@ -31,53 +34,49 @@ class CastXPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - auto in_type = static_cast( - context.Attr("in_dtype")); - auto out_type = static_cast( - context.Attr("out_dtype")); + auto in_type = static_cast(context.Attr("in_dtype")); + auto out_type = static_cast(context.Attr("out_dtype")); auto* in_data = in->data(); auto numel = in->numel(); auto& dev_ctx = context.template device_context(); int r = -1; - if (out_type == framework::proto::VarType::FP32) { - auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out_data, numel); - } else if (out_type == framework::proto::VarType::INT32) { - auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out_data, numel); - } else if (out_type == framework::proto::VarType::INT64) { - auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out_data, numel); - } else if ((out_type == framework::proto::VarType::BOOL) && - (in_type == framework::proto::VarType::FP32)) { - auto* out_data = out->mutable_data(context.GetPlace()); - r = xpu::cast_v2( - dev_ctx.x_context(), (const float*)in_data, - reinterpret_cast(out_data), numel); - } else if (out_type == framework::proto::VarType::FP16) { - auto* out_data = - out->mutable_data(context.GetPlace()); - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - reinterpret_cast(out_data), numel); - - } else { - PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d", - in_type, out_type)); + switch (out_type) { + case var_type::FP32: + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out->mutable_data(context.GetPlace()), numel); + break; + case var_type::FP16: + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + reinterpret_cast( + out->mutable_data(context.GetPlace())), + numel); + break; + case var_type::INT64: + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out->mutable_data(context.GetPlace()), numel); + break; + case var_type::INT32: + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out->mutable_data(context.GetPlace()), numel); + break; + case var_type::BOOL: + r = xpu::cast_v2( + dev_ctx.x_context(), reinterpret_cast(in_data), + out->mutable_data(context.GetPlace()), numel); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "Not supported cast %d -> %d", in_type, out_type)); } PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, - platform::errors::External( - "XPU API return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + platform::errors::External("XPU CAST API return wrong value[%d %s].", r, + XPUAPIErrorMsg[r])); } }; @@ -90,5 +89,6 @@ REGISTER_OP_XPU_KERNEL( ops::CastXPUKernel, ops::CastXPUKernel, - ops::CastXPUKernel); + ops::CastXPUKernel, + ops::CastXPUKernel); #endif diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc new file mode 100644 index 0000000000..7d4b02af41 --- /dev/null +++ b/paddle/fluid/operators/clip_op_xpu.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ClipXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto max = static_cast(ctx.Attr("max")); + if (ctx.HasInput("Max")) { + Tensor max_cpu; + auto* max_t = ctx.Input("Max"); + auto* max_data = max_t->data(); + if (platform::is_xpu_place(max_t->place())) { + TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu); + max_data = max_cpu.data(); + } + max = max_data[0]; + } + + auto min = ctx.Attr("min"); + if (ctx.HasInput("Min")) { + Tensor min_cpu; + auto* min_t = ctx.Input("Min"); + auto* min_data = min_t->data(); + if (platform::is_xpu_place(min_t->place())) { + TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu); + min_data = min_cpu.data(); + } + min = min_data[0]; + } + + using XPUDataType = typename XPUTypeTrait::Type; + auto& dev_ctx = ctx.template device_context(); + auto x_data = reinterpret_cast(x->data()); + auto out_data = reinterpret_cast(out->data()); + int r = xpu::clip_v2(dev_ctx.x_context(), x_data, out_data, x->numel(), min, + max); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(clip_v2) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(clip, ops::ClipXPUKernel); + +#endif diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 1a2df2a0c7..d2ad93bbae 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -22,3 +22,9 @@ endif() file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n") file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") + +if(WITH_XPU) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(equal, XPU);\nUSE_OP_DEVICE_KERNEL(not_equal, XPU);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(less_than, XPU);\nUSE_OP_DEVICE_KERNEL(less_equal, XPU);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(greater_than, XPU);\nUSE_OP_DEVICE_KERNEL(greater_equal, XPU);\n") +endif() diff --git a/paddle/fluid/operators/controlflow/compare_op_xpu.cc b/paddle/fluid/operators/controlflow/compare_op_xpu.cc new file mode 100644 index 0000000000..59e457caa1 --- /dev/null +++ b/paddle/fluid/operators/controlflow/compare_op_xpu.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/controlflow/compare_op.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +template +void XPUCompare( + const framework::ExecutionContext& ctx, + std::function&, const std::vector&)> + func) { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + auto x_shape = framework::vectorize(x->dims()); + auto y_shape = framework::vectorize(y->dims()); + + auto x_data = reinterpret_cast(x->data()); + auto y_data = reinterpret_cast(y->data()); + auto z_data = z->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = + ctx.template device_context(); + + int ret = func(dev_ctx.x_context(), x_data, y_data, z_data, x_shape, y_shape); + PADDLE_ENFORCE_EQ( + ret, xpu::SUCCESS, + platform::errors::External( + "XPU kernel compare op occur error[%d %s] in XPUCompare.", ret, + XPUAPIErrorMsg[ret])); +} + +template +class EqualXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_equal); + } +}; + +template +class NotEqualXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_not_equal); + } +}; + +template +class LessThanXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_less_than); + } +}; + +template +class LessEqualXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_less_equal); + } +}; + +template +class GreaterThanXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_greater_than); + } +}; + +template +class GreaterEqualXPUKernel : public framework::OpKernel { + using XPUType = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + XPUCompare(ctx, xpu::broadcast_greater_equal); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(equal, + ops::EqualXPUKernel, + ops::EqualXPUKernel, + ops::EqualXPUKernel); + +REGISTER_OP_XPU_KERNEL(not_equal, + ops::NotEqualXPUKernel, + ops::NotEqualXPUKernel, + ops::NotEqualXPUKernel); + +REGISTER_OP_XPU_KERNEL(less_than, + ops::LessThanXPUKernel, + ops::LessThanXPUKernel, + ops::LessThanXPUKernel); + +REGISTER_OP_XPU_KERNEL( + less_equal, ops::LessEqualXPUKernel, + ops::LessEqualXPUKernel, + ops::LessEqualXPUKernel); + +REGISTER_OP_XPU_KERNEL( + greater_than, ops::GreaterThanXPUKernel, + ops::GreaterThanXPUKernel, + ops::GreaterThanXPUKernel); + +REGISTER_OP_XPU_KERNEL( + greater_equal, ops::GreaterEqualXPUKernel, + ops::GreaterEqualXPUKernel, + ops::GreaterEqualXPUKernel); + +#endif diff --git a/paddle/fluid/operators/stack_op_xpu.cc b/paddle/fluid/operators/stack_op_xpu.cc index 9929df6e30..01ec4a2b16 100644 --- a/paddle/fluid/operators/stack_op_xpu.cc +++ b/paddle/fluid/operators/stack_op_xpu.cc @@ -66,5 +66,7 @@ namespace plat = paddle::platform; namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL(stack, + ops::StackXPUKernel, + ops::StackXPUKernel, ops::StackXPUKernel); #endif diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index 0a9a9453b5..121d26e39d 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -119,6 +119,35 @@ XPUOpMap& get_kl2_ops() { {"slice_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"not_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"less_than", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"less_equal", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"greater_than", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"greater_equal", + XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::FP32, XPUPlace())})}, + {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"stack", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"cast", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::BOOL, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"fill_any_like", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index f1ba8828f2..1633d82772 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -16,71 +16,48 @@ from __future__ import print_function import sys sys.path.append("..") -import op_test import unittest +import op_test import numpy as np import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler, Program, program_guard - -class TestCastOp1(op_test.OpTest): - def setUp(self): - ipt = np.random.random(size=[10, 10]) - self.inputs = {'X': ipt.astype('float32')} - self.outputs = {'Out': ipt.astype('float32')} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP32), - 'out_dtype': int(core.VarDesc.VarType.FP32) - } - self.op_type = 'cast' - - def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_grad(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_grad_with_place(place, ['X'], ['Out']) - - -class TestCastOp2(op_test.OpTest): - def setUp(self): - ipt = np.random.random(size=[10, 10]) - self.inputs = {'X': ipt.astype('float32')} - self.outputs = {'Out': ipt.astype('float16')} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP32), - 'out_dtype': int(core.VarDesc.VarType.FP16) - } - self.op_type = 'cast' - - def test_check_output(self): - #self.check_output(atol=1e-3) - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place, atol=1e-3) - - -class TestCastOp3(op_test.OpTest): - def setUp(self): - ipt = np.random.random(size=[10, 10]) - self.inputs = {'X': ipt.astype('float16')} - self.outputs = {'Out': ipt.astype('float32')} - self.attrs = { - 'in_dtype': int(core.VarDesc.VarType.FP16), - 'out_dtype': int(core.VarDesc.VarType.FP32) - } - self.op_type = 'cast' - - def test_check_output(self): - #self.check_output(atol=1e-3) - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place, atol=1e-3) +typeid_dict = { + 'int32': int(core.VarDesc.VarType.INT32), + 'int64': int(core.VarDesc.VarType.INT64), + 'float32': int(core.VarDesc.VarType.FP32), + 'float16': int(core.VarDesc.VarType.FP16), + 'bool': int(core.VarDesc.VarType.BOOL), +} + + +def create_test_class(in_typename, out_typename): + class Cls(op_test.OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype(in_typename)} + self.outputs = {'Out': ipt.astype(in_typename).astype(out_typename)} + self.attrs = { + 'in_dtype': typeid_dict[in_typename], + 'out_dtype': typeid_dict[out_typename], + } + self.op_type = 'cast' + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + cls_name = "cast_{0}_{1}".format(in_typename, out_typename) + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +for in_type in {'float16', 'float32', 'int32', 'int64', 'bool'}: + for out_type in {'float16', 'float32', 'int32', 'int64'}: + create_test_class(in_type, out_type) class TestCastOpError(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py new file mode 100644 index 0000000000..6c58c7ccf2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_clip_op_xpu.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test_xpu import OpTest, XPUOpTest +import paddle +from paddle.fluid import Program, program_guard + + +class TestClipOp(XPUOpTest): + def set_xpu(self): + self.__class__.use_xpu = True + self.place = paddle.XPUPlace(0) + + def setUp(self): + self.set_xpu() + self.max_relative_error = 0.006 + + self.inputs = {} + self.initTestCase() + + self.op_type = "clip" + self.attrs = {} + self.attrs['min'] = self.min + self.attrs['max'] = self.max + if 'Min' in self.inputs: + min_v = self.inputs['Min'] + else: + min_v = self.attrs['min'] + + if 'Max' in self.inputs: + max_v = self.inputs['Max'] + else: + max_v = self.attrs['max'] + + input = np.random.random(self.shape).astype("float32") + input[np.abs(input - min_v) < self.max_relative_error] = 0.5 + input[np.abs(input - max_v) < self.max_relative_error] = 0.5 + self.inputs['X'] = input + self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)} + + def test_check_output(self): + paddle.enable_static() + self.check_output_with_place(self.place) + paddle.disable_static() + + def test_check_grad_normal(self): + paddle.enable_static() + self.check_grad_with_place(self.place, ['X'], 'Out') + paddle.disable_static() + + def initTestCase(self): + self.shape = (4, 10, 10) + self.max = 0.8 + self.min = 0.3 + self.inputs['Max'] = np.array([0.8]).astype('float32') + self.inputs['Min'] = np.array([0.1]).astype('float32') + + +class TestCase1(TestClipOp): + def initTestCase(self): + self.shape = (8, 16, 8) + self.max = 0.7 + self.min = 0.0 + + +class TestCase2(TestClipOp): + def initTestCase(self): + self.shape = (8, 16) + self.max = 1.0 + self.min = 0.0 + + +class TestCase3(TestClipOp): + def initTestCase(self): + self.shape = (4, 8, 16) + self.max = 0.7 + self.min = 0.2 + + +class TestCase4(TestClipOp): + def initTestCase(self): + self.shape = (4, 8, 8) + self.max = 0.7 + self.min = 0.2 + self.inputs['Max'] = np.array([0.8]).astype('float32') + self.inputs['Min'] = np.array([0.3]).astype('float32') + + +class TestCase5(TestClipOp): + def initTestCase(self): + self.shape = (4, 8, 16) + self.max = 0.5 + self.min = 0.5 + + +class TestClipOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + input_data = np.random.random((2, 4)).astype("float32") + + def test_Variable(): + fluid.layers.clip(x=input_data, min=-1.0, max=1.0) + + self.assertRaises(TypeError, test_Variable) + + def test_dtype(): + x2 = fluid.layers.data(name='x2', shape=[1], dtype='int32') + fluid.layers.clip(x=x2, min=-1.0, max=1.0) + + self.assertRaises(TypeError, test_dtype) + paddle.disable_static() + + +class TestClipAPI(unittest.TestCase): + def _executed_api(self, x, min=None, max=None): + return paddle.clip(x, min, max) + + def test_clip(self): + paddle.enable_static() + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + images = fluid.data(name='image', shape=data_shape, dtype='float32') + min = fluid.data(name='min', shape=[1], dtype='float32') + max = fluid.data(name='max', shape=[1], dtype='float32') + + place = fluid.XPUPlace(0) if fluid.core.is_compiled_with_xpu( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + out_1 = self._executed_api(images, min=min, max=max) + out_2 = self._executed_api(images, min=0.2, max=0.9) + out_3 = self._executed_api(images, min=0.3) + out_4 = self._executed_api(images, max=0.7) + out_5 = self._executed_api(images, min=min) + out_6 = self._executed_api(images, max=max) + out_7 = self._executed_api(images, max=-1.) + out_8 = self._executed_api(images) + + res1, res2, res3, res4, res5, res6, res7, res8 = exe.run( + fluid.default_main_program(), + feed={ + "image": data, + "min": np.array([0.2]).astype('float32'), + "max": np.array([0.8]).astype('float32') + }, + fetch_list=[ + out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8 + ]) + + self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8))) + self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9))) + self.assertTrue(np.allclose(res3, data.clip(min=0.3))) + self.assertTrue(np.allclose(res4, data.clip(max=0.7))) + self.assertTrue(np.allclose(res5, data.clip(min=0.2))) + self.assertTrue(np.allclose(res6, data.clip(max=0.8))) + self.assertTrue(np.allclose(res7, data.clip(max=-1))) + self.assertTrue(np.allclose(res8, data)) + paddle.disable_static() + + def test_clip_dygraph(self): + paddle.disable_static() + place = fluid.XPUPlace(0) if fluid.core.is_compiled_with_xpu( + ) else fluid.CPUPlace() + paddle.disable_static(place) + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + images = paddle.to_tensor(data, dtype='float32') + v_min = paddle.to_tensor(np.array([0.2], dtype=np.float32)) + v_max = paddle.to_tensor(np.array([0.8], dtype=np.float32)) + + out_1 = self._executed_api(images, min=0.2, max=0.8) + images = paddle.to_tensor(data, dtype='float32') + out_2 = self._executed_api(images, min=0.2, max=0.9) + images = paddle.to_tensor(data, dtype='float32') + out_3 = self._executed_api(images, min=v_min, max=v_max) + + self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8))) + self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9))) + self.assertTrue(np.allclose(out_3.numpy(), data.clip(0.2, 0.8))) + + def test_errors(self): + paddle.enable_static() + x1 = fluid.data(name='x1', shape=[1], dtype="int16") + x2 = fluid.data(name='x2', shape=[1], dtype="int8") + self.assertRaises(TypeError, paddle.clip, x=x1, min=0.2, max=0.8) + self.assertRaises(TypeError, paddle.clip, x=x2, min=0.2, max=0.8) + paddle.disable_static() + + +class TestInplaceClipAPI(TestClipAPI): + def _executed_api(self, x, min=None, max=None): + return x.clip_(min, max) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py new file mode 100644 index 0000000000..5496c53a42 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_compare_op_xpu.py @@ -0,0 +1,272 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid +from op_test_xpu import OpTest, XPUOpTest +import paddle +from paddle.fluid import Program, program_guard + + +def create_test_class(op_type, typename, callback): + class Cls(OpTest): + def setUp(self): + a = np.random.random(size=(10, 7)).astype(typename) + b = np.random.random(size=(10, 7)).astype(typename) + c = callback(a, b) + self.inputs = {'X': a, 'Y': b} + self.outputs = {'Out': c} + self.op_type = op_type + self.use_xpu = True + self.attrs = {'use_xpu': True} + + def test_check_output(self): + paddle.enable_static() + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[2], dtype='int32') + y = fluid.layers.data(name='y', shape=[2], dtype='int32') + a = fluid.layers.data(name='a', shape=[2], dtype='int16') + if self.op_type == "less_than": + self.assertRaises( + TypeError, + fluid.layers.less_than, + x=x, + y=y, + force_cpu=1) + op = eval("fluid.layers.%s" % self.op_type) + self.assertRaises(TypeError, op, x=x, y=y, cond=1) + self.assertRaises(TypeError, op, x=x, y=a) + self.assertRaises(TypeError, op, x=a, y=y) + + cls_name = "{0}_{1}".format(op_type, typename) + Cls.__name__ = cls_name + globals()[cls_name] = Cls + + +for _type_name in {'float32', 'int32', 'int64'}: + if _type_name == 'float64' and core.is_compiled_with_rocm(): + _type_name = 'float32' + + create_test_class('less_than', _type_name, lambda _a, _b: _a < _b) + create_test_class('less_equal', _type_name, lambda _a, _b: _a <= _b) + create_test_class('greater_than', _type_name, lambda _a, _b: _a > _b) + create_test_class('greater_equal', _type_name, lambda _a, _b: _a >= _b) + create_test_class('equal', _type_name, lambda _a, _b: _a == _b) + create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) + + +def create_paddle_case(op_type, callback): + class PaddleCls(unittest.TestCase): + def setUp(self): + self.op_type = op_type + self.input_x = np.array([1, 2, 3, 4]).astype(np.int64) + self.input_y = np.array([1, 3, 2, 4]).astype(np.int64) + self.real_result = callback(self.input_x, self.input_y) + self.place = fluid.XPUPlace(0) if fluid.core.is_compiled_with_xpu( + ) else fluid.CPUPlace() + + def test_api(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.data(name='x', shape=[4], dtype='int64') + y = fluid.data(name='y', shape=[4], dtype='int64') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = fluid.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[out]) + self.assertEqual((res == self.real_result).all(), True) + + def test_api_float(self): + if self.op_type == "equal": + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.data(name='x', shape=[4], dtype='int64') + y = fluid.data(name='y', shape=[1], dtype='int64') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = fluid.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": 1.0}, + fetch_list=[out]) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((res == self.real_result).all(), True) + + def test_dynamic_api(self): + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_dynamic_api_int(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_dynamic_api_float(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, 1.0) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_assert(self): + def test_dynamic_api_string(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, "1.0") + paddle.enable_static() + + self.assertRaises(TypeError, test_dynamic_api_string) + + def test_dynamic_api_bool(self): + if self.op_type == "equal": + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + op = eval("paddle.%s" % (self.op_type)) + out = op(x, True) + self.real_result = np.array([1, 0, 0, 0]).astype(np.int64) + self.assertEqual((out.numpy() == self.real_result).all(), True) + paddle.enable_static() + + def test_broadcast_api_1(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data( + name='x', shape=[1, 2, 1, 3], dtype='int32') + y = paddle.static.data(name='y', shape=[1, 2, 3], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32) + input_y = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_broadcast_api_2(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[1, 2, 3], dtype='int32') + y = paddle.static.data( + name='y', shape=[1, 2, 1, 3], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 6).reshape((1, 2, 3)).astype(np.int32) + input_y = np.arange(1, 7).reshape((1, 2, 1, 3)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_broadcast_api_3(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[5], dtype='int32') + y = paddle.static.data(name='y', shape=[3, 1], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.arange(0, 5).reshape((5)).astype(np.int32) + input_y = np.array([5, 3, 2]).reshape((3, 1)).astype(np.int32) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_bool_api_4(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') + y = paddle.static.data(name='y', shape=[3, 1], dtype='bool') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.array([True, False, True]).astype(np.bool) + input_y = np.array([True, True, False]).astype(np.bool) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_bool_broadcast_api_4(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = paddle.static.data(name='x', shape=[3, 1], dtype='bool') + y = paddle.static.data(name='y', shape=[1], dtype='bool') + op = eval("paddle.%s" % (self.op_type)) + out = op(x, y) + exe = paddle.static.Executor(self.place) + input_x = np.array([True, False, True]).astype(np.bool) + input_y = np.array([True]).astype(np.bool) + real_result = callback(input_x, input_y) + res, = exe.run(feed={"x": input_x, + "y": input_y}, + fetch_list=[out]) + self.assertEqual((res == real_result).all(), True) + + def test_attr_name(self): + paddle.enable_static() + with program_guard(Program(), Program()): + x = fluid.layers.data(name='x', shape=[4], dtype='int32') + y = fluid.layers.data(name='y', shape=[4], dtype='int32') + op = eval("paddle.%s" % (self.op_type)) + out = op(x=x, y=y, name="name_%s" % (self.op_type)) + self.assertEqual("name_%s" % (self.op_type) in out.name, True) + + cls_name = "TestCase_{}".format(op_type) + PaddleCls.__name__ = cls_name + globals()[cls_name] = PaddleCls + + +create_paddle_case('less_than', lambda _a, _b: _a < _b) +create_paddle_case('less_equal', lambda _a, _b: _a <= _b) +create_paddle_case('greater_than', lambda _a, _b: _a > _b) +create_paddle_case('greater_equal', lambda _a, _b: _a >= _b) +create_paddle_case('equal', lambda _a, _b: _a == _b) +create_paddle_case('not_equal', lambda _a, _b: _a != _b) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py index 7c546391f6..68e5a6ccdb 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_stack_op_xpu.py @@ -97,5 +97,27 @@ class TestStackOp6(TestStackOpBase): self.axis = 3 +class TestStackOpint64(TestStackOpBase): + def initDefaultParameters(self): + self.num_inputs = 4 + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'int64' + + def initParameters(self): + self.num_inputs = 16 + + +class TestStackOpint(TestStackOpBase): + def initDefaultParameters(self): + self.num_inputs = 4 + self.input_dim = (5, 6, 7) + self.axis = 0 + self.dtype = 'int' + + def initParameters(self): + self.num_inputs = 16 + + if __name__ == '__main__': unittest.main() -- GitLab