scatter_test.cc 2.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Z
zchen0211 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/scatter.h"
Z
zchen0211 已提交
16 17 18
#include <gtest/gtest.h>
#include <iostream>
#include <string>
19 20 21
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
Z
zchen0211 已提交
22 23

TEST(scatter, ScatterUpdate) {
24 25 26
  // using namespace paddle::framework;
  // using namespace paddle::platform;
  // using namespace paddle::operators;
Z
zchen0211 已提交
27

28 29 30
  paddle::framework::Tensor* src = new paddle::framework::Tensor();
  paddle::framework::Tensor* index = new paddle::framework::Tensor();
  paddle::framework::Tensor* output = new paddle::framework::Tensor();
Z
zchen0211 已提交
31 32 33

  float* p_src = nullptr;
  int* p_index = nullptr;
34 35 36 37
  p_src = src->mutable_data<float>(paddle::framework::make_ddim({1, 4}),
                                   paddle::platform::CPUPlace());
  p_index = index->mutable_data<int>(paddle::framework::make_ddim({1}),
                                     paddle::platform::CPUPlace());
Z
zchen0211 已提交
38

39
  for (size_t i = 0; i < 4; ++i) p_src[i] = static_cast<float>(i);
Z
zchen0211 已提交
40 41
  p_index[0] = 1;

42 43
  float* p_output = output->mutable_data<float>(
      paddle::framework::make_ddim({4, 4}), paddle::platform::CPUPlace());
Z
zchen0211 已提交
44

Z
zchen0211 已提交
45 46
  auto* cpu_place = new paddle::platform::CPUPlace();
  paddle::platform::CPUDeviceContext ctx(*cpu_place);
47
  paddle::operators::ScatterAssign<float>(ctx, *src, *index, output);
Z
zchen0211 已提交
48

49 50 51 52 53
  for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f);
  for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
  for (size_t i = 4; i < 8; ++i) {
    EXPECT_EQ(p_output[i], static_cast<float>(i - 4));
  }
Z
zchen0211 已提交
54
  for (size_t i = 4; i < 8; ++i)
55 56 57
    EXPECT_EQ(output->data<float>()[i], static_cast<float>(i - 4));
  for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], 0.0f);
  for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
L
liaogang 已提交
58 59 60 61

  delete src;
  delete index;
  delete output;
Z
zchen0211 已提交
62
}