diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..47a549dfcde28999fd86c1025bbcefb7a05282e7 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_mlu.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class ElementwiseAddMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + int axis = ctx.Attr("axis"); + const auto& x_dims = x->dims(); + const auto& y_dims = y->dims(); + axis = (axis < 0 ? (std::abs(x_dims.size() - y_dims.size()) + axis + 1) + : axis); + int max_dim = std::max(x_dims.size(), y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), max_dim, + axis); + + MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), + ToCnnlDataType(x->type())); + MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), + ToCnnlDataType(y->type())); + MLUCnnlTensorDesc out_desc(*out); + MLUCnnlOpTensorDesc op_tensor_desc(CNNL_OP_TENSOR_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN); + + MLUCnnl::OpTensor(ctx, op_tensor_desc.get(), x_desc.get(), GetBasePtr(x), + y_desc.get(), GetBasePtr(y), out_desc.get(), + GetBasePtr(out), ToCnnlDataType()); + } +}; + +template +class ElementwiseAddGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis); + + MLUCnnlTensorDesc dout_desc(*dout); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + if (dx->dims() != dout->dims()) { + std::vector dst_dims_vec; + std::vector reduce_axes; + auto src_dims = dx->dims(); + auto dout_dims = dout->dims(); + + int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + src_dims.size()) || + (dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } else { + dst_dims_vec.push_back(dout_dims[ax]); + } + } + if (dst_dims_vec.size() == 0) { + // x is scalar + dst_dims_vec.push_back(1); + } + + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dx_desc(dst_dims_vec.size(), dst_dims_vec.data(), + ToCnnlDataType()); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(dout), 0, nullptr, + nullptr, dx_desc.get(), GetBasePtr(dx)); + } else { + framework::TensorCopy(*dout, ctx.GetPlace(), dev_ctx, dx); + } + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + if (dy->dims() != dout->dims()) { + std::vector dst_dims_vec; + std::vector reduce_axes; + auto src_dims = dy->dims(); + auto dout_dims = dout->dims(); + + int src_axis = (src_dims.size() < dout_dims.size() ? axis : 0); + for (int ax = 0; ax < dout_dims.size(); ++ax) { + if ((ax < src_axis || ax >= src_axis + src_dims.size()) || + (dout_dims[ax] > 1 && src_dims[ax - src_axis] == 1)) { + reduce_axes.push_back(ax); + } else { + dst_dims_vec.push_back(dout_dims[ax]); + } + } + if (dst_dims_vec.size() == 0) { + // y is scalar + dst_dims_vec.push_back(1); + } + + MLUCnnlReduceDesc reduction_desc( + reduce_axes, CNNL_REDUCE_ADD, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnlTensorDesc dy_desc(dst_dims_vec.size(), dst_dims_vec.data(), + ToCnnlDataType()); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, dout_desc.get(), GetBasePtr(dout), 0, nullptr, + nullptr, dy_desc.get(), GetBasePtr(dy)); + } else { + framework::TensorCopy(*dout, ctx.GetPlace(), dev_ctx, dy); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(elementwise_add, ops::ElementwiseAddMLUKernel, + ops::ElementwiseAddMLUKernel); +REGISTER_OP_MLU_KERNEL(elementwise_add_grad, + ops::ElementwiseAddGradMLUKernel, + ops::ElementwiseAddGradMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_elementwise_add_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_add_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..5b6db6903fba0dc1d77b6026623a3bc3c101013c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_elementwise_add_op_mlu.py @@ -0,0 +1,527 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append('..') +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestElementwiseAddOp(OpTest): + def set_mlu(self): + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + + def setUp(self): + self.op_type = "elementwise_add" + self.set_mlu() + self.init_dtype() + self.init_input_output() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis} + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['X', 'Y'], 'Out', max_relative_error=0.01) + + def test_check_grad_ingore_x(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['Y'], + 'Out', + no_grad_set=set("X"), + max_relative_error=0.01) + + def test_check_grad_ingore_y(self): + if self.dtype == np.float16: + return + self.check_grad_with_place( + self.place, ['X'], + 'Out', + no_grad_set=set('Y'), + max_relative_error=0.01) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float32 + + def init_axis(self): + self.axis = -1 + + +class TestFP16ElementwiseAddOp(TestElementwiseAddOp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseAddOp_scalar(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestFP16ElementwiseAddOp_scalar(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1,1) to test broadcast.") +class TestElementwiseAddOp_scalar2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1,1) to test broadcast.") +class TestFP16ElementwiseAddOp_scalar2(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_Vector(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + self.y = np.random.random((100, )).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestFP16ElementwiseAddOp_Vector(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + self.y = np.random.random((100, )).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestFP16ElementwiseAddOp_broadcast_0(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 100, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 100, 1) + + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_broadcast_1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 100, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 100, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 100) + + +class TestFP16ElementwiseAddOp_broadcast_2(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 100) + + +class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12, 1).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12, 1) + + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_broadcast_3(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 1, 2).astype(self.dtype) + self.y = np.random.rand(100, 1).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 1, 2).astype(self.dtype) + self.y = np.random.rand(100, 1).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 3, 12).astype(self.dtype) + self.y = np.random.rand(10, 1, 12).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_broadcast_5(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 3, 12).astype(self.dtype) + self.y = np.random.rand(10, 1, 12).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype) + self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype) + self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + +class TestFP16ElementwiseAddOp_broadcast_6(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype) + self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12) + + def init_axis(self): + self.axis = 1 + + +class TestFP16ElementwiseAddOp_rowwise_add_0(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12) + + def init_axis(self): + self.axis = 1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1) + + def init_axis(self): + self.axis = 1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestFP16ElementwiseAddOp_rowwise_add_1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseFP16AddOp_commonuse_add1(TestFP16ElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype) + self.y = np.random.rand(10, 1, 12, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 12).astype(self.dtype) + self.y = np.random.rand(2, 2, 10, 12).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = 2 + + +class TestElementwiseAddOp_same_shape_ysize_large(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 1, 12).astype(self.dtype) + self.y = np.random.rand(10, 2, 12).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseAddOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # the input of elementwise_add must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.MLUPlace(0)) + y1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.MLUPlace(0)) + self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1) + + # the input dtype of elementwise_add must be float16 or float32 + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8") + y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2) + + +class TestAddApi(unittest.TestCase): + def _executed_api(self, x, y, name=None): + return paddle.add(x, y, name) + + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[2, 3], dtype="float32") + y = fluid.data(name='y', shape=[2, 3], dtype='float32') + + y_1 = self._executed_api(x, y, name='add_res') + self.assertEqual(('add_res' in y_1.name), True) + + def test_declarative(self): + with fluid.program_guard(fluid.Program()): + + def gen_data(): + return { + "x": np.array([2, 3, 4]).astype('float32'), + "y": np.array([1, 5, 2]).astype('float32') + } + + x = fluid.data(name="x", shape=[3], dtype='float32') + y = fluid.data(name="y", shape=[3], dtype='float32') + z = self._executed_api(x, y) + + place = fluid.MLUPlace(0) + exe = fluid.Executor(place) + z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) + z_expected = np.array([3., 8., 6.]) + self.assertEqual((z_value == z_expected).all(), True) + + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.array([2, 3, 4]).astype('float32') + np_y = np.array([1, 5, 2]).astype('float32') + x = fluid.dygraph.to_variable(np_x) + y = fluid.dygraph.to_variable(np_y) + z = self._executed_api(x, y) + np_z = z.numpy() + z_expected = np.array([3., 8., 6.]) + self.assertEqual((np_z == z_expected).all(), True) + + +class TestAddInplaceApi(TestAddApi): + def _executed_api(self, x, y, name=None): + return x.add_(y, name) + + +class TestAddInplaceBroadcastSuccess(unittest.TestCase): + def init_data(self): + self.x_numpy = np.random.rand(2, 3, 4).astype('float32') + self.y_numpy = np.random.rand(3, 4).astype('float32') + + def test_broadcast_success(self): + paddle.disable_static() + self.init_data() + x = paddle.to_tensor(self.x_numpy) + y = paddle.to_tensor(self.y_numpy) + inplace_result = x.add_(y) + numpy_result = self.x_numpy + self.y_numpy + self.assertEqual((inplace_result.numpy() == numpy_result).all(), True) + paddle.enable_static() + + +class TestAddInplaceBroadcastSuccess2(TestAddInplaceBroadcastSuccess): + def init_data(self): + self.x_numpy = np.random.rand(1, 2, 3, 1).astype('float32') + self.y_numpy = np.random.rand(3, 1).astype('float32') + + +class TestAddInplaceBroadcastSuccess3(TestAddInplaceBroadcastSuccess): + def init_data(self): + self.x_numpy = np.random.rand(2, 3, 1, 5).astype('float32') + self.y_numpy = np.random.rand(1, 3, 1, 5).astype('float32') + + +class TestAddInplaceBroadcastError(unittest.TestCase): + def init_data(self): + self.x_numpy = np.random.rand(3, 4).astype('float32') + self.y_numpy = np.random.rand(2, 3, 4).astype('float32') + + def test_broadcast_errors(self): + paddle.disable_static() + self.init_data() + x = paddle.to_tensor(self.x_numpy) + y = paddle.to_tensor(self.y_numpy) + + def broadcast_shape_error(): + x.add_(y) + + self.assertRaises(ValueError, broadcast_shape_error) + paddle.enable_static() + + +class TestAddInplaceBroadcastError2(TestAddInplaceBroadcastError): + def init_data(self): + self.x_numpy = np.random.rand(2, 1, 4).astype('float32') + self.y_numpy = np.random.rand(2, 3, 4).astype('float32') + + +class TestAddInplaceBroadcastError3(TestAddInplaceBroadcastError): + def init_data(self): + self.x_numpy = np.random.rand(5, 2, 1, 4).astype('float32') + self.y_numpy = np.random.rand(2, 3, 4).astype('float32') + + +class TestBoolAddFloatElementwiseAddop(unittest.TestCase): + def test_static_add(self): + paddle.enable_static() + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) + paddle.enable_static() + + def test_dygraph_add(self): + paddle.disable_static() + a = 1.5 + b = paddle.full([4, 5, 6], True, dtype='bool') + c = a + b + self.assertTrue(c.dtype == core.VarDesc.VarType.FP32) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()