operator_test.cc 19.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
D
dzhwinter 已提交
15

Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
19
#include "paddle/fluid/platform/init.h"
Q
Qiao Longfei 已提交
20 21 22 23

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
24 25 26
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
27
 public:
Y
Yu Yang 已提交
28 29
  OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
                      const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
30
      : OperatorBase(type, inputs, outputs, attrs), x(1) {}
31 32 33 34

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Y
Yu Yang 已提交
35 36 37 38
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
39
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
40
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
41
  }
Q
Qiao Longfei 已提交
42 43

 public:
Y
Yu Yang 已提交
44
  int x{0};
Q
Qiao Longfei 已提交
45 46
};

D
dzhwinter 已提交
47
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
48
 public:
Y
Yu Yang 已提交
49
  void Make() {
Q
Qiao Longfei 已提交
50 51
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
52
    AddAttr<float>("scale", "scale of cosine op");
X
Xin Pan 已提交
53 54
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
55 56 57 58 59 60 61
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
62 63
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
64
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
65 66 67 68 69 70
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

D
dzhwinter 已提交
71 72 73
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
                             paddle::framework::OpWithoutKernelTest,
                             paddle::framework::OpWithoutKernelCheckerMaker);
Q
Qiao Longfei 已提交
74 75

TEST(OperatorBase, all) {
X
Xin Pan 已提交
76
  paddle::framework::InitDevices(true);
77
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
78
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
79 80
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
81

Q
Qiao Longfei 已提交
82 83
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
84
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
85
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
86

D
dzhwinter 已提交
87
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
88
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
89

90
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dongzhihong 已提交
91
  scope.Var("OUT1");
Q
Qiao Longfei 已提交
92
  ASSERT_EQ(paddle::framework::op_run_num, 0);
D
dzhwinter 已提交
93
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
94
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
95 96 97 98 99
}

namespace paddle {
namespace framework {

X
Xin Pan 已提交
100 101
static int special_type_value = 1;

Q
Qiao Longfei 已提交
102 103
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
104
  void Make() {
Y
Yan Chunwei 已提交
105 106 107 108
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
109
        .GreaterThan(0.0);
X
Xin Pan 已提交
110 111
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
112 113 114 115
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
116
static int cpu_kernel_run_num = 0;
X
Xin Pan 已提交
117
static int cpu_kernel2_run_num = 0;
Q
Qiao Longfei 已提交
118

Q
Qiao Longfei 已提交
119
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
120 121 122
 public:
  using OperatorWithKernel::OperatorWithKernel;

Y
Yu Yang 已提交
123
 protected:
124
  void InferShape(framework::InferShapeContext* ctx) const override {}
125 126
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
X
Xin Pan 已提交
127 128 129 130
    int sub_type = ctx.Attr<int>("kernel_sub_type");
    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
                        framework::DataLayout::kAnyLayout,
                        framework::LibraryType::kPlain, sub_type);
Y
Yu Yang 已提交
131
  }
Q
Qiao Longfei 已提交
132 133
};

134
template <typename T1, typename T2>
Y
Yu Yang 已提交
135
class CPUKernelTest : public OpKernel<float> {
Q
Qiao Longfei 已提交
136
 public:
137
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
138
    std::cout << ctx.op().DebugString() << std::endl;
Q
Qiao Longfei 已提交
139
    cpu_kernel_run_num++;
Q
qiaolongfei 已提交
140 141
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
Y
Yan Chunwei 已提交
142 143 144
  }
};

X
Xin Pan 已提交
145 146 147 148 149 150 151 152 153 154 155
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
 public:
  void Compute(const ExecutionContext& ctx) const {
    std::cout << ctx.op().DebugString() << std::endl;
    cpu_kernel2_run_num++;
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
  }
};

Y
Yan Chunwei 已提交
156 157 158
class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
159
  void Make() {
Y
Yu Yang 已提交
160
    AddInput("xs", "inputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
161
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
162
    AddOutput("ys", "outputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
163 164
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
165
        .GreaterThan(0.0);
X
Xin Pan 已提交
166 167
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Y
Yan Chunwei 已提交
168 169 170 171
    AddComment("This is test op");
  }
};

Y
Yu Yang 已提交
172
class CPUKernalMultiInputsTest : public OpKernel<float> {
Y
Yan Chunwei 已提交
173
 public:
174
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
175
    auto xs = ctx.op().Inputs("xs");
Y
Yan Chunwei 已提交
176 177 178 179 180
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

181
    auto inVar0 = ctx.MultiInputVar("xs");
182
    ASSERT_EQ(inVar0.size(), 3U);
183 184 185 186 187

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
188
    ASSERT_EQ(outVar0.size(), 2U);
189 190

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
191
    ASSERT_EQ(inTensor0.size(), 3U);
192 193 194 195 196

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
197
    ASSERT_EQ(outTensor0.size(), 2U);
198

Q
qiaolongfei 已提交
199
    auto k = ctx.op().Input("k");
Y
Yan Chunwei 已提交
200 201
    ASSERT_EQ(k, "k0");

Q
qiaolongfei 已提交
202
    auto ys = ctx.op().Outputs("ys");
Y
Yan Chunwei 已提交
203 204 205
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
206 207 208
  }
};

Y
Yu Yang 已提交
209 210 211
}  // namespace framework
}  // namespace paddle

F
fengjiayi 已提交
212 213 214
REGISTER_OP_WITHOUT_GRADIENT(
    op_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestProtoAndCheckerMaker);
X
Xin Pan 已提交
215

X
Xin Pan 已提交
216 217
REGISTER_OP_CPU_KERNEL(op_with_kernel,
                       paddle::framework::CPUKernelTest<float, float>);
X
Xin Pan 已提交
218 219

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
X
Xin Pan 已提交
220 221
    op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME,
    paddle::framework::special_type_value,
X
Xin Pan 已提交
222
    paddle::framework::CPUKernel2Test<float, float>);
Q
Qiao Longfei 已提交
223

Y
Yan Chunwei 已提交
224
// test with single input
Q
Qiao Longfei 已提交
225
TEST(OpKernel, all) {
X
Xin Pan 已提交
226
  paddle::framework::InitDevices(true);
227
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
228
  op_desc.set_type("op_with_kernel");
Y
Fix CI  
Yu Yang 已提交
229 230
  BuildVar("x", {"IN1"}, op_desc.add_inputs());
  BuildVar("y", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
231

Q
Qiao Longfei 已提交
232 233
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
234
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
235 236
  attr->set_f(3.14);

D
dzhwinter 已提交
237
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
238
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
239

240
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
241
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
D
dzhwinter 已提交
242
  op->Run(scope, cpu_place);
X
Xin Pan 已提交
243
  // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
Q
Qiao Longfei 已提交
244
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
X
Xin Pan 已提交
245 246 247 248 249 250 251 252
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);

  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("kernel_sub_type");
  attr->set_type(paddle::framework::proto::AttrType::INT);
  attr->set_i(1);
  auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
  op2->Run(scope, cpu_place);
X
Xin Pan 已提交
253
  // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
X
Xin Pan 已提交
254 255
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
Q
Qiao Longfei 已提交
256
}
Y
Yan Chunwei 已提交
257

F
fengjiayi 已提交
258 259 260
REGISTER_OP_WITHOUT_GRADIENT(
    op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
Y
Yan Chunwei 已提交
261 262 263 264 265
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
X
Xin Pan 已提交
266
  paddle::framework::InitDevices(true);
267
  paddle::framework::proto::OpDesc op_desc;
D
dzhwinter 已提交
268

Y
Yan Chunwei 已提交
269
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
270 271 272
  BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
  BuildVar("k", {"k0"}, op_desc.add_inputs());
  BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
273

Y
Yan Chunwei 已提交
274 275
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
276
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Y
Yan Chunwei 已提交
277 278
  attr->set_f(3.14);

D
dzhwinter 已提交
279
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
280
  paddle::framework::Scope scope;
281 282 283 284 285 286
  scope.Var("x0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x1")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x2")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("k0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y1")->GetMutable<paddle::framework::LoDTensor>();
Y
Yan Chunwei 已提交
287

288
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
289
  op->Run(scope, cpu_place);
Y
Yan Chunwei 已提交
290
}
291

M
minqiyang 已提交
292
TEST(VarNameTest, all) {
293 294
  std::string var_name("X");
  std::string grad_var_name = paddle::framework::GradVarName(var_name);
M
minqiyang 已提交
295
  ASSERT_EQ(grad_var_name, "X@GRAD");
296
  std::string original_var_name =
M
minqiyang 已提交
297
      paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
298
  ASSERT_EQ(original_var_name, "X");
M
minqiyang 已提交
299
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
300 301 302 303 304
  ASSERT_EQ(original_var_name, "X");

  std::string var_name_2("XYZ");
  grad_var_name = paddle::framework::GradVarName(var_name_2);
  ASSERT_EQ(grad_var_name, "XYZ@GRAD");
M
minqiyang 已提交
305
  original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
306
  ASSERT_EQ(original_var_name, "XYZ");
M
minqiyang 已提交
307
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
308 309 310 311 312
  ASSERT_EQ(original_var_name, "XYZ");

  std::string var_name_3("");
  grad_var_name = paddle::framework::GradVarName(var_name_3);
  ASSERT_EQ(grad_var_name, "@GRAD");
M
minqiyang 已提交
313
  original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
M
minqiyang 已提交
314
  ASSERT_EQ(original_var_name, "");
M
minqiyang 已提交
315
  original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
M
minqiyang 已提交
316
  ASSERT_EQ(original_var_name, "");
317
}
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333

namespace paddle {
namespace framework {

class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
334

335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("LoDTensor", "Input of Tensor type Variable.");
    AddComment("This Op is only for IndicateVarDataType inferface test.");
  }
};

class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type =
        OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
class IndicateSelectedRowsDataTypeTestProtoMaker
    : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("SelectedRows", "Input of SelectedRows type Variable.");
    AddComment("This Op is only for IndicateVarDataType inferface test.");
  }
};

class IndicateOtherDataTypeTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {}
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
    return framework::OpKernelType(data_type, ctx.device_context());
  }
};
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("Other", "Input of Other type Variable");
    AddComment("This Op is only for IndicateVarDataType inferface test.");
  }
};

template <typename DeviceContext, typename T>
386
class EmptyTestKernel : public OpKernel<T> {
387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
 public:
  void Compute(const ExecutionContext& ctx) const {}
};

}  // namespace framework
}  // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
    indicate_lod_tensor_data_type_test,
    paddle::framework::IndicateLoDTensorDataTypeTest,
    paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
    indicate_selected_rows_data_type_test,
    paddle::framework::IndicateSelectedRowsDataTypeTest,
    paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
    indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest,
    paddle::framework::IndicateOtherDataTypeTestProtoMaker);

REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test,
407
                       paddle::framework::EmptyTestKernel<
408 409
                           paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test,
410
                       paddle::framework::EmptyTestKernel<
411 412
                           paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test,
413
                       paddle::framework::EmptyTestKernel<
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
                           paddle::platform::CPUDeviceContext, int>);

TEST(IndicateVarDataTypeTest, lodtensor) {
  paddle::framework::InitDevices(true);
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_lod_tensor_data_type_test");
  BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("lodtensor_1");
  var->GetMutable<paddle::framework::LoDTensor>();

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
432
  } catch (paddle::platform::EnforceNotMet& err) {
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(
        ex_msg.find(
            "The Tensor in the indicate_lod_tensor_data_type_test Op's "
            "Input Variable LoDTensor(lodtensor_1) is not initialized") !=
        std::string::npos);
  }
  ASSERT_TRUE(caught);
}

TEST(IndicateVarDataTypeTest, selectedrows) {
  paddle::framework::InitDevices(true);
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_selected_rows_data_type_test");
  BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("selected_rows_1");
  var->GetMutable<paddle::framework::SelectedRows>();

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
460
  } catch (paddle::platform::EnforceNotMet& err) {
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(
        ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test "
                    "Op's Input Variable SelectedRows(selected_rows_1) is not "
                    "initialized") != std::string::npos);
  }
  ASSERT_TRUE(caught);
}

TEST(IndicateVarDataTypeTest, other) {
  paddle::framework::InitDevices(true);
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type("indicate_other_data_type_test");
  BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs());

  paddle::platform::CPUPlace cpu_place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* var = scope.Var("lod_tensor_array_1");
  var->GetMutable<paddle::framework::LoDTensorArray>();

  bool caught = false;
  try {
    op->Run(scope, cpu_place);
Z
Zeng Jinle 已提交
487
  } catch (paddle::platform::EnforceNotMet& err) {
488 489 490 491 492 493 494 495 496 497
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of "
                            "indicate_other_data_type_test Op used to "
                            "determine kernel data type "
                            "is empty or not LoDTensor or SelectedRows") !=
                std::string::npos);
  }
  ASSERT_TRUE(caught);
}
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582

namespace paddle {
namespace framework {

class GetLoDLevelTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
                      "Input(X) should not be null.");
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) should not be null.");
    PADDLE_ENFORCE_GT(ctx->GetLoDLevel("X"), 0,
                      "The LoD level Input(X) should be larger than 0.");
  }
};

class SetLoDLevelTest : public OperatorWithKernel {
 public:
  using OperatorWithKernel::OperatorWithKernel;

 protected:
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
                      "Input(X) should not be null.");
    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
                      "Output(Out) should not be null.");
    ctx->SetLoDLevel("Out", 1);
  }
};

class GetSetLoDLevelTestMaker : public OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("X", "(LoDTensor) Input Variable.");
    AddOutput("Out", "(LoDTensor) Output Variable.");
    AddComment("This Op is only for Get/SetLoDLevel inferface test.");
  }
};

}  // namespace framework
}  // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(get_lod_level_test,
                             paddle::framework::GetLoDLevelTest,
                             paddle::framework::GetSetLoDLevelTestMaker);
REGISTER_OP_CPU_KERNEL(get_lod_level_test,
                       paddle::framework::EmptyTestKernel<
                           paddle::platform::CPUDeviceContext, float>);

REGISTER_OP_WITHOUT_GRADIENT(set_lod_level_test,
                             paddle::framework::SetLoDLevelTest,
                             paddle::framework::GetSetLoDLevelTestMaker);
REGISTER_OP_CPU_KERNEL(set_lod_level_test,
                       paddle::framework::EmptyTestKernel<
                           paddle::platform::CPUDeviceContext, float>);

void SetGetLoDLevelTestMain(std::string op_type) {
  paddle::framework::InitDevices(false, {});
  paddle::framework::proto::OpDesc op_desc;
  op_desc.set_type(op_type);
  BuildVar("X", {"x.0"}, op_desc.add_inputs());
  BuildVar("Out", {"out.0"}, op_desc.add_outputs());

  paddle::platform::CPUPlace place;
  paddle::framework::Scope scope;

  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
  auto* x_var = scope.Var("x.0");
  auto* x = x_var->GetMutable<paddle::framework::LoDTensor>();
  x->mutable_data<float>(paddle::framework::make_ddim({64}), place);
  auto* out_var = scope.Var("out.0");
  out_var->GetMutable<paddle::framework::LoDTensor>();

  bool caught = false;
  std::string err_str =
      (op_type == "get_lod_level_test") ? "GetLoDLevel" : "SetLoDLevel";
  err_str +=
      " is only used in compile time. The calculation of output's actual lod "
      "is different among operators so that should be set in the runtime "
      "kernel.";
  try {
    op->Run(scope, place);
Z
Zeng Jinle 已提交
583
  } catch (paddle::platform::EnforceNotMet& err) {
584 585 586 587 588 589 590 591 592 593
    caught = true;
    std::string ex_msg = err.what();
    EXPECT_TRUE(ex_msg.find(err_str) != std::string::npos);
  }
  ASSERT_TRUE(caught);
}

TEST(GetLoDLevelTest, base) { SetGetLoDLevelTestMain("get_lod_level_test"); }

TEST(SetLoDLevelTest, base) { SetGetLoDLevelTestMain("set_lod_level_test"); }