one_hot_op_test.cc 2.0 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/**
 * Copyright 2019 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "common/common.h"
#include "dataset/kernels/data/one_hot_op.h"
#include "utils/log_adapter.h"

using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;

H
hesham 已提交
25
class MindDataTestOneHotOp : public UT::Common {
Z
zhunaipan 已提交
26
 protected:
H
hesham 已提交
27
    MindDataTestOneHotOp() {}
Z
zhunaipan 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
};

TEST_F(MindDataTestOneHotOp, TestOp) {
  MS_LOG(INFO) << "Doing MindDataTestOneHotOp.";
  uint64_t labels[3] = {0, 1, 2};
  TensorShape shape({3});
  std::shared_ptr<Tensor> input = std::make_shared<Tensor>(shape, DataType(DataType::DE_UINT64),
                                                           reinterpret_cast <unsigned char *>(labels));
  std::shared_ptr<Tensor> output;

  std::unique_ptr<OneHotOp> op(new OneHotOp(5));
  Status s = op->Compute(input, &output);
  uint64_t out[15] = {1, 0, 0, 0, 0,
                      0, 1, 0, 0, 0,
                      0, 0, 1, 0, 0};
  std::shared_ptr<Tensor> expected = std::make_shared<Tensor>(TensorShape{3, 5}, DataType(DataType::DE_UINT64),
                                                              reinterpret_cast <unsigned char *>(out));
  EXPECT_TRUE(s.IsOk());
  ASSERT_TRUE(output->shape() == expected->shape());
  ASSERT_TRUE(output->type() == expected->type());
48 49
  MS_LOG(DEBUG) << *output << std::endl;
  MS_LOG(DEBUG) << *expected << std::endl;
Z
zhunaipan 已提交
50 51 52 53

  ASSERT_TRUE(*output == *expected);
  MS_LOG(INFO) << "MindDataTestOneHotOp end.";
}