operator_test.cc 7.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
22 23 24
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
25
 public:
Q
Qiao Longfei 已提交
26
  void Init() override { x = 1; }
Y
Yu Yang 已提交
27 28
  void InferShape(const Scope& scope) const override {}
  void Run(const Scope& scope,
Y
Yu Yang 已提交
29
           const platform::DeviceContext& dev_ctx) const override {
Y
Yu Yang 已提交
30 31 32 33
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
34
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
35
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
36
  }
Q
Qiao Longfei 已提交
37 38 39

 public:
  float x = 0;
Q
Qiao Longfei 已提交
40 41
};

Q
Qiao Longfei 已提交
42
class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
43
 public:
Q
Qiao Longfei 已提交
44 45
  OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto,
                                           OpAttrChecker* op_checker)
Q
Qiao Longfei 已提交
46 47 48
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
49
    AddAttr<float>("scale", "scale of cosine op");
Q
Qiao Longfei 已提交
50 51 52 53 54 55 56
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Q
Qiao Longfei 已提交
57 58
REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest,
            paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
Q
Qiao Longfei 已提交
59 60 61 62

TEST(OperatorBase, all) {
  paddle::framework::OpDesc op_desc;
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
63
  auto* ipt = op_desc.mutable_inputs()->Add();
64 65
  *ipt->mutable_arguments()->Add() = "IN1";
  ipt->set_parameter("input");
Y
Yu Yang 已提交
66 67

  auto* output = op_desc.mutable_outputs()->Add();
68 69
  *output->mutable_arguments()->Add() = "OUT1";
  output->set_parameter("output");
Q
Qiao Longfei 已提交
70 71 72
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
Q
Qiao Longfei 已提交
73
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
74 75

  paddle::platform::CPUDeviceContext device_context;
Y
Yu Yang 已提交
76
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
77

Y
Yu Yang 已提交
78
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yu Yang 已提交
79
  scope.NewVar("OUT1");
Q
Qiao Longfei 已提交
80
  ASSERT_EQ(paddle::framework::op_run_num, 0);
81
  op->InferShape(scope);
Q
Qiao Longfei 已提交
82
  op->Run(scope, device_context);
Q
Qiao Longfei 已提交
83
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
84 85 86 87 88
}

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
89 90 91 92
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yan Chunwei 已提交
93 94 95 96 97
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
        .LargerThan(0.0);
Q
Qiao Longfei 已提交
98 99 100 101
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
102 103
static int cpu_kernel_run_num = 0;

Q
Qiao Longfei 已提交
104
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
105
 protected:
106
  void InferShape(const framework::InferShapeContext& ctx) const override {}
Q
Qiao Longfei 已提交
107 108
};

109
template <typename T1, typename T2>
Q
Qiao Longfei 已提交
110 111
class CPUKernelTest : public OpKernel {
 public:
112
  void Compute(const ExecutionContext& ctx) const {
Y
Yan Chunwei 已提交
113 114
    std::cout << "this is cpu kernel" << std::endl;
    std::cout << ctx.op_.DebugString() << std::endl;
Q
Qiao Longfei 已提交
115
    cpu_kernel_run_num++;
Y
Yan Chunwei 已提交
116 117 118 119 120 121 122 123 124 125 126
    ASSERT_EQ(ctx.op_.Input("x"), "IN1");
    ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
  }
};

class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
  OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
                                              OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
127
    AddInput("xs", "inputs of test op").SetMultiple();
Y
Yan Chunwei 已提交
128
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
129
    AddOutput("ys", "outputs of test op").SetMultiple();
Y
Yan Chunwei 已提交
130 131 132 133 134 135 136 137 138
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
        .LargerThan(0.0);
    AddComment("This is test op");
  }
};

class CPUKernalMultiInputsTest : public OpKernel {
 public:
139
  void Compute(const ExecutionContext& ctx) const {
Y
Yan Chunwei 已提交
140 141 142 143 144 145
    auto xs = ctx.op_.Inputs("xs");
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

146
    auto inVar0 = ctx.MultiInputVar("xs");
147
    ASSERT_EQ(inVar0.size(), 3U);
148 149 150 151 152

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
153
    ASSERT_EQ(outVar0.size(), 2U);
154 155

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
156
    ASSERT_EQ(inTensor0.size(), 3U);
157 158 159 160 161

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
162
    ASSERT_EQ(outTensor0.size(), 2U);
163

Y
Yan Chunwei 已提交
164 165 166 167 168 169 170
    auto k = ctx.op_.Input("k");
    ASSERT_EQ(k, "k0");

    auto ys = ctx.op_.Outputs("ys");
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
171 172 173
  }
};

Y
Yu Yang 已提交
174 175 176 177 178
}  // namespace framework
}  // namespace paddle

REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
            paddle::framework::OpKernelTestProtoAndCheckerMaker);
179 180
REGISTER_OP_CPU_KERNEL(op_with_kernel,
                       paddle::framework::CPUKernelTest<float, float>);
Q
Qiao Longfei 已提交
181

Y
Yan Chunwei 已提交
182
// test with single input
Q
Qiao Longfei 已提交
183
TEST(OpKernel, all) {
Q
Qiao Longfei 已提交
184
  paddle::framework::OpDesc op_desc;
Q
Qiao Longfei 已提交
185
  op_desc.set_type("op_with_kernel");
Y
Yu Yang 已提交
186
  auto* ipt = op_desc.mutable_inputs()->Add();
187
  *ipt->mutable_arguments()->Add() = "IN1";
188
  ipt->set_parameter("x");
Y
Yu Yang 已提交
189 190

  auto* output = op_desc.mutable_outputs()->Add();
191
  *output->mutable_arguments()->Add() = "OUT1";
192
  output->set_parameter("y");
Y
Yu Yang 已提交
193

Q
Qiao Longfei 已提交
194 195 196 197 198
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
  attr->set_f(3.14);

Y
Yu Yang 已提交
199
  paddle::platform::CPUDeviceContext cpu_device_context;
Y
Yu Yang 已提交
200
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
201

Y
Yu Yang 已提交
202
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
203
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
Q
Qiao Longfei 已提交
204
  op->Run(scope, cpu_device_context);
Q
Qiao Longfei 已提交
205
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
Q
Qiao Longfei 已提交
206
}
Y
Yan Chunwei 已提交
207 208 209 210 211 212 213 214 215 216 217 218

REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
            paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
  using namespace paddle::framework;

  OpDesc op_desc;
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
219
  auto x = op_desc.mutable_inputs()->Add();
220 221 222 223
  x->set_parameter("xs");
  *x->mutable_arguments()->Add() = "x0";
  *x->mutable_arguments()->Add() = "x1";
  *x->mutable_arguments()->Add() = "x2";
Y
Yu Yang 已提交
224
  auto k = op_desc.mutable_inputs()->Add();
225 226
  k->set_parameter("k");
  *k->mutable_arguments()->Add() = "k0";
Y
Yu Yang 已提交
227
  auto y = op_desc.mutable_outputs()->Add();
228 229 230
  y->set_parameter("ys");
  *y->mutable_arguments()->Add() = "y0";
  *y->mutable_arguments()->Add() = "y1";
Y
Yu Yang 已提交
231

Y
Yan Chunwei 已提交
232 233 234 235 236 237
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::AttrType::FLOAT);
  attr->set_f(3.14);

  paddle::platform::CPUDeviceContext cpu_device_context;
Y
Yu Yang 已提交
238
  paddle::framework::Scope scope;
239 240 241 242 243 244
  scope.NewVar("x0")->GetMutable<Tensor>();
  scope.NewVar("x1")->GetMutable<Tensor>();
  scope.NewVar("x2")->GetMutable<Tensor>();
  scope.NewVar("k0")->GetMutable<Tensor>();
  scope.NewVar("y0")->GetMutable<Tensor>();
  scope.NewVar("y1")->GetMutable<Tensor>();
Y
Yan Chunwei 已提交
245

Y
Yu Yang 已提交
246
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Y
Yan Chunwei 已提交
247 248
  op->Run(scope, cpu_device_context);
}