From 97e75ad0f51e21e92baeeb67884cab04d2f4c26c Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Wed, 23 Dec 2020 17:42:54 +0800 Subject: [PATCH] [setitem] Support Tensor setitem in static mode (#29708) 1. Type of index: int, slice(step must be 1). 2. Type of value: (1) int32, int64, float32, bool; (2) numpy.array(int32, int64, float32, bool); (3) paddle.Tensor(int32, int64, float32, float64, bool); --- paddle/fluid/operators/set_value_op.cc | 105 ++++ paddle/fluid/operators/set_value_op.cu | 24 + paddle/fluid/operators/set_value_op.h | 214 ++++++++ python/paddle/fluid/framework.py | 82 +++ .../unittests/dygraph_to_static/test_slice.py | 23 +- .../tests/unittests/test_set_value_op.py | 482 ++++++++++++++++++ 6 files changed, 928 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/set_value_op.cc create mode 100644 paddle/fluid/operators/set_value_op.cu create mode 100644 paddle/fluid/operators/set_value_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_set_value_op.py diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc new file mode 100644 index 0000000000..a928668a22 --- /dev/null +++ b/paddle/fluid/operators/set_value_op.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/set_value_op.h" + +#include + +namespace paddle { +namespace operators { + +class SetValue : public framework::OperatorWithKernel { + public: + SetValue(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue"); + auto in_dims = ctx->GetInputDim("Input"); + PADDLE_ENFORCE_LT( + in_dims.size(), 7, + platform::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", + in_dims.size())); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); + } +}; + +class SetValueMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "(Tensor) Input tensor of set_value operator."); + AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.") + .AsDispensable(); + AddOutput("Out", + "(Tensor) Output tensor of set_value operator. The output is the " + "same Tensor as input"); + + AddAttr("dtype", "data type of input.") + .InEnum( + {framework::proto::VarType::BOOL, framework::proto::VarType::INT32, + framework::proto::VarType::INT64, framework::proto::VarType::FP32, + framework::proto::VarType::FP64}) + .SetDefault(framework::proto::VarType::FP32); + AddAttr>( + "axes", "(list) Axes that `starts` and `ends` apply to."); + AddAttr>( + "starts", + "(list) Starting indices of corresponding axis in `axes`"); + AddAttr>( + "ends", + "(list) Ending indices of corresponding axis in `axes`."); + + AddAttr>("bool_values", "store the bool values") + .SetDefault({}); + AddAttr>("fp32_values", "store the float32 values") + .SetDefault({}); + AddAttr>("int32_values", "store the int32 values") + .SetDefault({}); + AddAttr>("int64_values", "store the int64 values") + .SetDefault({}); + + AddAttr>("shape", "(vector) Shape of values.") + .SetDefault({}); + AddComment(R"DOC(SetValue operator. +Assignment to a Tensor in static mode. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + set_value, ops::SetValue, ops::SetValueMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + set_value, ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel); diff --git a/paddle/fluid/operators/set_value_op.cu b/paddle/fluid/operators/set_value_op.cu new file mode 100644 index 0000000000..b65e1691b9 --- /dev/null +++ b/paddle/fluid/operators/set_value_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/set_value_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + set_value, ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel, + ops::SetValueKernel); diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h new file mode 100644 index 0000000000..e7624ed5eb --- /dev/null +++ b/paddle/fluid/operators/set_value_op.h @@ -0,0 +1,214 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/assign_value_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +inline std::string GetValueName(framework::proto::VarType::Type data_type) { + std::string value_name; + switch (data_type) { + case framework::proto::VarType::INT32: + value_name = "int32_values"; + break; + case framework::proto::VarType::INT64: + value_name = "int64_values"; + break; + case framework::proto::VarType::FP32: + value_name = "fp32_values"; + break; + case framework::proto::VarType::BOOL: + value_name = "bool_values"; + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported data type(code %d) for SetValue operator, only " + "supports bool, int32, float32 and int64.", + data_type)); + } + return value_name; +} + +inline framework::DDim GetSliceDims(const framework::DDim in_dims, + const std::vector axes, + const std::vector starts, + const std::vector ends) { + framework::DDim slice_dims(in_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + int64_t dim_value = in_dims[axis]; + + int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i]; + int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i]; + start = std::max(start, static_cast(0)); + end = std::min(end, dim_value); + + PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( + "end should greater than start, but " + "received end = %d, start = %d", + end, start)); + slice_dims[axis] = end - start; + } + return slice_dims; +} + +template +class SetValueKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const int rank = ctx.Output("Out")->dims().size(); + + // TODO(liym27): A more elegent code to do this. C++ has to make template + // integer as constant, but we had better have alternative writing in the + // future. + switch (rank) { + case 1: + SetValueCompute<1>(ctx); + break; + case 2: + SetValueCompute<2>(ctx); + break; + case 3: + SetValueCompute<3>(ctx); + break; + case 4: + SetValueCompute<4>(ctx); + break; + case 5: + SetValueCompute<5>(ctx); + break; + case 6: + SetValueCompute<6>(ctx); + break; + } + } + + private: + template + void SetValueCompute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input("Input"); + auto* out = ctx.Output("Out"); + + auto dtype = + static_cast(ctx.Attr("dtype")); + auto axes = ctx.Attr>("axes"); + auto starts = ctx.Attr>("starts"); + auto ends = ctx.Attr>("ends"); + auto shape = ctx.Attr>("shape"); + auto* value_tensor = ctx.Input("ValueTensor"); + + auto in_dims = in->dims(); + auto value_dims = framework::make_ddim(shape); + auto slice_dims = GetSliceDims(in_dims, axes, starts, ends); + + auto place = ctx.GetPlace(); + auto& eigen_place = + *ctx.template device_context().eigen_device(); + + // Here copy data from input to avoid data loss at PE and Graph level. + // TODO(liym27): Speed up in the future version. + // - Q: Why don't call ShareDataWith to speed up? + // - A: Because it's not supported to ShareDataWith on OP's input and output + // https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP + // - Q: Why don't delete Input, after all, the input and output are the same + // Tensor at program level? + // - A: If deleting Input, the graph will be complex, such as there will + // be two ops points to the output in graph: op1 -> output <- set_value. + // In this case, we have to find a way to handle the running order of + // set_value is what we want. + TensorCopy(*in, place, out); + + Tensor slice_t(dtype), pad_t(dtype); + slice_t.mutable_data(slice_dims, place); + pad_t.mutable_data(in_dims, place); + + auto pad_e = framework::EigenTensor::From(pad_t, in_dims); + auto out_e = framework::EigenTensor::From(*out); + auto slice_e = framework::EigenTensor::From(slice_t, slice_dims); + + // Step 1: Set the value of out at `_index` to zero + // - Step 1.1 Get a slice tensor from out + Eigen::array offsets, extents; + Eigen::array, D> paddings; + + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + int64_t start; + for (size_t i = 0; i < axes.size(); ++i) { + start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i]; + start = std::max(start, static_cast(0)); + offsets[axes[i]] = start; + } + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i]; + } + + slice_e.device(eigen_place) = out_e.slice(offsets, extents); + + // - Step 1.2 Get paded tensor by padding 0 to slice tensor + pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); + + // - Step 1.3 Set 0 at `_index` of out tensor + out_e.device(eigen_place) = out_e - pad_e; + + // Step 2: Set a tensor with the same shape as out tensor. And its data at + // '_index' is the same as value_tensor, and data out of '_index' to zero + + // - Step 2.1 Set the data of slice tensor to 0 + slice_e.device(eigen_place) = slice_e.constant(T(0)); + + // - Step 2.2 Set slice tensor with value + if (value_tensor != nullptr) { + // ElementwiseComputeEx can do broadcasting + ElementwiseComputeEx, DeviceContext, T>( + ctx, &slice_t, value_tensor, -1, SubFunctor(), &slice_t); + } else { + Tensor value_t(dtype); + value_t.mutable_data(value_dims, place); + auto value_name = GetValueName(dtype); + CopyVecotorToTensor(value_name.c_str(), &value_t, ctx); + value_t.Resize(value_dims); + ElementwiseComputeEx, DeviceContext, T>( + ctx, &slice_t, &value_t, -1, SubFunctor(), &slice_t); + } + + // - Step 2.3 Pad slice tensor with 0 + pad_e.device(eigen_place) = slice_e.pad(paddings, T(0)); + + // Step 3: Set out tensor with value_tensor + out_e.device(eigen_place) = out_e - pad_e; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a0e650e4da..d3f80bdb64 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1817,6 +1817,88 @@ class Variable(object): def __getitem__(self, item): return _getitem_impl_(self, item) + def __setitem__(self, item, value): + inputs = {'Input': self} + + # 1. Parse item + if not isinstance(item, tuple): + item = [item] + + axes = [] + starts = [] + ends = [] + max_integer = sys.maxsize + for dim, slice_item in enumerate(item): + if isinstance(slice_item, slice): + start = slice_item.start + end = slice_item.stop + step = slice_item.step + + if start is None and end is None and step is None: + continue + + start = 0 if start is None else start + step = 1 if step is None else step + + # TODO: support cases when step != 1 + if step != 1: + raise ValueError( + "When assign a value to a paddle.Tensor, only support step is 1, " + "but received step is {}.".format(step)) + end = max_integer if end is None else end + else: + start = slice_item + end = slice_item + 1 if slice_item != -1 else max_integer + axes.append(dim) + starts.append(start) + ends.append(end) + + attrs = {'axes': axes, 'starts': starts, 'ends': ends} + + # 2. Parse value + dtype = self.dtype + attrs['dtype'] = dtype + + # 2.1 value is an integer of float + if isinstance(value, (int, float)): + value = np.array([value]) + + # 2.2 value is a np.ndarray + if isinstance(value, np.ndarray): + shape = list(value.shape) + if dtype == core.VarDesc.VarType.BOOL: + value_name = "bool_values" + values = [bool(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.FP32: + value_name = "fp32_values" + values = [float(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.INT32: + value_name = "int32_values" + values = [int(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.INT64: + value_name = "int64_values" + values = [int(v) for v in value.flat] + else: + from .data_feeder import convert_dtype + raise TypeError( + "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " + "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " + "received %s." % convert_dtype(dtype)) + attrs[value_name] = values + attrs["shape"] = shape + + elif isinstance(value, Variable): + inputs["ValueTensor"] = value + else: + raise TypeError( + "Only support to assign an integer, float, numpy.ndarray or " + "paddle.Tensor to a paddle.Tensor, but received {}".format( + type(value))) + + self.block.append_op( + type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs) + return self + def get_all_op_protos(): """ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py index 14fa75e458..bf74299806 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py @@ -84,14 +84,25 @@ def test_slice_in_for_loop(x, iter_num=3): return out +@paddle.jit.to_static +def test_set_value(x): + x = paddle.to_tensor(x) + x[0] = paddle.full(shape=[1], fill_value=2, dtype="float32") + x[1:2, 0:1] = 10 + return x + + class TestSliceWithoutControlFlow(unittest.TestCase): def setUp(self): - self.input = np.random.random((3)).astype('int32') + self.init_input() self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( ) else paddle.CPUPlace() self.init_dygraph_func() paddle.disable_static() + def init_input(self): + self.input = np.random.random((3)).astype('int32') + def init_dygraph_func(self): self.dygraph_func = test_slice_without_control_flow @@ -125,10 +136,18 @@ class TestSliceInWhileLoop(TestSliceWithoutControlFlow): self.dygraph_func = test_slice_in_while_loop -class TestSliceInForLoop(TestSliceInWhileLoop): +class TestSliceInForLoop(TestSliceWithoutControlFlow): def init_dygraph_func(self): self.dygraph_func = test_slice_in_for_loop +class TestSetValue(TestSliceWithoutControlFlow): + def init_input(self): + self.input = np.full([3, 4, 5], 5).astype('float32') + + def init_dygraph_func(self): + self.dygraph_func = test_set_value + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py new file mode 100644 index 0000000000..cc5bf01b62 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -0,0 +1,482 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test set_value op in static mode + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle + + +class TestSetValueBase(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.set_dtype() + self.set_value() + self.shape = [2, 3, 4] + self.data = np.ones(self.shape).astype(self.dtype) + self.program = paddle.static.Program() + + def set_value(self): + self.value = 6 + + def set_dtype(self): + self.dtype = "float32" + + def _call_setitem(self, x): + x[0, 0] = self.value + + def _get_answer(self): + self.data[0, 0] = self.value + + +class TestSetValueApi(TestSetValueBase): + def test_api(self): + with paddle.static.program_guard(self.program): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + self._call_setitem(x) + + exe = paddle.static.Executor(paddle.CPUPlace()) + out = exe.run(self.program, fetch_list=[x]) + + self._get_answer() + self.assertTrue( + (self.data == out).all(), + msg="\nExpected res = \n{}, \n\nbut received : \n{}".format( + self.data, out)) + + +# 1. Test different type of item: int, python slice +class TestSetValueItemInt(TestSetValueApi): + def _call_setitem(self, x): + x[0] = self.value + + def _get_answer(self): + self.data[0] = self.value + + +class TestSetValueItemSlice(TestSetValueApi): + def _call_setitem(self, x): + x[0:2] = self.value + + def _get_answer(self): + self.data[0:2] = self.value + + +class TestSetValueItemSlice2(TestSetValueApi): + def _call_setitem(self, x): + x[0:-1] = self.value + + def _get_answer(self): + self.data[0:-1] = self.value + + +class TestSetValueItemSlice3(TestSetValueApi): + def _call_setitem(self, x): + x[0:-1, 0:2] = self.value + + def _get_answer(self): + self.data[0:-1, 0:2] = self.value + + +class TestSetValueItemSlice4(TestSetValueApi): + def _call_setitem(self, x): + x[0:, 1:2, :] = self.value + + def _get_answer(self): + self.data[0:, 1:2, :] = self.value + + +# 2. Test different type of value: int, float, numpy.ndarray, Tensor +# 2.1 value is int32, int64, float32, bool + + +def create_test_value_int32(parent): + class TestValueInt(parent): + def set_value(self): + self.value = 7 + + def set_dtype(self): + self.dtype = "int32" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueInt32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_int32(TestSetValueItemInt) +create_test_value_int32(TestSetValueItemSlice) +create_test_value_int32(TestSetValueItemSlice2) +create_test_value_int32(TestSetValueItemSlice3) +create_test_value_int32(TestSetValueItemSlice4) + + +def create_test_value_int64(parent): + class TestValueInt(parent): + def set_value(self): + self.value = 7 + + def set_dtype(self): + self.dtype = "int64" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueInt64") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_int64(TestSetValueItemInt) +create_test_value_int64(TestSetValueItemSlice) +create_test_value_int64(TestSetValueItemSlice2) +create_test_value_int64(TestSetValueItemSlice3) +create_test_value_int64(TestSetValueItemSlice4) + + +def create_test_value_fp32(parent): + class TestValueInt(parent): + def set_value(self): + self.value = 3.3 + + def set_dtype(self): + self.dtype = "float32" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueFp32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_fp32(TestSetValueItemInt) +create_test_value_fp32(TestSetValueItemSlice) +create_test_value_fp32(TestSetValueItemSlice2) +create_test_value_fp32(TestSetValueItemSlice3) +create_test_value_fp32(TestSetValueItemSlice4) + + +def create_test_value_bool(parent): + class TestValueInt(parent): + def set_value(self): + self.value = 0 + + def set_dtype(self): + self.dtype = "bool" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueBool") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_bool(TestSetValueItemInt) +create_test_value_bool(TestSetValueItemSlice) +create_test_value_bool(TestSetValueItemSlice2) +create_test_value_bool(TestSetValueItemSlice3) +create_test_value_bool(TestSetValueItemSlice4) + + +# 2.2 value is numpy.array (int32, int64, float32, bool) +def create_test_value_numpy_int32(parent): + class TestValueInt(parent): + def set_value(self): + self.value = np.array([5]) + + def set_dtype(self): + self.dtype = "int32" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyInt32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_numpy_int32(TestSetValueItemInt) +create_test_value_numpy_int32(TestSetValueItemSlice) +create_test_value_numpy_int32(TestSetValueItemSlice2) +create_test_value_numpy_int32(TestSetValueItemSlice3) +create_test_value_numpy_int32(TestSetValueItemSlice4) + + +def create_test_value_numpy_int64(parent): + class TestValueInt(parent): + def set_value(self): + self.value = np.array([1]) + + def set_dtype(self): + self.dtype = "int64" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyInt64") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_numpy_int64(TestSetValueItemInt) +create_test_value_numpy_int64(TestSetValueItemSlice) +create_test_value_numpy_int64(TestSetValueItemSlice2) +create_test_value_numpy_int64(TestSetValueItemSlice3) +create_test_value_numpy_int64(TestSetValueItemSlice4) + + +def create_test_value_numpy_fp32(parent): + class TestValueInt(parent): + def set_value(self): + self.value = np.array([1]) + + def set_dtype(self): + self.dtype = "float32" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyFp32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_numpy_fp32(TestSetValueItemInt) +create_test_value_numpy_fp32(TestSetValueItemSlice) +create_test_value_numpy_fp32(TestSetValueItemSlice2) +create_test_value_numpy_fp32(TestSetValueItemSlice3) +create_test_value_numpy_fp32(TestSetValueItemSlice4) + + +def create_test_value_numpy_bool(parent): + class TestValueInt(parent): + def set_value(self): + self.value = np.array([0]) + + def set_dtype(self): + self.dtype = "bool" + + cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyBool") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_numpy_bool(TestSetValueItemInt) +create_test_value_numpy_bool(TestSetValueItemSlice) +create_test_value_numpy_bool(TestSetValueItemSlice2) +create_test_value_numpy_bool(TestSetValueItemSlice3) +create_test_value_numpy_bool(TestSetValueItemSlice4) + + +# 2.3 value is a Paddle Tensor (int32, int64, float32, float64, bool) +def create_test_value_tensor_int32(parent): + class TestValueInt(parent): + def set_dtype(self): + self.dtype = "int32" + + def _call_setitem(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x[0, 1] = value + + def _get_answer(self): + self.data[0, 1] = 3 + + cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorInt32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_tensor_int32(TestSetValueItemInt) +create_test_value_tensor_int32(TestSetValueItemSlice) +create_test_value_tensor_int32(TestSetValueItemSlice2) +create_test_value_tensor_int32(TestSetValueItemSlice3) +create_test_value_tensor_int32(TestSetValueItemSlice4) + + +def create_test_value_tensor_int64(parent): + class TestValueInt(parent): + def set_dtype(self): + self.dtype = "int64" + + def _call_setitem(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x[0, 1] = value + + def _get_answer(self): + self.data[0, 1] = 3 + + cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorInt64") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_tensor_int64(TestSetValueItemInt) +create_test_value_tensor_int64(TestSetValueItemSlice) +create_test_value_tensor_int64(TestSetValueItemSlice2) +create_test_value_tensor_int64(TestSetValueItemSlice3) +create_test_value_tensor_int64(TestSetValueItemSlice4) + + +def create_test_value_tensor_fp32(parent): + class TestValueInt(parent): + def set_dtype(self): + self.dtype = "float32" + + def _call_setitem(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x[0, 1] = value + + def _get_answer(self): + self.data[0, 1] = 3 + + cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorFp32") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_tensor_fp32(TestSetValueItemInt) +create_test_value_tensor_fp32(TestSetValueItemSlice) +create_test_value_tensor_fp32(TestSetValueItemSlice2) +create_test_value_tensor_fp32(TestSetValueItemSlice3) +create_test_value_tensor_fp32(TestSetValueItemSlice4) + + +def create_test_value_tensor_fp64(parent): + class TestValueInt(parent): + def set_dtype(self): + self.dtype = "float64" + + def _call_setitem(self, x): + value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype) + x[0, 1] = value + + def _get_answer(self): + self.data[0, 1] = 3 + + cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorFp64") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_tensor_fp64(TestSetValueItemInt) +create_test_value_tensor_fp64(TestSetValueItemSlice) +create_test_value_tensor_fp64(TestSetValueItemSlice2) +create_test_value_tensor_fp64(TestSetValueItemSlice3) +create_test_value_tensor_fp64(TestSetValueItemSlice4) + + +def create_test_value_tensor_bool(parent): + class TestValueInt(parent): + def set_dtype(self): + self.dtype = "bool" + + def _call_setitem(self, x): + value = paddle.full(shape=[1], fill_value=False, dtype=self.dtype) + x[0, 1] = value + + def _get_answer(self): + self.data[0, 1] = False + + cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorBool") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_tensor_bool(TestSetValueItemInt) +create_test_value_tensor_bool(TestSetValueItemSlice) +create_test_value_tensor_bool(TestSetValueItemSlice2) +create_test_value_tensor_bool(TestSetValueItemSlice3) +create_test_value_tensor_bool(TestSetValueItemSlice4) + + +# 3. Test different shape of value +class TestSetValueValueShape1(TestSetValueApi): + def set_value(self): + self.value = np.array([3, 4, 5, 6]) # shape is (4,) + + def _call_setitem(self, x): + x[0] = self.value + + def _get_answer(self): + self.data[0] = self.value + + +class TestSetValueValueShape2(TestSetValueApi): + def set_value(self): + self.value = np.array([[3, 4, 5, 6]]) # shape is (1,4) + + def _call_setitem(self, x): + x[0:1] = self.value + + def _get_answer(self): + self.data[0:1] = self.value + + +class TestSetValueValueShape3(TestSetValueApi): + def set_value(self): + self.value = np.array( + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]) # shape is (3,4) + + def _call_setitem(self, x): + x[0] = self.value + + def _get_answer(self): + self.data[0] = self.value + + +class TestSetValueValueShape4(TestSetValueApi): + def set_value(self): + self.value = np.array( + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]).astype( + self.dtype) # shape is (3,4) + + def _call_setitem(self, x): + x[0] = paddle.assign(self.value) # x is Paddle.Tensor + + def _get_answer(self): + self.data[0] = self.value + + +# 4. Test error +class TestError(TestSetValueBase): + def _value_type_error(self): + with self.assertRaisesRegexp( + TypeError, + "Only support to assign an integer, float, numpy.ndarray or paddle.Tensor" + ): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + value = [1] + x[0] = value + + def _dtype_error(self): + with self.assertRaisesRegexp( + TypeError, + "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " + ): + y = paddle.ones(shape=self.shape, dtype="float64") + y[0] = 1 + + def _step_error(self): + with self.assertRaisesRegexp(ValueError, "only support step is 1"): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + x[0:1:2] = self.value + + def _broadcast_mismatch(self): + program = paddle.static.Program() + with paddle.static.program_guard(program): + x = paddle.ones(shape=self.shape, dtype=self.dtype) + value = np.array([3, 4, 5, 6, 7]) + x[0] = value + exe = paddle.static.Executor(paddle.CPUPlace()) + with self.assertRaisesRegexp(ValueError, + "Broadcast dimension mismatch."): + exe.run(program) + + def test_error(self): + with paddle.static.program_guard(self.program): + self._value_type_error() + self._dtype_error() + self._step_error() + self._broadcast_mismatch() + + +if __name__ == '__main__': + unittest.main() -- GitLab