operator_test.cc 9.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
D
dzhwinter 已提交
15

Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
19
#include "paddle/fluid/platform/init.h"
Q
Qiao Longfei 已提交
20 21 22 23

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
24 25 26
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
27
 public:
Y
Yu Yang 已提交
28 29
  OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
                      const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
30
      : OperatorBase(type, inputs, outputs, attrs), x(1) {}
31 32 33 34

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Y
Yu Yang 已提交
35 36 37 38
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
39
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
40
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
41
  }
Q
Qiao Longfei 已提交
42 43

 public:
Y
Yu Yang 已提交
44
  int x{0};
Q
Qiao Longfei 已提交
45 46
};

D
dzhwinter 已提交
47
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
48
 public:
Y
Yu Yang 已提交
49
  void Make() {
Q
Qiao Longfei 已提交
50 51
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
52
    AddAttr<float>("scale", "scale of cosine op");
X
Xin Pan 已提交
53 54
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
55 56 57 58 59 60 61
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
62 63
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
64
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
65 66 67 68 69 70
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

D
dzhwinter 已提交
71 72 73
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
                             paddle::framework::OpWithoutKernelTest,
                             paddle::framework::OpWithoutKernelCheckerMaker);
Q
Qiao Longfei 已提交
74 75

TEST(OperatorBase, all) {
X
Xin Pan 已提交
76
  paddle::framework::InitDevices(true);
77
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
78
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
79 80
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
81

Q
Qiao Longfei 已提交
82 83
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
84
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
85
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
86

D
dzhwinter 已提交
87
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
88
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
89

90
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dongzhihong 已提交
91
  scope.Var("OUT1");
Q
Qiao Longfei 已提交
92
  ASSERT_EQ(paddle::framework::op_run_num, 0);
D
dzhwinter 已提交
93
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
94
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
95 96 97 98 99
}

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
100 101
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
102
  void Make() {
Y
Yan Chunwei 已提交
103 104 105 106
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
107
        .GreaterThan(0.0);
X
Xin Pan 已提交
108 109
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
110 111 112 113
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
114
static int cpu_kernel_run_num = 0;
X
Xin Pan 已提交
115
static int cpu_kernel2_run_num = 0;
Q
Qiao Longfei 已提交
116

Q
Qiao Longfei 已提交
117
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
118 119 120
 public:
  using OperatorWithKernel::OperatorWithKernel;

Y
Yu Yang 已提交
121
 protected:
122
  void InferShape(framework::InferShapeContext* ctx) const override {}
123 124
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
X
Xin Pan 已提交
125 126 127 128
    int sub_type = ctx.Attr<int>("kernel_sub_type");
    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
                        framework::DataLayout::kAnyLayout,
                        framework::LibraryType::kPlain, sub_type);
Y
Yu Yang 已提交
129
  }
Q
Qiao Longfei 已提交
130 131
};

132
template <typename T1, typename T2>
Y
Yu Yang 已提交
133
class CPUKernelTest : public OpKernel<float> {
Q
Qiao Longfei 已提交
134
 public:
135
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
136
    std::cout << ctx.op().DebugString() << std::endl;
Q
Qiao Longfei 已提交
137
    cpu_kernel_run_num++;
Q
qiaolongfei 已提交
138 139
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
Y
Yan Chunwei 已提交
140 141 142
  }
};

X
Xin Pan 已提交
143 144 145 146 147 148 149 150 151 152 153
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
 public:
  void Compute(const ExecutionContext& ctx) const {
    std::cout << ctx.op().DebugString() << std::endl;
    cpu_kernel2_run_num++;
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
  }
};

Y
Yan Chunwei 已提交
154 155 156
class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
157
  void Make() {
Y
Yu Yang 已提交
158
    AddInput("xs", "inputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
159
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
160
    AddOutput("ys", "outputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
161 162
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
163
        .GreaterThan(0.0);
X
Xin Pan 已提交
164 165
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Y
Yan Chunwei 已提交
166 167 168 169
    AddComment("This is test op");
  }
};

Y
Yu Yang 已提交
170
class CPUKernalMultiInputsTest : public OpKernel<float> {
Y
Yan Chunwei 已提交
171
 public:
172
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
173
    auto xs = ctx.op().Inputs("xs");
Y
Yan Chunwei 已提交
174 175 176 177 178
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

179
    auto inVar0 = ctx.MultiInputVar("xs");
180
    ASSERT_EQ(inVar0.size(), 3U);
181 182 183 184 185

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
186
    ASSERT_EQ(outVar0.size(), 2U);
187 188

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
189
    ASSERT_EQ(inTensor0.size(), 3U);
190 191 192 193 194

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
195
    ASSERT_EQ(outTensor0.size(), 2U);
196

Q
qiaolongfei 已提交
197
    auto k = ctx.op().Input("k");
Y
Yan Chunwei 已提交
198 199
    ASSERT_EQ(k, "k0");

Q
qiaolongfei 已提交
200
    auto ys = ctx.op().Outputs("ys");
Y
Yan Chunwei 已提交
201 202 203
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
204 205 206
  }
};

Y
Yu Yang 已提交
207 208 209
}  // namespace framework
}  // namespace paddle

F
fengjiayi 已提交
210 211 212
REGISTER_OP_WITHOUT_GRADIENT(
    op_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestProtoAndCheckerMaker);
X
Xin Pan 已提交
213 214 215 216 217 218 219 220 221 222 223

// REGISTER_OP_CPU_KERNEL(op_with_kernel,
//                        paddle::framework::CPUKernelTest<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    op_with_kernel, CPU, paddle::platform::CPUPlace, DEFAULT_TYPE, 0,
    paddle::framework::CPUKernelTest<float, float>);

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
    op_with_kernel, CPU, paddle::platform::CPUPlace, SPECIAL, 1,
    paddle::framework::CPUKernel2Test<float, float>);
Q
Qiao Longfei 已提交
224

Y
Yan Chunwei 已提交
225
// test with single input
Q
Qiao Longfei 已提交
226
TEST(OpKernel, all) {
X
Xin Pan 已提交
227
  paddle::framework::InitDevices(true);
228
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
229
  op_desc.set_type("op_with_kernel");
Y
Fix CI  
Yu Yang 已提交
230 231
  BuildVar("x", {"IN1"}, op_desc.add_inputs());
  BuildVar("y", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
232

Q
Qiao Longfei 已提交
233 234
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
235
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
236 237
  attr->set_f(3.14);

D
dzhwinter 已提交
238
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
239
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
240

241
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
242
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
D
dzhwinter 已提交
243
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
244
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
X
Xin Pan 已提交
245 246 247 248 249 250 251 252 253 254
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);

  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("kernel_sub_type");
  attr->set_type(paddle::framework::proto::AttrType::INT);
  attr->set_i(1);
  auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
  op2->Run(scope, cpu_place);
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
Q
Qiao Longfei 已提交
255
}
Y
Yan Chunwei 已提交
256

F
fengjiayi 已提交
257 258 259
REGISTER_OP_WITHOUT_GRADIENT(
    op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
Y
Yan Chunwei 已提交
260 261 262 263 264
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
X
Xin Pan 已提交
265
  paddle::framework::InitDevices(true);
266
  paddle::framework::proto::OpDesc op_desc;
D
dzhwinter 已提交
267

Y
Yan Chunwei 已提交
268
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
269 270 271
  BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
  BuildVar("k", {"k0"}, op_desc.add_inputs());
  BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
272

Y
Yan Chunwei 已提交
273 274
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
275
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Y
Yan Chunwei 已提交
276 277
  attr->set_f(3.14);

D
dzhwinter 已提交
278
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
279
  paddle::framework::Scope scope;
280 281 282 283 284 285
  scope.Var("x0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x1")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x2")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("k0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y1")->GetMutable<paddle::framework::LoDTensor>();
Y
Yan Chunwei 已提交
286

287
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
288
  op->Run(scope, cpu_place);
Y
Yan Chunwei 已提交
289
}