operator_test.cc 24.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14
#include "paddle/fluid/framework/operator.h"
D
dzhwinter 已提交
15

16
#include "gtest/gtest.h"
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
19
#include "paddle/fluid/platform/errors.h"
20
#include "paddle/fluid/platform/init.h"
Q
Qiao Longfei 已提交
21

22 23
DECLARE_bool(enable_unused_var_check);

Q
Qiao Longfei 已提交
24 25 26
namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
27 28 29
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
30
 public:
Y
Yu Yang 已提交
31 32
  OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
                      const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
33
      : OperatorBase(type, inputs, outputs, attrs), x(1) {}
34 35 36 37

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Y
Yu Yang 已提交
38 39 40 41
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
42
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
43
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
44
  }
Q
Qiao Longfei 已提交
45 46

 public:
Y
Yu Yang 已提交
47
  int x{0};
Q
Qiao Longfei 已提交
48 49
};

D
dzhwinter 已提交
50
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
51
 public:
Y
Yu Yang 已提交
52
  void Make() {
Q
Qiao Longfei 已提交
53 54
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
55
    AddAttr<float>("scale", "scale of cosine op");
X
Xin Pan 已提交
56 57
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
58 59 60 61 62 63 64
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
65 66
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
67
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
68 69 70 71 72 73
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

D
dzhwinter 已提交
74 75 76
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
                             paddle::framework::OpWithoutKernelTest,
                             paddle::framework::OpWithoutKernelCheckerMaker);
Q
Qiao Longfei 已提交
77 78

TEST(OperatorBase, all) {
79
  paddle::framework::InitDevices();
80
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
81
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
82 83
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
84

Q
Qiao Longfei 已提交
85 86
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
87
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
88
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
89

D
dzhwinter 已提交
90
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
91
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
92

93
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dongzhihong 已提交
94
  scope.Var("OUT1");
Q
Qiao Longfei 已提交
95
  ASSERT_EQ(paddle::framework::op_run_num, 0);
D
dzhwinter 已提交
96
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
97
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
98 99 100 101 102
}

namespace paddle {
namespace framework {

X
Xin Pan 已提交
103 104
static int special_type_value = 1;

Q
Qiao Longfei 已提交
105 106
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
107
  void Make() {
Y
Yan Chunwei 已提交
108 109 110 111
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
112
        .GreaterThan(0.0);
X
Xin Pan 已提交
113 114
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
115 116 117 118
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
119
static int cpu_kernel_run_num = 0;
X
Xin Pan 已提交
120
static int cpu_kernel2_run_num = 0;
Q
Qiao Longfei 已提交
121

Q
Qiao Longfei 已提交
122
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
123 124 125
 public:
  using OperatorWithKernel::OperatorWithKernel;

Y
Yu Yang 已提交
126
 protected:
127
  void InferShape(framework::InferShapeContext* ctx) const override {}
128 129
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
X
Xin Pan 已提交
130 131 132 133
    int sub_type = ctx.Attr<int>("kernel_sub_type");
    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
                        framework::DataLayout::kAnyLayout,
                        framework::LibraryType::kPlain, sub_type);
Y
Yu Yang 已提交
134
  }
Q
Qiao Longfei 已提交
135 136
};

137
template <typename T1, typename T2>
Y
Yu Yang 已提交
138
class CPUKernelTest : public OpKernel<float> {
Q
Qiao Longfei 已提交
139
 public:
140
  void Compute(const ExecutionContext& ctx) const {
H
hong 已提交
141
    std::cout << ctx.DebugString() << std::endl;
Q
Qiao Longfei 已提交
142
    cpu_kernel_run_num++;
H
hong 已提交
143 144
    ASSERT_EQ(ctx.InputName("x"), "IN1");
    ASSERT_EQ(ctx.OutputName("y"), "OUT1");
145 146
    auto* x = ctx.Input<Tensor>("X");
    ASSERT_EQ(x, nullptr);
Y
Yan Chunwei 已提交
147 148 149
  }
};

X
Xin Pan 已提交
150 151 152 153
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
 public:
  void Compute(const ExecutionContext& ctx) const {
H
hong 已提交
154
    std::cout << ctx.DebugString() << std::endl;
X
Xin Pan 已提交
155
    cpu_kernel2_run_num++;
H
hong 已提交
156 157
    ASSERT_EQ(ctx.InputName("x"), "IN1");
    ASSERT_EQ(ctx.OutputName("y"), "OUT1");
X
Xin Pan 已提交
158 159 160
  }
};

Y
Yan Chunwei 已提交
161 162 163
class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
164
  void Make() {
Y
Yu Yang 已提交
165
    AddInput("xs", "inputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
166
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
167
    AddOutput("ys", "outputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
168 169
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
170
        .GreaterThan(0.0);
X
Xin Pan 已提交
171 172
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Y
Yan Chunwei 已提交
173 174 175 176
    AddComment("This is test op");
  }
};

Y
Yu Yang 已提交
177
class CPUKernalMultiInputsTest : public OpKernel<float> {
Y
Yan Chunwei 已提交
178
 public:
179
  void Compute(const ExecutionContext& ctx) const {
H
hong 已提交
180
    auto xs = ctx.InputNames("xs");
Y
Yan Chunwei 已提交
181 182 183 184 185
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

186
    auto inVar0 = ctx.MultiInputVar("xs");
187
    ASSERT_EQ(inVar0.size(), 3U);
188 189 190 191 192

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
193
    ASSERT_EQ(outVar0.size(), 2U);
194 195

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
196
    ASSERT_EQ(inTensor0.size(), 3U);
197 198 199 200 201

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
202
    ASSERT_EQ(outTensor0.size(), 2U);
203

H
hong 已提交
204
    auto k = ctx.InputName("k");
Y
Yan Chunwei 已提交
205 206
    ASSERT_EQ(k, "k0");

H
hong 已提交
207
    auto ys = ctx.OutputNames("ys");
Y
Yan Chunwei 已提交
208 209 210
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
211 212 213
  }
};

Y
Yu Yang 已提交
214 215 216
}  // namespace framework
}  // namespace paddle

F
fengjiayi 已提交
217 218 219
REGISTER_OP_WITHOUT_GRADIENT(
    op_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestProtoAndCheckerMaker);
X
Xin Pan 已提交
220

X
Xin Pan 已提交
221 222
REGISTER_OP_CPU_KERNEL(op_with_kernel,
                       paddle::framework::CPUKernelTest<float, float>);
X
Xin Pan 已提交
223 224

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
X
Xin Pan 已提交
225 226
    op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME,
    paddle::framework::special_type_value,
X
Xin Pan 已提交
227
    paddle::framework::CPUKernel2Test<float, float>);
Q
Qiao Longfei 已提交
228

Y
Yan Chunwei 已提交
229
// test with single input
Q
Qiao Longfei 已提交
230
TEST(OpKernel, all) {
231
  paddle::framework::InitDevices();
232
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
233
  op_desc.set_type("op_with_kernel");
Y
Fix CI  
Yu Yang 已提交
234 235
  BuildVar("x", {"IN1"}, op_desc.add_inputs());
  BuildVar("y", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
236

Q
Qiao Longfei 已提交
237 238
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
239
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
240 241
  attr->set_f(3.14);

D
dzhwinter 已提交
242
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
243
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
244

245
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
246
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
D
dzhwinter 已提交
247
  op->Run(scope, cpu_place);
X
Xin Pan 已提交
248
  // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
Q
Qiao Longfei 已提交
249
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
X
Xin Pan 已提交
250 251 252 253 254 255 256 257
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);

  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("kernel_sub_type");
  attr->set_type(paddle::framework::proto::AttrType::INT);
  attr->set_i(1);
  auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
  op2->Run(scope, cpu_place);
X
Xin Pan 已提交
258
  // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
X
Xin Pan 已提交
259 260
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
Q
Qiao Longfei 已提交
261
}
Y
Yan Chunwei 已提交
262

F
fengjiayi 已提交
263 264 265
REGISTER_OP_WITHOUT_GRADIENT(
    op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
Y
Yan Chunwei 已提交
266 267 268 269 270
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
271
  paddle::framework::InitDevices();
272
  paddle::framework::proto::OpDesc op_desc;
D
dzhwinter 已提交
273

Y
Yan Chunwei 已提交
274
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
275 276 277
  BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
  BuildVar("k", {"k0"}, op_desc.add_inputs());
  BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
278

Y
Yan Chunwei 已提交
279 280
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
281
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Y
Yan Chunwei 已提交
282 283
  attr->set_f(3.14);

D
dzhwinter 已提交
284
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
285
  paddle::framework::Scope scope;
286 287 288 289 290 291
  scope.Var("x0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x1")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x2")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("k0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y1")->GetMutable<paddle::framework::LoDTensor>();
Y
Yan Chunwei 已提交
292

293
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
294
  op->Run(scope, cpu_place);
Y
Yan Chunwei 已提交
295
}
296

M
minqiyang 已提交
297
TEST(VarNameTest, all) {
298 299
  std::string var_name("X");
  std::string grad_var_name = paddle::framework::GradVarName(var_name);
M
minqiyang 已提交
300
  ASSERT_EQ(grad_var_name, "X@GRAD");
301
  std::string original_var_name =
M
minqiyang 已提交
302
      paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
303
  ASSERT_EQ(original_var_name, "X");
M
minqiyang 已提交
304
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
305 306 307 308 309
  ASSERT_EQ(original_var_name, "X");

  std::string var_name_2("XYZ");
  grad_var_name = paddle::framework::GradVarName(var_name_2);
  ASSERT_EQ(grad_var_name, "XYZ@GRAD");
M
minqiyang 已提交
310
  original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
311
  ASSERT_EQ(original_var_name, "XYZ");
M
minqiyang 已提交
312
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
313 314 315 316 317
  ASSERT_EQ(original_var_name, "XYZ");

  std::string var_name_3("");
  grad_var_name = paddle::framework::GradVarName(var_name_3);
  ASSERT_EQ(grad_var_name, "@GRAD");
M
minqiyang 已提交
318
  original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
319
  ASSERT_EQ(original_var_name, "");
M
minqiyang 已提交
320
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
321
  ASSERT_EQ(original_var_name, "");
322
}
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338

namespace paddle {
namespace framework {

class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
339

340 341 342 343
class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("LoDTensor", "Input of Tensor type Variable.");
T
tianshuo78520a 已提交
344
    AddComment("This Op is only for IndicateVarDataType interface test.");
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
  }
};

class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
class IndicateSelectedRowsDataTypeTestProtoMaker
    : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("SelectedRows", "Input of SelectedRows type Variable.");
T
tianshuo78520a 已提交
366
    AddComment("This Op is only for IndicateVarDataType interface test.");
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
  }
};

class IndicateOtherDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("Other", "Input of Other type Variable");
T
tianshuo78520a 已提交
386
    AddComment("This Op is only for IndicateVarDataType interface test.");
387 388 389 390
  }
};

template <typename DeviceContext, typename T>
391
class EmptyTestKernel : public OpKernel<T> {
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
 public:
  void Compute(const ExecutionContext& ctx) const {}
};

}  // namespace framework
}  // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
    indicate_lod_tensor_data_type_test,
    paddle::framework::IndicateLoDTensorDataTypeTest,
    paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
    indicate_selected_rows_data_type_test,
    paddle::framework::IndicateSelectedRowsDataTypeTest,
    paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
    indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest,
    paddle::framework::IndicateOtherDataTypeTestProtoMaker);

REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test,
412
                       paddle::framework::EmptyTestKernel<
413 414
                           paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test,
415
                       paddle::framework::EmptyTestKernel<
416 417
                           paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test,
418
                       paddle::framework::EmptyTestKernel<
419 420 421
                           paddle::platform::CPUDeviceContext, int>);

TEST(IndicateVarDataTypeTest, lodtensor) {
422
  paddle::framework::InitDevices();
423 424 425 426 427 428 429 430 431 432 433 434 435 436
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_lod_tensor_data_type_test");
  BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("lodtensor_1");
  var->GetMutable<paddle::framework::LoDTensor>();

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
437
  } catch (paddle::platform::EnforceNotMet& err) {
438 439 440 441
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(
        ex_msg.find(
442 443
            "The indicate_lod_tensor_data_type_test Op's Input Variable "
            "`LoDTensor` contains uninitialized Tensor.") != std::string::npos);
444 445 446 447 448
  }
  ASSERT_TRUE(caught);
}

TEST(IndicateVarDataTypeTest, selectedrows) {
449
  paddle::framework::InitDevices();
450 451 452 453 454 455 456 457 458
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_selected_rows_data_type_test");
  BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("selected_rows_1");
459
  var->GetMutable<phi::SelectedRows>();
460 461 462 463

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
464
  } catch (paddle::platform::EnforceNotMet& err) {
465 466 467
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(
468 469 470
        ex_msg.find("The indicate_selected_rows_data_type_test Op's "
                    "Input Variable `SelectedRows` contains uninitialized "
                    "Tensor.") != std::string::npos);
471 472 473 474 475
  }
  ASSERT_TRUE(caught);
}

TEST(IndicateVarDataTypeTest, other) {
476
  paddle::framework::InitDevices();
477 478
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_other_data_type_test");
L
liym27 已提交
479
  BuildVar("Other", {"lod_rank_table_1"}, op_desc.add_inputs());
480 481 482 483 484

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
L
liym27 已提交
485 486
  auto* var = scope.Var("lod_rank_table_1");
  var->GetMutable<paddle::framework::LoDRankTable>();
487 488 489 490

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
491
  } catch (paddle::platform::EnforceNotMet& err) {
492 493
    caught = true;
    std::string ex_msg = err.what();
L
liym27 已提交
494 495 496
    EXPECT_TRUE(
        ex_msg.find(
            "The Input Variable(Other) of "
497
            "(indicate_other_data_type_test) Operator used to "
L
liym27 已提交
498
            "determine kernel data type "
499
            "is empty or not LoDTensor or SelectedRows or LoDTensorArray.") !=
L
liym27 已提交
500
        std::string::npos);
501 502 503
  }
  ASSERT_TRUE(caught);
}
504

H
hong 已提交
505
TEST(ExecutionContextAttrAndInOut, new_api) {
506
  paddle::framework::InitDevices();
H
hong 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("test_operator");
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());

  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
  attr->set_f(3.14);

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("OUT1");
  var->GetMutable<paddle::framework::LoDTensorArray>();

  paddle::platform::DeviceContextPool& pool =
      paddle::platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(cpu_place);

  paddle::framework::RuntimeContext ctx({}, {});
  paddle::framework::ExecutionContext exe_context(*(op.get()), scope, *dev_ctx,
530
                                                  ctx);
H
hong 已提交
531 532 533 534 535

  ASSERT_EQ(exe_context.InputSize("input"), 1u);
  ASSERT_EQ(exe_context.OutputSize("output"), 1u);

  auto attr_map = exe_context.Attrs();
536
  ASSERT_EQ(BOOST_GET(float, attr_map["scale"]), 3.14f);
H
hong 已提交
537 538 539
  ASSERT_EQ(exe_context.Type(), "test_operator");
}

540 541 542 543 544 545 546 547 548
namespace paddle {
namespace framework {

class GetLoDLevelTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
549 550 551 552 553 554 555
    OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "GetLoDLevelTest");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GetLoDLevelTest");

    auto lod_level = ctx->GetLoDLevel("X");
    PADDLE_ENFORCE_GT(lod_level, 0,
                      paddle::platform::errors::InvalidArgument(
                          "The LoD level Input(X) should be larger than 0."));
556 557 558 559 560 561 562 563 564
  }
};

class SetLoDLevelTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
565 566
    OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SetLoDLevelTest");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SetLoDLevelTest");
567 568 569 570 571 572 573 574 575
    ctx->SetLoDLevel("Out", 1);
  }
};

class GetSetLoDLevelTestMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("X", "(LoDTensor) Input Variable.");
    AddOutput("Out", "(LoDTensor) Output Variable.");
T
tianshuo78520a 已提交
576
    AddComment("This Op is only for Get/SetLoDLevel interface test.");
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597
  }
};

}  // namespace framework
}  // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(get_lod_level_test,
                             paddle::framework::GetLoDLevelTest,
                             paddle::framework::GetSetLoDLevelTestMaker);
REGISTER_OP_CPU_KERNEL(get_lod_level_test,
                       paddle::framework::EmptyTestKernel<
                           paddle::platform::CPUDeviceContext, float>);

REGISTER_OP_WITHOUT_GRADIENT(set_lod_level_test,
                             paddle::framework::SetLoDLevelTest,
                             paddle::framework::GetSetLoDLevelTestMaker);
REGISTER_OP_CPU_KERNEL(set_lod_level_test,
                       paddle::framework::EmptyTestKernel<
                           paddle::platform::CPUDeviceContext, float>);

void SetGetLoDLevelTestMain(std::string op_type) {
598
  paddle::framework::InitDevices({});
599 600 601 602 603 604 605 606 607 608 609
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type(op_type);
  BuildVar("X", {"x.0"}, op_desc.add_inputs());
  BuildVar("Out", {"out.0"}, op_desc.add_outputs());

  paddle::platform::CPUPlace place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* x_var = scope.Var("x.0");
  auto* x = x_var->GetMutable<paddle::framework::LoDTensor>();
610
  x->mutable_data<float>(phi::make_ddim({64}), place);
611 612 613 614 615 616 617 618 619 620 621 622
  auto* out_var = scope.Var("out.0");
  out_var->GetMutable<paddle::framework::LoDTensor>();

  bool caught = false;
  std::string err_str =
      (op_type == "get_lod_level_test") ? "GetLoDLevel" : "SetLoDLevel";
  err_str +=
      " is only used in compile time. The calculation of output's actual lod "
      "is different among operators so that should be set in the runtime "
      "kernel.";
  try {
    op->Run(scope, place);
Z
Zeng Jinle 已提交
623
  } catch (paddle::platform::EnforceNotMet& err) {
624 625 626 627 628 629 630 631 632 633
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(ex_msg.find(err_str) != std::string::npos);
  }
  ASSERT_TRUE(caught);
}

TEST(GetLoDLevelTest, base) { SetGetLoDLevelTestMain("get_lod_level_test"); }

TEST(SetLoDLevelTest, base) { SetGetLoDLevelTestMain("set_lod_level_test"); }
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702

namespace paddle {
namespace framework {

class OpUnusedVarTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
                        framework::DataLayout::kAnyLayout);
  }
};

class OpUnusedVarTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("X", "input of test op");
    AddOutput("Y", "output of test op");
    AddComment("This is test op for unused var check.");
  }
};

template <typename T>
class OpWithUnusedVarKernelTest : public OpKernel<T> {
 public:
  void Compute(const ExecutionContext& ctx) const {
    ASSERT_EQ(ctx.InputName("X"), "X");
    ASSERT_EQ(ctx.OutputName("Y"), "Y");
  }
};

template <typename T>
class OpWithoutUnusedVarKernelTest : public OpKernel<T> {
 public:
  void Compute(const ExecutionContext& ctx) const {
    ASSERT_EQ(ctx.InputName("X"), "X");
    ASSERT_EQ(ctx.OutputName("Y"), "Y");
    auto* x = ctx.Input<Tensor>("X");
    auto* y = ctx.Output<Tensor>("Y");
    ASSERT_NE(x, y);
    ASSERT_NE(y, nullptr);
  }
};

}  // namespace framework
}  // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
    op_with_unused_var, paddle::framework::OpUnusedVarTest,
    paddle::framework::OpUnusedVarTestProtoAndCheckerMaker);

REGISTER_OP_CPU_KERNEL(op_with_unused_var,
                       paddle::framework::OpWithUnusedVarKernelTest<float>);

REGISTER_OP_WITHOUT_GRADIENT(
    op_without_unused_var, paddle::framework::OpUnusedVarTest,
    paddle::framework::OpUnusedVarTestProtoAndCheckerMaker);

REGISTER_OP_CPU_KERNEL(op_without_unused_var,
                       paddle::framework::OpWithoutUnusedVarKernelTest<float>);

// test with single input
TEST(OpWithUnusedVar, all) {
  // enable the unused_var_check
  FLAGS_enable_unused_var_check = true;
703
  paddle::framework::InitDevices();
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("op_with_unused_var");
  BuildVar("X", {"X"}, op_desc.add_inputs());
  BuildVar("Y", {"Y"}, op_desc.add_outputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;
  auto* x = scope.Var("X")->GetMutable<paddle::framework::LoDTensor>();
  auto* y = scope.Var("Y")->GetMutable<paddle::framework::LoDTensor>();
  x->Resize({32, 64});
  y->Resize({32, 64});
  x->mutable_data<float>(cpu_place);
  y->mutable_data<float>(cpu_place);

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  // should throw exception
  ASSERT_THROW(op->Run(scope, cpu_place), paddle::platform::EnforceNotMet);
  FLAGS_enable_unused_var_check = false;
}

TEST(OpWithoutUnusedVar, all) {
  // enable the unused_var_check
  FLAGS_enable_unused_var_check = true;

728
  paddle::framework::InitDevices();
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("op_without_unused_var");
  BuildVar("X", {"X"}, op_desc.add_inputs());
  BuildVar("Y", {"Y"}, op_desc.add_outputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;
  auto* x = scope.Var("X")->GetMutable<paddle::framework::LoDTensor>();
  auto* y = scope.Var("Y")->GetMutable<paddle::framework::LoDTensor>();
  x->Resize({32, 64});
  y->Resize({32, 64});
  x->mutable_data<float>(cpu_place);
  y->mutable_data<float>(cpu_place);

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  // should not throw exception
  ASSERT_NO_THROW(op->Run(scope, cpu_place));
  FLAGS_enable_unused_var_check = false;
}