operator_test.cc 9.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
D
dzhwinter 已提交
15

Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
19
#include "paddle/fluid/platform/init.h"
Q
Qiao Longfei 已提交
20 21 22 23

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
24 25 26
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
27
 public:
Y
Yu Yang 已提交
28 29
  OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
                      const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
30
      : OperatorBase(type, inputs, outputs, attrs), x(1) {}
31 32 33 34

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Y
Yu Yang 已提交
35 36 37 38
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
39
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
40
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
41
  }
Q
Qiao Longfei 已提交
42 43

 public:
Y
Yu Yang 已提交
44
  int x{0};
Q
Qiao Longfei 已提交
45 46
};

D
dzhwinter 已提交
47
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
48
 public:
Y
Yu Yang 已提交
49
  void Make() {
Q
Qiao Longfei 已提交
50 51
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
52
    AddAttr<float>("scale", "scale of cosine op");
X
Xin Pan 已提交
53 54
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
55 56 57 58 59 60 61
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
62 63
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
64
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
65 66 67 68 69 70
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

D
dzhwinter 已提交
71 72 73
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
                             paddle::framework::OpWithoutKernelTest,
                             paddle::framework::OpWithoutKernelCheckerMaker);
Q
Qiao Longfei 已提交
74 75

TEST(OperatorBase, all) {
X
Xin Pan 已提交
76
  paddle::framework::InitDevices(true);
77
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
78
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
79 80
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
81

Q
Qiao Longfei 已提交
82 83
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
84
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
85
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
86

D
dzhwinter 已提交
87
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
88
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
89

90
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dongzhihong 已提交
91
  scope.Var("OUT1");
Q
Qiao Longfei 已提交
92
  ASSERT_EQ(paddle::framework::op_run_num, 0);
D
dzhwinter 已提交
93
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
94
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
95 96 97 98 99
}

namespace paddle {
namespace framework {

X
Xin Pan 已提交
100 101
static int special_type_value = 1;

Q
Qiao Longfei 已提交
102 103
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
104
  void Make() {
Y
Yan Chunwei 已提交
105 106 107 108
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
109
        .GreaterThan(0.0);
X
Xin Pan 已提交
110 111
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Q
Qiao Longfei 已提交
112 113 114 115
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
116
static int cpu_kernel_run_num = 0;
X
Xin Pan 已提交
117
static int cpu_kernel2_run_num = 0;
Q
Qiao Longfei 已提交
118

Q
Qiao Longfei 已提交
119
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
120 121 122
 public:
  using OperatorWithKernel::OperatorWithKernel;

Y
Yu Yang 已提交
123
 protected:
124
  void InferShape(framework::InferShapeContext* ctx) const override {}
125 126
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
X
Xin Pan 已提交
127 128 129 130
    int sub_type = ctx.Attr<int>("kernel_sub_type");
    return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
                        framework::DataLayout::kAnyLayout,
                        framework::LibraryType::kPlain, sub_type);
Y
Yu Yang 已提交
131
  }
Q
Qiao Longfei 已提交
132 133
};

134
template <typename T1, typename T2>
Y
Yu Yang 已提交
135
class CPUKernelTest : public OpKernel<float> {
Q
Qiao Longfei 已提交
136
 public:
137
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
138
    std::cout << ctx.op().DebugString() << std::endl;
Q
Qiao Longfei 已提交
139
    cpu_kernel_run_num++;
Q
qiaolongfei 已提交
140 141
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
Y
Yan Chunwei 已提交
142 143 144
  }
};

X
Xin Pan 已提交
145 146 147 148 149 150 151 152 153 154 155
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
 public:
  void Compute(const ExecutionContext& ctx) const {
    std::cout << ctx.op().DebugString() << std::endl;
    cpu_kernel2_run_num++;
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
  }
};

Y
Yan Chunwei 已提交
156 157 158
class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
159
  void Make() {
Y
Yu Yang 已提交
160
    AddInput("xs", "inputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
161
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
162
    AddOutput("ys", "outputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
163 164
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
165
        .GreaterThan(0.0);
X
Xin Pan 已提交
166 167
    AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
        .SetDefault(0);
Y
Yan Chunwei 已提交
168 169 170 171
    AddComment("This is test op");
  }
};

Y
Yu Yang 已提交
172
class CPUKernalMultiInputsTest : public OpKernel<float> {
Y
Yan Chunwei 已提交
173
 public:
174
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
175
    auto xs = ctx.op().Inputs("xs");
Y
Yan Chunwei 已提交
176 177 178 179 180
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

181
    auto inVar0 = ctx.MultiInputVar("xs");
182
    ASSERT_EQ(inVar0.size(), 3U);
183 184 185 186 187

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
188
    ASSERT_EQ(outVar0.size(), 2U);
189 190

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
191
    ASSERT_EQ(inTensor0.size(), 3U);
192 193 194 195 196

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
197
    ASSERT_EQ(outTensor0.size(), 2U);
198

Q
qiaolongfei 已提交
199
    auto k = ctx.op().Input("k");
Y
Yan Chunwei 已提交
200 201
    ASSERT_EQ(k, "k0");

Q
qiaolongfei 已提交
202
    auto ys = ctx.op().Outputs("ys");
Y
Yan Chunwei 已提交
203 204 205
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
206 207 208
  }
};

Y
Yu Yang 已提交
209 210 211
}  // namespace framework
}  // namespace paddle

F
fengjiayi 已提交
212 213 214
REGISTER_OP_WITHOUT_GRADIENT(
    op_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestProtoAndCheckerMaker);
X
Xin Pan 已提交
215

X
Xin Pan 已提交
216 217
REGISTER_OP_CPU_KERNEL(op_with_kernel,
                       paddle::framework::CPUKernelTest<float, float>);
X
Xin Pan 已提交
218 219

REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
X
Xin Pan 已提交
220 221
    op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME,
    paddle::framework::special_type_value,
X
Xin Pan 已提交
222
    paddle::framework::CPUKernel2Test<float, float>);
Q
Qiao Longfei 已提交
223

Y
Yan Chunwei 已提交
224
// test with single input
Q
Qiao Longfei 已提交
225
TEST(OpKernel, all) {
X
Xin Pan 已提交
226
  paddle::framework::InitDevices(true);
227
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
228
  op_desc.set_type("op_with_kernel");
Y
Fix CI  
Yu Yang 已提交
229 230
  BuildVar("x", {"IN1"}, op_desc.add_inputs());
  BuildVar("y", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
231

Q
Qiao Longfei 已提交
232 233
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
234
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
235 236
  attr->set_f(3.14);

D
dzhwinter 已提交
237
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
238
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
239

240
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
241
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
D
dzhwinter 已提交
242
  op->Run(scope, cpu_place);
X
Xin Pan 已提交
243
  // kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
Q
Qiao Longfei 已提交
244
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
X
Xin Pan 已提交
245 246 247 248 249 250 251 252
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);

  attr = op_desc.mutable_attrs()->Add();
  attr->set_name("kernel_sub_type");
  attr->set_type(paddle::framework::proto::AttrType::INT);
  attr->set_i(1);
  auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
  op2->Run(scope, cpu_place);
X
Xin Pan 已提交
253
  // kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
X
Xin Pan 已提交
254 255
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
  ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
Q
Qiao Longfei 已提交
256
}
Y
Yan Chunwei 已提交
257

F
fengjiayi 已提交
258 259 260
REGISTER_OP_WITHOUT_GRADIENT(
    op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
Y
Yan Chunwei 已提交
261 262 263 264 265
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
X
Xin Pan 已提交
266
  paddle::framework::InitDevices(true);
267
  paddle::framework::proto::OpDesc op_desc;
D
dzhwinter 已提交
268

Y
Yan Chunwei 已提交
269
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
270 271 272
  BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
  BuildVar("k", {"k0"}, op_desc.add_inputs());
  BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
273

Y
Yan Chunwei 已提交
274 275
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
276
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Y
Yan Chunwei 已提交
277 278
  attr->set_f(3.14);

D
dzhwinter 已提交
279
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
280
  paddle::framework::Scope scope;
281 282 283 284 285 286
  scope.Var("x0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x1")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("x2")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("k0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y0")->GetMutable<paddle::framework::LoDTensor>();
  scope.Var("y1")->GetMutable<paddle::framework::LoDTensor>();
Y
Yan Chunwei 已提交
287

288
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
289
  op->Run(scope, cpu_place);
Y
Yan Chunwei 已提交
290
}
291 292 293 294 295 296 297 298 299

TEST(Functions, all) {
  std::string var_name("X");
  std::string grad_var_name = paddle::framework::GradVarName(var_name);
  ASSERT_EQ(grad_var_name.c_str(), "X@GRAD");
  std::string original_var_name =
      paddle::framework::OriginVarName(grad_var_name);
  ASSERT_EQ(original_var_name.c_str(), "X");
}