diff --git a/paddle/fluid/operators/assign_op_xpu.cc b/paddle/fluid/operators/assign_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..6255b5d341e09dab55ceebe16f72b814725057b1 --- /dev/null +++ b/paddle/fluid/operators/assign_op_xpu.cc @@ -0,0 +1,161 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/assign_op.h" + +#include + +namespace paddle { +namespace framework { +class OpDesc; +class Variable; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +namespace platform { +struct CPUPlace; +struct CUDAPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { + +class AssignOp : public framework::OperatorWithKernel { + public: + AssignOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + if (ctx->HasInput("X")) { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { + if (ctx->IsRuntime()) { + // The runtime output shape is determined in kernel. + return; + } else { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + } + } + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const framework::Variable *var = ctx.InputVar("X"); + if (var->IsType()) { + auto t_arr = var->Get(); + // NOTE(liym27): Support an empty tensor array as Input. + // And set the kernel type is float. + if (t_arr.size() == 0) { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } + } + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class AssignInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + ctx->SyncTypeAndDataType("X", "Out"); + } +}; + +class AssignKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); + if (x == nullptr) { + return; + } + PADDLE_ENFORCE_EQ( + ctx.HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of assign_op is not found.")); + auto *out = ctx.OutputVar("Out"); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); + + framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); + } +}; + +class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LoDTensor, SelectedRows or LoDTensorArray) The input variable " + "could be LoDTensor, SelectedRows or LoDTensorArray.") + .AsDispensable(); + AddOutput("Out", + "(LoDTensor, SelectedRows or LoDTensorArray) The type of output " + "is the same as input X."); + AddComment(R"DOC(Assign Operator + +Out = X, when type in [LoDTensor/SelectedRows/LoDTensorArray] +raise error if the type is not listed above. +)DOC"); + } +}; + +template +class AssignGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("assign"); + op->SetInput("X", this->OutputGrad("Out")); + op->SetOutput("Out", this->InputGrad("X")); + } +}; + +DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, + ops::AssignKernel, int, ops::AssignKernel, + int64_t, ops::AssignKernel, bool, + ops::AssignKernel); +#endif diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..56160bd297e28933f214e47ddb3275646507cc0e --- /dev/null +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/cast_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +class CastXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + auto in_type = static_cast( + context.Attr("in_dtype")); + auto out_type = static_cast( + context.Attr("out_dtype")); + auto* in_data = in->data(); + auto numel = in->numel(); + auto& dev_ctx = context.template device_context(); + int r = -1; + if (out_type == framework::proto::VarType::FP32) { + auto* out_data = out->mutable_data(context.GetPlace()); + r = xpu::cast(dev_ctx.x_context(), in_data, out_data, numel); + } else if (out_type == framework::proto::VarType::INT32) { + auto* out_data = out->mutable_data(context.GetPlace()); + r = xpu::cast(dev_ctx.x_context(), in_data, out_data, numel); + } else if (out_type == framework::proto::VarType::INT64) { + auto* out_data = out->mutable_data(context.GetPlace()); + r = xpu::cast(dev_ctx.x_context(), in_data, out_data, + numel); + } else { + PADDLE_THROW(platform::errors::Unavailable("Not supported cast %d -> %d", + in_type, out_type)); + } + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + cast, ops::CastXPUKernel, + ops::CastXPUKernel, + ops::CastXPUKernel); +#endif diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c9c72c7f6f7859bf3a78501a5707953766a3541 --- /dev/null +++ b/paddle/fluid/operators/concat_op_xpu.cc @@ -0,0 +1,185 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/concat_op.h" + +#include +#include +#include + +#ifdef PADDLE_WITH_MKLDNN +#include +#endif + +#ifdef PADDLE_WITH_XPU + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class ConcatXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + framework::Tensor* out = ctx.Output("Out"); + int axis = ctx.Attr("axis"); + PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument( + "The input should not be null.")); + PADDLE_ENFORCE_NE(ctx.HasInput("AxisTensor"), true, + platform::errors::InvalidArgument( + "XPU donot surpport AxisTensor for now")); + axis = ComputeAxis(static_cast(axis), + static_cast(ins[0]->dims().size())); + PADDLE_ENFORCE_GE( + axis, 0, platform::errors::InvalidArgument("concat: axis shoud >= 0!")); + PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(), + platform::errors::InvalidArgument( + "concat: axis shoud < ins[0]->dims()!")); + auto place = ctx.GetPlace(); + out->mutable_data(place); + std::vector choose_idx; + int n = 0; + for (unsigned int i = 0; i < ins.size(); ++i) { + if (ins[i] && ins[i]->numel() > 0) { + choose_idx.push_back(i); + n++; + } + } + PADDLE_ENFORCE_LE(n, 8, platform::errors::InvalidArgument( + "XPU only surpport at most 8 tensors for now")); + PADDLE_ENFORCE_GT( + n, 0, platform::errors::InvalidArgument("No tensor need concat?")); + int h = 1; + int w_except_axis = 1; + for (int i = 0; i < axis; ++i) { + h *= (ins[choose_idx[0]]->dims())[i]; + } + for (int i = axis + 1; i < ins[0]->dims().size(); ++i) { + w_except_axis *= (ins[choose_idx[0]]->dims())[i]; + } + for (int i = 1; i < n; ++i) { + int hh = 1; + int ww = 1; + for (int j = 0; j < axis; ++j) { + hh *= (ins[choose_idx[i]]->dims())[j]; + } + for (int j = axis + 1; j < ins[i]->dims().size(); ++j) { + ww *= (ins[choose_idx[i]]->dims())[j]; + } + PADDLE_ENFORCE_EQ(hh, h, platform::errors::InvalidArgument( + "concat: h should be eual!")); + PADDLE_ENFORCE_EQ(ww, w_except_axis, + platform::errors::InvalidArgument( + "concat: w should be eual except for axis!")); + } + auto& dev_ctx = ctx.template device_context(); + std::unique_ptr in_w_host(new int[n]); + std::unique_ptr ptrs(new const float*[n]); + for (int i = 0; i < n; ++i) { + ptrs[i] = ins[choose_idx[i]]->data(); + in_w_host[i] = w_except_axis * (ins[choose_idx[i]]->dims())[axis]; + } + int r = + xpu::concat(dev_ctx.x_context(), h, (const int*)in_w_host.get(), + n, (const float**)ptrs.get(), out->data()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } +}; +template +class ConcatGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto ins = ctx.MultiInput("X"); + auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); + auto outs = + ctx.MultiOutput(framework::GradVarName("X")); + { + auto dx = outs; + auto x = ins; + for (size_t i = 0; i < dx.size(); ++i) { + if (dx[i] != nullptr) { + dx[i]->set_lod(x[i]->lod()); + } + } + } + PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument( + "The input should not be null.")); + auto axis = ctx.Attr("axis"); + if (ctx.HasInput("AxisTensor")) { + auto* axis_tensor = ctx.Input("AxisTensor"); + axis = GetDataFromTensor(axis_tensor)[0]; + } + axis = ComputeAxis(static_cast(axis), + static_cast(ins[0]->dims().size())); + // get output tensor that the name is not kEmptyVarName + std::vector outputs; + for (size_t j = 0; j < outs.size(); ++j) { + if (out_var_names[j] != framework::kEmptyVarName && + outs[j]->numel() != 0UL) { + outs[j]->mutable_data(ctx.GetPlace()); + outputs.push_back(outs[j]); + } else { + outputs.push_back(nullptr); + } + } + PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument( + "concat_grad: axis shoud >= 0!")); + PADDLE_ENFORCE_LT(axis, out_grad->dims().size(), + platform::errors::InvalidArgument( + "concat_grad: axis shoud < ins[0]->dims()!")); + auto out_grad_stride = framework::stride_numel(out_grad->dims()); + int n = outputs.size(); + PADDLE_ENFORCE_LE(n, 16, + platform::errors::InvalidArgument( + "XPU only surpport at most 16 tensors for now")); + int h = out_grad_stride[0] / out_grad_stride[axis]; + auto& dev_ctx = ctx.template device_context(); + std::unique_ptr in_w_host(new int[n]); + std::unique_ptr ptrs(new float*[n]); + for (int i = 0; i < n; ++i) { + auto out_stride = framework::stride_numel(outputs[i]->dims()); + ptrs[i] = outputs[i]->data(); + in_w_host[i] = out_stride[axis]; + } + int r = xpu::concat_grad(dev_ctx.x_context(), h, in_w_host.get(), n, + reinterpret_cast(ptrs.get()), + out_grad->data()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_XPU_KERNEL( + concat, ops::ConcatXPUKernel); +REGISTER_OP_XPU_KERNEL( + concat_grad, + ops::ConcatGradXPUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 82efd66a5e54c6d98594993a36d870d30e6a3731..24a80ed2ed6ff87bab44717277f14a896096b2e7 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -253,7 +253,8 @@ class TestConcatAPI(unittest.TestCase): assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1)) def test_api(self): - x_1 = paddle.fluid.data(shape=[None, 1, 4, 5], dtype='int32', name='x_1') + x_1 = paddle.fluid.data( + shape=[None, 1, 4, 5], dtype='int32', name='x_1') paddle.concat([x_1, x_1], 0) input_2 = np.random.random([2, 1, 4, 5]).astype("int32") diff --git a/python/paddle/fluid/tests/unittests/xpu/test_assign_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_assign_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..110e7bb3cbf41c43f964cfffb3d40b8eff5948d0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_assign_op_xpu.py @@ -0,0 +1,97 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys + +sys.path.append("..") +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward + + +class TestAssignOp(op_test.OpTest): + def setUp(self): + self.op_type = "assign" + x = np.random.random(size=(100, 10)).astype('float32') + self.inputs = {'X': x} + self.outputs = {'Out': x} + + def test_forward(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_backward(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + +class TestAssignOpWithLoDTensorArray(unittest.TestCase): + def test_assign_LoDTensorArray(self): + main_program = Program() + startup_program = Program() + with program_guard(main_program): + x = fluid.data(name='x', shape=[100, 10], dtype='float32') + x.stop_gradient = False + y = fluid.layers.fill_constant( + shape=[100, 10], dtype='float32', value=1) + z = fluid.layers.elementwise_add(x=x, y=y) + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + init_array = fluid.layers.array_write(x=z, i=i) + array = fluid.layers.assign(init_array) + sums = fluid.layers.array_read(array=init_array, i=i) + mean = fluid.layers.mean(sums) + append_backward(mean) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + feed_x = np.random.random(size=(100, 10)).astype('float32') + ones = np.ones((100, 10)).astype('float32') + feed_add = feed_x + ones + res = exe.run(main_program, + feed={'x': feed_x}, + fetch_list=[sums.name, x.grad_name]) + self.assertTrue(np.allclose(res[0], feed_add)) + self.assertTrue(np.allclose(res[1], ones / 1000.0)) + + +class TestAssignOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The type of input must be Variable or numpy.ndarray. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.XPUPlace(0)) + self.assertRaises(TypeError, fluid.layers.assign, x1) + # When the type of input is Variable, the dtype of input must be float16, float32, float64, int32, int64, bool. + x3 = fluid.layers.data(name='x3', shape=[4], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.assign, x3) + # When the type of input is numpy.ndarray, the dtype of input must be float32, int32. + x4 = np.array([[2.5, 2.5]], dtype='float64') + self.assertRaises(TypeError, fluid.layers.assign, x4) + x5 = np.array([[2.5, 2.5]], dtype='uint8') + self.assertRaises(TypeError, fluid.layers.assign, x5) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..cb64cb90e8c2c712771368fccedea734ad00883e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -0,0 +1,106 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys + +sys.path.append("..") +import op_test +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestCastOp1(op_test.OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.op_type = 'cast' + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], ['Out']) + + +class TestCastOp2(op_test.OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.op_type = 'cast' + + def test_check_output(self): + #self.check_output(atol=1e-3) + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=1e-3) + + +class TestCastOp3(op_test.OpTest): + def setUp(self): + ipt = np.random.random(size=[10, 10]) + self.inputs = {'X': ipt.astype('float32')} + self.outputs = {'Out': ipt.astype('float32')} + self.attrs = { + 'in_dtype': int(core.VarDesc.VarType.FP32), + 'out_dtype': int(core.VarDesc.VarType.FP32) + } + self.op_type = 'cast' + + def test_check_output(self): + #self.check_output(atol=1e-3) + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=1e-3) + + +class TestCastOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of cast_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.XPUPlace(0)) + self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32') + # The input dtype of cast_op must be float32, int32, int64. + x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16') + self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32') + + def test_dtype_type(): + x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32') + output = fluid.layers.cast(x=x4, dtype='int16') + + self.assertRaises(TypeError, test_dtype_type) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5d7134a1bad0ad963996cc7322fbb57bb58a9c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_concat_op_xpu.py @@ -0,0 +1,240 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys + +sys.path.append("..") +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard, core +import paddle + + +class TestConcatOp(OpTest): + def setUp(self): + self.op_type = "concat" + self.dtype = self.get_dtype() + self.init_test_data() + self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} + self.attrs = {'axis': self.axis} + if self.axis < 0: + self.actual_axis = self.axis + len(self.x0.shape) + self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0 + else: + self.actual_axis = self.axis + + self.outputs = { + 'Out': np.concatenate( + (self.x0, self.x1, self.x2), axis=self.actual_axis) + } + + def get_dtype(self): + return "float64" + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['x0'], 'Out') + self.check_grad_with_place(place, ['x1'], 'Out') + self.check_grad_with_place(place, ['x2'], 'Out') + + def init_test_data(self): + self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype) + self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype) + self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype) + self.axis = 1 + + +class TestConcatOp2(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.x2 = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.axis = 1 + + +@skip_check_grad_ci( + reason="The function 'check_grad' for large inputs is too slow.") +class TestConcatOp3(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((1, 256, 170, 256)).astype(self.dtype) + self.x1 = np.random.random((1, 128, 170, 256)).astype(self.dtype) + self.x2 = np.random.random((1, 128, 170, 256)).astype(self.dtype) + self.axis = 1 + + def test_check_grad(self): + pass + + +@skip_check_grad_ci( + reason="This test will meet fetch error when there is a null grad. The detailed information is in PR#17015." +) +class TestConcatOp4(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.x1 = np.random.random((2, 3, 4, 5)).astype(self.dtype) + self.x2 = np.random.random((0, 3, 4, 5)).astype(self.dtype) + self.axis = 0 + + def test_check_grad(self): + pass + + +class TestConcatOp5(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((5, 1, 4, 5)).astype(self.dtype) + self.x1 = np.random.random((5, 2, 4, 5)).astype(self.dtype) + self.x2 = np.random.random((5, 3, 4, 5)).astype(self.dtype) + self.axis = -3 + + +class TestConcatOp6(TestConcatOp): + def setUp(self): + self.op_type = "concat" + self.dtype = self.get_dtype() + self.init_test_data() + self.lod = [[20, 80]] + self.out_lod = [[20, 80, 20, 80, 20, 80]] + self.inputs = { + 'X': [('x0', (self.x0, self.lod)), ('x1', (self.x1, self.lod)), + ('x2', (self.x2, self.lod))] + } + self.attrs = {'axis': self.axis} + if self.axis < 0: + self.actual_axis = self.axis + len(self.x0.shape) + self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0 + else: + self.actual_axis = self.axis + out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis) + self.outputs = {'Out': (out, self.out_lod)} + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['x0'], 'Out') + self.check_grad_with_place(place, ['x1'], 'Out') + self.check_grad_with_place(place, ['x2'], 'Out') + + def init_test_data(self): + self.x0 = np.random.random([100]).astype(self.dtype) + self.x1 = np.random.random([100]).astype(self.dtype) + self.x2 = np.random.random([100]).astype(self.dtype) + self.axis = 0 + + +class TestConcatOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of concat_op should be list. + x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1') + fluid.layers.concat(x1) + # The item in input must be Variable. + x2 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + x3 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, fluid.layers.concat, [x2]) + # The input dtype of concat_op must be float16, float32, float64, int32, int64. + x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4') + x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5') + self.assertRaises(TypeError, fluid.layers.concat, [x4, x5]) + x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6') + x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7') + x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8') + fluid.layers.concat([x6, x7]) + + # The type of axis in concat_op should be int or Variable. + def test_axis_type(): + fluid.layers.concat([x6, x7], 3.2) + + self.assertRaises(TypeError, test_axis_type) + + def test_input_same_dtype(): + fluid.layers.concat([x7, x8]) + + self.assertRaises(TypeError, test_input_same_dtype) + + +class TestConcatAPI(unittest.TestCase): + def test_fluid_api(self): + x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='float32', name='x_1') + fluid.layers.concat([x_1, x_1], 0) + + input_2 = np.random.random([2, 1, 4, 5]).astype("float32") + input_3 = np.random.random([2, 2, 4, 5]).astype("float32") + x_2 = fluid.data(shape=[2, 1, 4, 5], dtype='float32', name='x_2') + x_3 = fluid.data(shape=[2, 2, 4, 5], dtype='float32', name='x_3') + positive_1_int32 = fluid.layers.fill_constant([1], "float32", 1) + positive_1_int64 = fluid.layers.fill_constant([1], "float32", 1) + out_1 = fluid.layers.concat(input=[x_2, x_3], axis=1) + out_2 = fluid.layers.concat(input=[x_2, x_3], axis=1) + out_3 = fluid.layers.concat(input=[x_2, x_3], axis=1) + + exe = fluid.Executor(place=fluid.XPUPlace(0)) + [res_1, res_2, res_3] = exe.run( + fluid.default_main_program(), + feed={"x_1": input_2, + "x_2": input_2, + "x_3": input_3}, + fetch_list=[out_1, out_2, out_3]) + assert np.array_equal(res_1, np.concatenate((input_2, input_3), axis=1)) + assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1)) + assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1)) + + def test_errors(self): + with program_guard(Program(), Program()): + # The item in input must be Variable. + x2 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.XPUPlace(0)) + x3 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.XPUPlace(0)) + self.assertRaises(TypeError, paddle.concat, [x2]) + # The input dtype of concat_op must be float32. + x4 = fluid.data(shape=[4], dtype='uint8', name='x4') + x5 = fluid.data(shape=[4], dtype='uint8', name='x5') + self.assertRaises(TypeError, fluid.layers.concat, [x4, x5]) + + # The type of axis in concat_op should be int or Variable. + x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6') + x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7') + x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8') + + def test_axis_type(): + paddle.concat([x6, x7], 3.2) + + self.assertRaises(TypeError, test_axis_type) + + def test_input_same_dtype(): + paddle.concat([x7, x8]) + + self.assertRaises(TypeError, test_input_same_dtype) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()