未验证 提交 97e75ad0 编写于 作者: L liym27 提交者: GitHub

[setitem] Support Tensor setitem in static mode (#29708)

1. Type of index: int, slice(step must be 1).

2. Type of value: 
 (1) int32, int64, float32, bool; 
 (2) numpy.array(int32, int64, float32, bool);<Note: float64 is not supported>
 (3) paddle.Tensor(int32, int64, float32, float64, bool);
上级 24ce051a
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/set_value_op.h"
#include <string>
namespace paddle {
namespace operators {
class SetValue : public framework::OperatorWithKernel {
public:
SetValue(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "SetValue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetValue");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.",
in_dims.size()));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class SetValueMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input", "(Tensor) Input tensor of set_value operator.");
AddInput("ValueTensor", "(Tensor) Value tensor of set_value operator.")
.AsDispensable();
AddOutput("Out",
"(Tensor) Output tensor of set_value operator. The output is the "
"same Tensor as input");
AddAttr<int>("dtype", "data type of input.")
.InEnum(
{framework::proto::VarType::BOOL, framework::proto::VarType::INT32,
framework::proto::VarType::INT64, framework::proto::VarType::FP32,
framework::proto::VarType::FP64})
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
AddAttr<std::vector<int64_t>>(
"starts",
"(list<int64_t>) Starting indices of corresponding axis in `axes`");
AddAttr<std::vector<int64_t>>(
"ends",
"(list<int64_t>) Ending indices of corresponding axis in `axes`.");
AddAttr<std::vector<int>>("bool_values", "store the bool values")
.SetDefault({});
AddAttr<std::vector<float>>("fp32_values", "store the float32 values")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "store the int32 values")
.SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values")
.SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
.SetDefault({});
AddComment(R"DOC(SetValue operator.
Assignment to a Tensor in static mode.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
set_value, ops::SetValue, ops::SetValueMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CPUDeviceContext, int>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, float>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, double>,
ops::SetValueKernel<paddle::platform::CPUDeviceContext, bool>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/set_value_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
set_value, ops::SetValueKernel<paddle::platform::CUDADeviceContext, int>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include <utility>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
std::string value_name;
switch (data_type) {
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
break;
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
case framework::proto::VarType::BOOL:
value_name = "bool_values";
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for SetValue operator, only "
"supports bool, int32, float32 and int64.",
data_type));
}
return value_name;
}
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<int64_t> axes,
const std::vector<int64_t> starts,
const std::vector<int64_t> ends) {
framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
int64_t axis = axes[i];
int64_t dim_value = in_dims[axis];
int64_t start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
int64_t end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, static_cast<int64_t>(0));
end = std::min(end, dim_value);
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
"end should greater than start, but "
"received end = %d, start = %d",
end, start));
slice_dims[axis] = end - start;
}
return slice_dims;
}
template <typename DeviceContext, typename T>
class SetValueKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
const int rank = ctx.Output<framework::LoDTensor>("Out")->dims().size();
// TODO(liym27): A more elegent code to do this. C++ has to make template
// integer as constant, but we had better have alternative writing in the
// future.
switch (rank) {
case 1:
SetValueCompute<1>(ctx);
break;
case 2:
SetValueCompute<2>(ctx);
break;
case 3:
SetValueCompute<3>(ctx);
break;
case 4:
SetValueCompute<4>(ctx);
break;
case 5:
SetValueCompute<5>(ctx);
break;
case 6:
SetValueCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void SetValueCompute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::LoDTensor>("Input");
auto* out = ctx.Output<framework::LoDTensor>("Out");
auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
auto starts = ctx.Attr<std::vector<int64_t>>("starts");
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto* value_tensor = ctx.Input<framework::LoDTensor>("ValueTensor");
auto in_dims = in->dims();
auto value_dims = framework::make_ddim(shape);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends);
auto place = ctx.GetPlace();
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
// Here copy data from input to avoid data loss at PE and Graph level.
// TODO(liym27): Speed up in the future version.
// - Q: Why don't call ShareDataWith to speed up?
// - A: Because it's not supported to ShareDataWith on OP's input and output
// https://github.com/PaddlePaddle/Paddle/wiki/ShareDataWith-and-ShareBufferWith-are-prohibited-in-OP
// - Q: Why don't delete Input, after all, the input and output are the same
// Tensor at program level?
// - A: If deleting Input, the graph will be complex, such as there will
// be two ops points to the output in graph: op1 -> output <- set_value.
// In this case, we have to find a way to handle the running order of
// set_value is what we want.
TensorCopy(*in, place, out);
Tensor slice_t(dtype), pad_t(dtype);
slice_t.mutable_data<T>(slice_dims, place);
pad_t.mutable_data<T>(in_dims, place);
auto pad_e = framework::EigenTensor<T, D>::From(pad_t, in_dims);
auto out_e = framework::EigenTensor<T, D>::From(*out);
auto slice_e = framework::EigenTensor<T, D>::From(slice_t, slice_dims);
// Step 1: Set the value of out at `_index` to zero
// - Step 1.1 Get a slice tensor from out
Eigen::array<int64_t, D> offsets, extents;
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
int64_t start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i] < 0 ? (starts[i] + in_dims[axes[i]]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axes[i]] = start;
}
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - slice_dims[i]) - offsets[i];
}
slice_e.device(eigen_place) = out_e.slice(offsets, extents);
// - Step 1.2 Get paded tensor by padding 0 to slice tensor
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0));
// - Step 1.3 Set 0 at `_index` of out tensor
out_e.device(eigen_place) = out_e - pad_e;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set the data of slice tensor to 0
slice_e.device(eigen_place) = slice_e.constant(T(0));
// - Step 2.2 Set slice tensor with value
if (value_tensor != nullptr) {
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, value_tensor, -1, SubFunctor<T>(), &slice_t);
} else {
Tensor value_t(dtype);
value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype);
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_t, &value_t, -1, SubFunctor<T>(), &slice_t);
}
// - Step 2.3 Pad slice tensor with 0
pad_e.device(eigen_place) = slice_e.pad(paddings, T(0));
// Step 3: Set out tensor with value_tensor
out_e.device(eigen_place) = out_e - pad_e;
}
};
} // namespace operators
} // namespace paddle
...@@ -1817,6 +1817,88 @@ class Variable(object): ...@@ -1817,6 +1817,88 @@ class Variable(object):
def __getitem__(self, item): def __getitem__(self, item):
return _getitem_impl_(self, item) return _getitem_impl_(self, item)
def __setitem__(self, item, value):
inputs = {'Input': self}
# 1. Parse item
if not isinstance(item, tuple):
item = [item]
axes = []
starts = []
ends = []
max_integer = sys.maxsize
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
end = slice_item.stop
step = slice_item.step
if start is None and end is None and step is None:
continue
start = 0 if start is None else start
step = 1 if step is None else step
# TODO: support cases when step != 1
if step != 1:
raise ValueError(
"When assign a value to a paddle.Tensor, only support step is 1, "
"but received step is {}.".format(step))
end = max_integer if end is None else end
else:
start = slice_item
end = slice_item + 1 if slice_item != -1 else max_integer
axes.append(dim)
starts.append(start)
ends.append(end)
attrs = {'axes': axes, 'starts': starts, 'ends': ends}
# 2. Parse value
dtype = self.dtype
attrs['dtype'] = dtype
# 2.1 value is an integer of float
if isinstance(value, (int, float)):
value = np.array([value])
# 2.2 value is a np.ndarray
if isinstance(value, np.ndarray):
shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
else:
from .data_feeder import convert_dtype
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape
elif isinstance(value, Variable):
inputs["ValueTensor"] = value
else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
"paddle.Tensor to a paddle.Tensor, but received {}".format(
type(value)))
self.block.append_op(
type="set_value", inputs=inputs, outputs={'Out': self}, attrs=attrs)
return self
def get_all_op_protos(): def get_all_op_protos():
""" """
......
...@@ -84,14 +84,25 @@ def test_slice_in_for_loop(x, iter_num=3): ...@@ -84,14 +84,25 @@ def test_slice_in_for_loop(x, iter_num=3):
return out return out
@paddle.jit.to_static
def test_set_value(x):
x = paddle.to_tensor(x)
x[0] = paddle.full(shape=[1], fill_value=2, dtype="float32")
x[1:2, 0:1] = 10
return x
class TestSliceWithoutControlFlow(unittest.TestCase): class TestSliceWithoutControlFlow(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.random.random((3)).astype('int32') self.init_input()
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
self.init_dygraph_func() self.init_dygraph_func()
paddle.disable_static() paddle.disable_static()
def init_input(self):
self.input = np.random.random((3)).astype('int32')
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_slice_without_control_flow self.dygraph_func = test_slice_without_control_flow
...@@ -125,10 +136,18 @@ class TestSliceInWhileLoop(TestSliceWithoutControlFlow): ...@@ -125,10 +136,18 @@ class TestSliceInWhileLoop(TestSliceWithoutControlFlow):
self.dygraph_func = test_slice_in_while_loop self.dygraph_func = test_slice_in_while_loop
class TestSliceInForLoop(TestSliceInWhileLoop): class TestSliceInForLoop(TestSliceWithoutControlFlow):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_slice_in_for_loop self.dygraph_func = test_slice_in_for_loop
class TestSetValue(TestSliceWithoutControlFlow):
def init_input(self):
self.input = np.full([3, 4, 5], 5).astype('float32')
def init_dygraph_func(self):
self.dygraph_func = test_set_value
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Test set_value op in static mode
from __future__ import print_function
import unittest
import numpy as np
import paddle
class TestSetValueBase(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.set_dtype()
self.set_value()
self.shape = [2, 3, 4]
self.data = np.ones(self.shape).astype(self.dtype)
self.program = paddle.static.Program()
def set_value(self):
self.value = 6
def set_dtype(self):
self.dtype = "float32"
def _call_setitem(self, x):
x[0, 0] = self.value
def _get_answer(self):
self.data[0, 0] = self.value
class TestSetValueApi(TestSetValueBase):
def test_api(self):
with paddle.static.program_guard(self.program):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
self._call_setitem(x)
exe = paddle.static.Executor(paddle.CPUPlace())
out = exe.run(self.program, fetch_list=[x])
self._get_answer()
self.assertTrue(
(self.data == out).all(),
msg="\nExpected res = \n{}, \n\nbut received : \n{}".format(
self.data, out))
# 1. Test different type of item: int, python slice
class TestSetValueItemInt(TestSetValueApi):
def _call_setitem(self, x):
x[0] = self.value
def _get_answer(self):
self.data[0] = self.value
class TestSetValueItemSlice(TestSetValueApi):
def _call_setitem(self, x):
x[0:2] = self.value
def _get_answer(self):
self.data[0:2] = self.value
class TestSetValueItemSlice2(TestSetValueApi):
def _call_setitem(self, x):
x[0:-1] = self.value
def _get_answer(self):
self.data[0:-1] = self.value
class TestSetValueItemSlice3(TestSetValueApi):
def _call_setitem(self, x):
x[0:-1, 0:2] = self.value
def _get_answer(self):
self.data[0:-1, 0:2] = self.value
class TestSetValueItemSlice4(TestSetValueApi):
def _call_setitem(self, x):
x[0:, 1:2, :] = self.value
def _get_answer(self):
self.data[0:, 1:2, :] = self.value
# 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, bool
def create_test_value_int32(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 7
def set_dtype(self):
self.dtype = "int32"
cls_name = "{0}_{1}".format(parent.__name__, "ValueInt32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_int32(TestSetValueItemInt)
create_test_value_int32(TestSetValueItemSlice)
create_test_value_int32(TestSetValueItemSlice2)
create_test_value_int32(TestSetValueItemSlice3)
create_test_value_int32(TestSetValueItemSlice4)
def create_test_value_int64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 7
def set_dtype(self):
self.dtype = "int64"
cls_name = "{0}_{1}".format(parent.__name__, "ValueInt64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_int64(TestSetValueItemInt)
create_test_value_int64(TestSetValueItemSlice)
create_test_value_int64(TestSetValueItemSlice2)
create_test_value_int64(TestSetValueItemSlice3)
create_test_value_int64(TestSetValueItemSlice4)
def create_test_value_fp32(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 3.3
def set_dtype(self):
self.dtype = "float32"
cls_name = "{0}_{1}".format(parent.__name__, "ValueFp32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_fp32(TestSetValueItemInt)
create_test_value_fp32(TestSetValueItemSlice)
create_test_value_fp32(TestSetValueItemSlice2)
create_test_value_fp32(TestSetValueItemSlice3)
create_test_value_fp32(TestSetValueItemSlice4)
def create_test_value_bool(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 0
def set_dtype(self):
self.dtype = "bool"
cls_name = "{0}_{1}".format(parent.__name__, "ValueBool")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_bool(TestSetValueItemInt)
create_test_value_bool(TestSetValueItemSlice)
create_test_value_bool(TestSetValueItemSlice2)
create_test_value_bool(TestSetValueItemSlice3)
create_test_value_bool(TestSetValueItemSlice4)
# 2.2 value is numpy.array (int32, int64, float32, bool)
def create_test_value_numpy_int32(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array([5])
def set_dtype(self):
self.dtype = "int32"
cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyInt32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_int32(TestSetValueItemInt)
create_test_value_numpy_int32(TestSetValueItemSlice)
create_test_value_numpy_int32(TestSetValueItemSlice2)
create_test_value_numpy_int32(TestSetValueItemSlice3)
create_test_value_numpy_int32(TestSetValueItemSlice4)
def create_test_value_numpy_int64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array([1])
def set_dtype(self):
self.dtype = "int64"
cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyInt64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_int64(TestSetValueItemInt)
create_test_value_numpy_int64(TestSetValueItemSlice)
create_test_value_numpy_int64(TestSetValueItemSlice2)
create_test_value_numpy_int64(TestSetValueItemSlice3)
create_test_value_numpy_int64(TestSetValueItemSlice4)
def create_test_value_numpy_fp32(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array([1])
def set_dtype(self):
self.dtype = "float32"
cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyFp32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_fp32(TestSetValueItemInt)
create_test_value_numpy_fp32(TestSetValueItemSlice)
create_test_value_numpy_fp32(TestSetValueItemSlice2)
create_test_value_numpy_fp32(TestSetValueItemSlice3)
create_test_value_numpy_fp32(TestSetValueItemSlice4)
def create_test_value_numpy_bool(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array([0])
def set_dtype(self):
self.dtype = "bool"
cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyBool")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_bool(TestSetValueItemInt)
create_test_value_numpy_bool(TestSetValueItemSlice)
create_test_value_numpy_bool(TestSetValueItemSlice2)
create_test_value_numpy_bool(TestSetValueItemSlice3)
create_test_value_numpy_bool(TestSetValueItemSlice4)
# 2.3 value is a Paddle Tensor (int32, int64, float32, float64, bool)
def create_test_value_tensor_int32(parent):
class TestValueInt(parent):
def set_dtype(self):
self.dtype = "int32"
def _call_setitem(self, x):
value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype)
x[0, 1] = value
def _get_answer(self):
self.data[0, 1] = 3
cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorInt32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_tensor_int32(TestSetValueItemInt)
create_test_value_tensor_int32(TestSetValueItemSlice)
create_test_value_tensor_int32(TestSetValueItemSlice2)
create_test_value_tensor_int32(TestSetValueItemSlice3)
create_test_value_tensor_int32(TestSetValueItemSlice4)
def create_test_value_tensor_int64(parent):
class TestValueInt(parent):
def set_dtype(self):
self.dtype = "int64"
def _call_setitem(self, x):
value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype)
x[0, 1] = value
def _get_answer(self):
self.data[0, 1] = 3
cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorInt64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_tensor_int64(TestSetValueItemInt)
create_test_value_tensor_int64(TestSetValueItemSlice)
create_test_value_tensor_int64(TestSetValueItemSlice2)
create_test_value_tensor_int64(TestSetValueItemSlice3)
create_test_value_tensor_int64(TestSetValueItemSlice4)
def create_test_value_tensor_fp32(parent):
class TestValueInt(parent):
def set_dtype(self):
self.dtype = "float32"
def _call_setitem(self, x):
value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype)
x[0, 1] = value
def _get_answer(self):
self.data[0, 1] = 3
cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorFp32")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_tensor_fp32(TestSetValueItemInt)
create_test_value_tensor_fp32(TestSetValueItemSlice)
create_test_value_tensor_fp32(TestSetValueItemSlice2)
create_test_value_tensor_fp32(TestSetValueItemSlice3)
create_test_value_tensor_fp32(TestSetValueItemSlice4)
def create_test_value_tensor_fp64(parent):
class TestValueInt(parent):
def set_dtype(self):
self.dtype = "float64"
def _call_setitem(self, x):
value = paddle.full(shape=[1], fill_value=3, dtype=self.dtype)
x[0, 1] = value
def _get_answer(self):
self.data[0, 1] = 3
cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorFp64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_tensor_fp64(TestSetValueItemInt)
create_test_value_tensor_fp64(TestSetValueItemSlice)
create_test_value_tensor_fp64(TestSetValueItemSlice2)
create_test_value_tensor_fp64(TestSetValueItemSlice3)
create_test_value_tensor_fp64(TestSetValueItemSlice4)
def create_test_value_tensor_bool(parent):
class TestValueInt(parent):
def set_dtype(self):
self.dtype = "bool"
def _call_setitem(self, x):
value = paddle.full(shape=[1], fill_value=False, dtype=self.dtype)
x[0, 1] = value
def _get_answer(self):
self.data[0, 1] = False
cls_name = "{0}_{1}".format(parent.__name__, "ValueTensorBool")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_tensor_bool(TestSetValueItemInt)
create_test_value_tensor_bool(TestSetValueItemSlice)
create_test_value_tensor_bool(TestSetValueItemSlice2)
create_test_value_tensor_bool(TestSetValueItemSlice3)
create_test_value_tensor_bool(TestSetValueItemSlice4)
# 3. Test different shape of value
class TestSetValueValueShape1(TestSetValueApi):
def set_value(self):
self.value = np.array([3, 4, 5, 6]) # shape is (4,)
def _call_setitem(self, x):
x[0] = self.value
def _get_answer(self):
self.data[0] = self.value
class TestSetValueValueShape2(TestSetValueApi):
def set_value(self):
self.value = np.array([[3, 4, 5, 6]]) # shape is (1,4)
def _call_setitem(self, x):
x[0:1] = self.value
def _get_answer(self):
self.data[0:1] = self.value
class TestSetValueValueShape3(TestSetValueApi):
def set_value(self):
self.value = np.array(
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]) # shape is (3,4)
def _call_setitem(self, x):
x[0] = self.value
def _get_answer(self):
self.data[0] = self.value
class TestSetValueValueShape4(TestSetValueApi):
def set_value(self):
self.value = np.array(
[[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]).astype(
self.dtype) # shape is (3,4)
def _call_setitem(self, x):
x[0] = paddle.assign(self.value) # x is Paddle.Tensor
def _get_answer(self):
self.data[0] = self.value
# 4. Test error
class TestError(TestSetValueBase):
def _value_type_error(self):
with self.assertRaisesRegexp(
TypeError,
"Only support to assign an integer, float, numpy.ndarray or paddle.Tensor"
):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
value = [1]
x[0] = value
def _dtype_error(self):
with self.assertRaisesRegexp(
TypeError,
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
):
y = paddle.ones(shape=self.shape, dtype="float64")
y[0] = 1
def _step_error(self):
with self.assertRaisesRegexp(ValueError, "only support step is 1"):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
x[0:1:2] = self.value
def _broadcast_mismatch(self):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
value = np.array([3, 4, 5, 6, 7])
x[0] = value
exe = paddle.static.Executor(paddle.CPUPlace())
with self.assertRaisesRegexp(ValueError,
"Broadcast dimension mismatch."):
exe.run(program)
def test_error(self):
with paddle.static.program_guard(self.program):
self._value_type_error()
self._dtype_error()
self._step_error()
self._broadcast_mismatch()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册