提交 dd9bf09f 编写于 作者: N nhussain

added FillOp for #119 - special Ops

上级 2005ecc2
......@@ -38,6 +38,7 @@
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/kernels/data/fill_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
......@@ -350,6 +351,10 @@ void bindTensorOps2(py::module *m) {
*m, "OneHotOp", "Tensor operation to apply one hot encoding. Takes number of classes.")
.def(py::init<int32_t>());
(void)py::class_<FillOp, TensorOp, std::shared_ptr<FillOp>>(
*m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.")
.def(py::init<std::shared_ptr<Tensor>>());
(void)py::class_<RandomRotationOp, TensorOp, std::shared_ptr<RandomRotationOp>>(
*m, "RandomRotationOp",
"Tensor operation to apply RandomRotation."
......
......@@ -5,4 +5,4 @@ add_library(kernels-data OBJECT
one_hot_op.cc
type_cast_op.cc
to_float16_op.cc
)
fill_op.cc)
......@@ -23,6 +23,7 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/core/data_type.h"
#include "dataset/core/pybind_support.h"
#include "dataset/kernels/data/type_cast_op.h"
namespace mindspore {
namespace dataset {
......@@ -78,6 +79,7 @@ Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_pt
Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, dsize_t num_classes) {
input->Squeeze();
if (input->Rank() > 1) { // We expect the input to be int he first dimension
RETURN_STATUS_UNEXPECTED("One hot only supports scalars or 1D shape Tensors.");
}
......@@ -106,11 +108,121 @@ Status OneHotEncoding(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *ou
}
}
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value) {
CHECK_FAIL_RETURN_UNEXPECTED(!((fill_value->type() == DataType::DE_STRING) && (input->type() != DataType::DE_STRING)),
"Types do not match");
CHECK_FAIL_RETURN_UNEXPECTED(fill_value->shape() == TensorShape({}), "fill_value is not a scalar");
std::shared_ptr<Tensor> out;
const DataType &to = input->type();
std::unique_ptr<TypeCastOp> op(new TypeCastOp(to));
std::shared_ptr<Tensor> fill_output;
op->Compute(fill_value, &fill_output);
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, TensorImpl::kFlexible, input->shape(), input->type()));
switch (input->type().value()) {
case DataType::DE_BOOL: {
bool value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<bool>(value);
break;
}
case DataType::DE_INT8: {
int8_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<int8_t>(value);
break;
}
case DataType::DE_UINT8: {
uint8_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<uint8_t>(value);
break;
}
case DataType::DE_UINT16: {
uint16_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<uint16_t>(value);
break;
}
case DataType::DE_INT16: {
int16_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<int16_t>(value);
break;
}
case DataType::DE_UINT32: {
uint32_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<uint32_t>(value);
break;
}
case DataType::DE_INT32: {
int32_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<int32_t>(value);
break;
}
case DataType::DE_UINT64: {
uint64_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<uint64_t>(value);
break;
}
case DataType::DE_INT64: {
int64_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<int64_t>(value);
break;
}
case DataType::DE_FLOAT16: {
int64_t value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<float>(value);
break;
}
case DataType::DE_FLOAT32: {
float value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<float>(value);
break;
}
case DataType::DE_FLOAT64: {
double value = 0;
RETURN_IF_NOT_OK(fill_output->GetItemAt(&value, {}));
out->Fill<double>(value);
break;
}
case DataType::DE_STRING: {
std::vector<std::string> strings;
std::string_view fill_string_view;
RETURN_IF_NOT_OK(fill_value->GetItemAt(&fill_string_view, {}));
std::string fill_string = std::string(fill_string_view);
for (int i = 0; i < input->shape().NumOfElements(); i++) {
strings.emplace_back(fill_string);
}
RETURN_IF_NOT_OK(Tensor::CreateTensor(&out, strings, input->shape()));
break;
}
case DataType::DE_UNKNOWN: {
RETURN_STATUS_UNEXPECTED("FillOp does not support input of this type.");
break;
}
}
*output = out;
return Status::OK();
}
template <typename FROM, typename TO>
void Cast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
auto in_itr = input->begin<FROM>();
auto out_itr = (*output)->begin<TO>();
auto out_end = (*output)->end<TO>();
for (; out_itr != out_end; static_cast<void>(in_itr++), static_cast<void>(out_itr++))
*out_itr = static_cast<TO>(*in_itr);
}
......
......@@ -43,6 +43,13 @@ Status OneHotEncodingUnsigned(const std::shared_ptr<Tensor> &input, std::shared_
Status OneHotEncodingSigned(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, dsize_t num_classes,
int64_t index);
// Returns a tensor of shape input filled with the passed fill_value
// @param input Tensor
// @param output Tensor. The shape and type of the output tensor is same as input
// @param fill_value Tensor. A scalar tensor used to fill the output tensor
Status Fill(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, std::shared_ptr<Tensor> fill_value);
// Returns a type changed input tensor.
// Example: if input tensor is float64, the output will the specified dataType. See DataTypes.cpp
// @param input Tensor
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/data/fill_op.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/data/data_utils.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
Status FillOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
Status s = Fill(input, output, fill_value_);
return s;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_DATA_FILL_OP_H_
#define DATASET_KERNELS_DATA_FILL_OP_H_
#include <string>
#include <vector>
#include <memory>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class FillOp : public TensorOp {
public:
explicit FillOp(std::shared_ptr<Tensor> value) : fill_value_(value) {}
~FillOp() override = default;
void Print(std::ostream &out) const override { out << "FillOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
private:
std::shared_ptr<Tensor> fill_value_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_FILL_OP_H
......@@ -15,9 +15,9 @@
"""
This module c_transforms provides common operations, including OneHotOp and TypeCast.
"""
import numpy as np
import mindspore._c_dataengine as cde
from .validators import check_num_classes, check_de_type
from .validators import check_num_classes, check_de_type, check_fill_value
from ..core.datatypes import mstype_to_detype
......@@ -35,6 +35,22 @@ class OneHot(cde.OneHotOp):
super().__init__(num_classes)
class Fill(cde.FillOp):
"""
Tensor operation to create a tensor filled with passed scalar value.
The output tensor will have the same shape and type as the input tensor.
Args:
fill_value (python types (str, int, float, or bool)) : scalar value
to fill created tensor with.
"""
@check_fill_value
def __init__(self, fill_value):
print(fill_value)
super().__init__(cde.Tensor(np.array(fill_value)))
class TypeCast(cde.TypeCastOp):
"""
Tensor operation to cast to a given MindSpore data type.
......
......@@ -17,7 +17,6 @@
from functools import wraps
from mindspore._c_expression import typing
# POS_INT_MIN is used to limit values from starting from 0
POS_INT_MIN = 1
UINT8_MAX = 255
......@@ -159,6 +158,25 @@ def check_num_classes(method):
return new_method
def check_fill_value(method):
"""Wrapper method to check the parameters of fill value."""
@wraps(method)
def new_method(self, *args, **kwargs):
fill_value = (list(args) + [None])[0]
if "fill_value" in kwargs:
fill_value = kwargs.get("fill_value")
if fill_value is None:
raise ValueError("fill_value is not provided.")
if not isinstance(fill_value, (str, float, bool, int)):
raise TypeError("fill_value must be either a primitive python str, float, bool, or int")
kwargs["fill_value"] = fill_value
return method(self, **kwargs)
return new_method
def check_de_type(method):
"""Wrapper method to check the parameters of data type."""
......
......@@ -72,6 +72,7 @@ SET(DE_UT_SRCS
tokenizer_op_test.cc
gnn_graph_test.cc
coco_op_test.cc
fill_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "dataset/kernels/data/fill_op.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestFillOp : public UT::Common {
protected:
MindDataTestFillOp() {}
};
TEST_F(MindDataTestFillOp, TestOp) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-TestOp.";
uint64_t labels[3] = {1, 1, 2};
TensorShape shape({3});
std::shared_ptr<Tensor> input =
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
TensorShape fill_shape({});
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64));
fill_tensor->SetItemAt<uint64_t>({}, 4);
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
uint64_t out[3] = {4, 4, 4};
std::shared_ptr<Tensor> expected =
std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out));
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output->shape() == expected->shape());
ASSERT_TRUE(output->type() == expected->type());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected << std::endl;
ASSERT_TRUE(*output == *expected);
MS_LOG(INFO) << "MindDataTestFillOp-TestOp end.";
}
TEST_F(MindDataTestFillOp, TestCasting) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-TestCasting.";
uint64_t labels[3] = {0, 1, 2};
TensorShape shape({3});
std::shared_ptr<Tensor> input =
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
TensorShape fill_shape({});
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32));
fill_tensor->SetItemAt<float>({}, 2.0);
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
uint64_t out[3] = {2, 2, 2};
std::shared_ptr<Tensor> expected =
std::make_shared<Tensor>(TensorShape{3}, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(out));
ASSERT_TRUE(output->shape() == expected->shape());
ASSERT_TRUE(output->type() == expected->type());
EXPECT_TRUE(s.IsOk());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected << std::endl;
ASSERT_TRUE(*output == *expected);
MS_LOG(INFO) << "MindDataTestFillOp-TestCasting end.";
}
TEST_F(MindDataTestFillOp, ScalarFill) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-ScalarFill.";
uint64_t labels[3] = {0, 1, 2};
TensorShape shape({3});
std::shared_ptr<Tensor> input =
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
TensorShape fill_shape({2});
uint64_t fill_labels[3] = {0, 1};
std::shared_ptr<Tensor> fill_tensor =
std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(fill_labels));
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
EXPECT_TRUE(s.IsError());
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
MS_LOG(INFO) << "MindDataTestFillOp-ScalarFill end.";
}
TEST_F(MindDataTestFillOp, StringFill) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-StringFill.";
std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"};
TensorShape shape({3});
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
TensorShape fill_shape({});
std::string fill_string = "hello";
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string);
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
std::vector<std::string> expected_strings = {"hello", "hello", "hello"};
TensorShape expected_shape({3});
std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(expected_strings, expected_shape);
EXPECT_TRUE(s.IsOk());
ASSERT_TRUE(output->shape() == expected->shape());
ASSERT_TRUE(output->type() == expected->type());
MS_LOG(DEBUG) << *output << std::endl;
MS_LOG(DEBUG) << *expected << std::endl;
ASSERT_TRUE(*output == *expected);
MS_LOG(INFO) << "MindDataTestFillOp-StringFill end.";
}
TEST_F(MindDataTestFillOp, NumericToString) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-NumericToString.";
std::vector<std::string> strings = {"xyzzy", "plugh", "abracadabra"};
TensorShape shape({3});
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(strings, shape);
TensorShape fill_shape({});
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_shape, DataType(DataType::DE_FLOAT32));
fill_tensor->SetItemAt<float>({}, 2.0);
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
EXPECT_TRUE(s.IsError());
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
MS_LOG(INFO) << "MindDataTestFillOp-NumericToString end.";
}
TEST_F(MindDataTestFillOp, StringToNumeric) {
MS_LOG(INFO) << "Doing MindDataTestFillOp-StringToNumeric.";
uint64_t labels[3] = {0, 1, 2};
TensorShape shape({3});
std::shared_ptr<Tensor> input =
std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64), reinterpret_cast<unsigned char *>(labels));
TensorShape fill_shape({});
std::string fill_string = "hello";
std::shared_ptr<Tensor> fill_tensor = std::make_shared<Tensor>(fill_string);
std::shared_ptr<Tensor> output;
std::unique_ptr<FillOp> op(new FillOp(fill_tensor));
Status s = op->Compute(input, &output);
EXPECT_TRUE(s.IsError());
ASSERT_TRUE(s.get_code() == StatusCode::kUnexpectedError);
MS_LOG(INFO) << "MindDataTestFillOp-StringToNumeric end.";
}
\ No newline at end of file
......@@ -13,9 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//
// Created by jesse on 10/3/19.
//
#include "common/common.h"
#include "gtest/gtest.h"
......@@ -25,32 +22,32 @@
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestQueue : public UT::Common {
public:
MindDataTestQueue() {}
MindDataTestQueue() {}
void SetUp() {}
void SetUp() {}
};
int gRefCountDestructorCalled;
class RefCount {
public:
RefCount() : v_(nullptr) {}
explicit RefCount(int x) : v_(std::make_shared<int>(x)) {}
explicit RefCount(const RefCount &o) : v_(o.v_) {}
~RefCount() {
MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl;
gRefCountDestructorCalled++;
}
RefCount& operator=(const RefCount &o) {
v_ = o.v_;
return *this;
}
RefCount() : v_(nullptr) {}
explicit RefCount(int x) : v_(std::make_shared<int>(x)) {}
explicit RefCount(const RefCount &o) : v_(o.v_) {}
~RefCount() {
MS_LOG(DEBUG) << "Destructor of RefCount called" << std::endl;
gRefCountDestructorCalled++;
}
RefCount &operator=(const RefCount &o) {
v_ = o.v_;
return *this;
}
std::shared_ptr<int> v_;
};
......@@ -70,22 +67,22 @@ TEST_F(MindDataTestQueue, Test1) {
// Use count should remain 2. a and b. No copy in the queue.
ASSERT_EQ(a.use_count(), 2);
a.reset(new int(5));
ASSERT_EQ(a.use_count(),1);
ASSERT_EQ(a.use_count(), 1);
// Push again but expect a is nullptr after push
rc = que.Add(std::move(a));
ASSERT_TRUE(rc.IsOk());
ASSERT_EQ(a.use_count(),0);
ASSERT_EQ(a.use_count(), 0);
rc = que.PopFront(&b);
ASSERT_TRUE(rc.IsOk());
ASSERT_EQ(*b, 5);
ASSERT_EQ(b.use_count(),1);
ASSERT_EQ(b.use_count(), 1);
// Test construct in place
rc = que.EmplaceBack(std::make_shared<int>(100));
ASSERT_TRUE(rc.IsOk());
rc = que.PopFront(&b);
ASSERT_TRUE(rc.IsOk());
ASSERT_EQ(*b, 100);
ASSERT_EQ(b.use_count(),1);
ASSERT_EQ(b.use_count(), 1);
// Test the destructor of the Queue by add an element in the queue without popping it and let the queue go
// out of scope.
rc = que.EmplaceBack(std::make_shared<int>(2000));
......@@ -127,7 +124,7 @@ TEST_F(MindDataTestQueue, Test3) {
ASSERT_EQ(*b, 40);
}
void test4(){
void test4() {
gRefCountDestructorCalled = 0;
// Pass a structure along the queue.
Queue<RefCount> que(3);
......@@ -144,9 +141,7 @@ void test4(){
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestQueue, Test4) {
test4();
}
TEST_F(MindDataTestQueue, Test4) { test4(); }
TEST_F(MindDataTestQueue, Test5) {
test4();
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing fill op
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as data_trans
def test_fillop_basic():
def gen():
yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
data = ds.GeneratorDataset(gen, column_names=["col"])
fill_op = data_trans.Fill(3)
data = data.map(input_columns=["col"], operations=fill_op)
expected = np.array([3, 3, 3, 3], dtype=np.uint8)
for data_row in data:
np.testing.assert_array_equal(data_row[0], expected)
def test_fillop_down_type_cast():
def gen():
yield (np.array([4, 5, 6, 7], dtype=np.uint8),)
data = ds.GeneratorDataset(gen, column_names=["col"])
fill_op = data_trans.Fill(-3)
data = data.map(input_columns=["col"], operations=fill_op)
expected = np.array([253, 253, 253, 253], dtype=np.uint8)
for data_row in data:
np.testing.assert_array_equal(data_row[0], expected)
def test_fillop_up_type_cast():
def gen():
yield (np.array([4, 5, 6, 7], dtype=np.float),)
data = ds.GeneratorDataset(gen, column_names=["col"])
fill_op = data_trans.Fill(3)
data = data.map(input_columns=["col"], operations=fill_op)
expected = np.array([3., 3., 3., 3.], dtype=np.float)
for data_row in data:
np.testing.assert_array_equal(data_row[0], expected)
def test_fillop_string():
def gen():
yield (np.array(["45555", "45555"], dtype='S'),)
data = ds.GeneratorDataset(gen, column_names=["col"])
fill_op = data_trans.Fill("error")
data = data.map(input_columns=["col"], operations=fill_op)
expected = np.array(['error', 'error'], dtype='S')
for data_row in data:
np.testing.assert_array_equal(data_row[0], expected)
def test_fillop_error_handling():
def gen():
yield (np.array([4, 4, 4, 4]),)
data = ds.GeneratorDataset(gen, column_names=["col"])
fill_op = data_trans.Fill("words")
data = data.map(input_columns=["col"], operations=fill_op)
with pytest.raises(RuntimeError) as error_info:
for data_row in data:
print(data_row)
assert "Types do not match" in repr(error_info.value)
if __name__ == "__main__":
test_fillop_basic()
test_fillop_up_type_cast()
test_fillop_down_type_cast()
test_fillop_string()
test_fillop_error_handling()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册