operator_test.cc 8.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
D
dzhwinter 已提交
15

Y
Yi Wang 已提交
16 17 18 19
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
Q
Qiao Longfei 已提交
20 21 22 23

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
24 25 26
static int op_run_num = 0;

class OpWithoutKernelTest : public OperatorBase {
Q
Qiao Longfei 已提交
27
 public:
Y
Yu Yang 已提交
28 29
  OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
                      const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
30
      : OperatorBase(type, inputs, outputs, attrs), x(1) {}
31 32 33 34

 private:
  void RunImpl(const Scope& scope,
               const platform::Place& place) const override {
Y
Yu Yang 已提交
35 36 37 38
    ++op_run_num;
    ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
    ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
    ASSERT_EQ(scope.FindVar(inputs_.at("input")[0]), nullptr);
Q
Qiao Longfei 已提交
39
    ASSERT_EQ(x, 1);
Y
Yu Yang 已提交
40
    ASSERT_NE(scope.FindVar(outputs_.at("output")[0]), nullptr);
Q
Qiao Longfei 已提交
41
  }
Q
Qiao Longfei 已提交
42 43

 public:
Y
Yu Yang 已提交
44
  int x{0};
Q
Qiao Longfei 已提交
45 46
};

D
dzhwinter 已提交
47
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
Q
Qiao Longfei 已提交
48
 public:
D
dzhwinter 已提交
49
  OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
Q
Qiao Longfei 已提交
50 51 52
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("input", "input of test op");
    AddOutput("output", "output of test op");
Q
Qiao Longfei 已提交
53
    AddAttr<float>("scale", "scale of cosine op");
Q
Qiao Longfei 已提交
54 55 56 57 58 59 60
    AddComment("This is test op");
  }
};

}  // namespace framework
}  // namespace paddle

Y
Yu Yang 已提交
61 62
static void BuildVar(const std::string& param_name,
                     std::initializer_list<const char*> arguments,
63
                     paddle::framework::proto::OpDesc::Var* var) {
Y
Yu Yang 已提交
64 65 66 67 68 69
  var->set_parameter(param_name);
  for (auto& arg_name : arguments) {
    *var->mutable_arguments()->Add() = arg_name;
  }
}

D
dzhwinter 已提交
70 71 72
REGISTER_OP_WITHOUT_GRADIENT(test_operator,
                             paddle::framework::OpWithoutKernelTest,
                             paddle::framework::OpWithoutKernelCheckerMaker);
Q
Qiao Longfei 已提交
73 74

TEST(OperatorBase, all) {
75
  paddle::framework::InitDevices();
76
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
77
  op_desc.set_type("test_operator");
Y
Yu Yang 已提交
78 79
  BuildVar("input", {"IN1"}, op_desc.add_inputs());
  BuildVar("output", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
80

Q
Qiao Longfei 已提交
81 82
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
83
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
84
  attr->set_f(3.14);
Q
Qiao Longfei 已提交
85

D
dzhwinter 已提交
86
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
87
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
88

89
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dongzhihong 已提交
90
  scope.Var("OUT1");
Q
Qiao Longfei 已提交
91
  ASSERT_EQ(paddle::framework::op_run_num, 0);
D
dzhwinter 已提交
92
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
93
  ASSERT_EQ(paddle::framework::op_run_num, 1);
Q
Qiao Longfei 已提交
94 95 96 97 98
}

namespace paddle {
namespace framework {

Q
Qiao Longfei 已提交
99 100 101 102
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
 public:
  OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yan Chunwei 已提交
103 104 105 106
    AddInput("x", "input of test op");
    AddOutput("y", "output of test op");
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
107
        .GreaterThan(0.0);
Q
Qiao Longfei 已提交
108 109 110 111
    AddComment("This is test op");
  }
};

Q
Qiao Longfei 已提交
112 113
static int cpu_kernel_run_num = 0;

Q
Qiao Longfei 已提交
114
class OpWithKernelTest : public OperatorWithKernel {
Y
Yu Yang 已提交
115 116 117
 public:
  using OperatorWithKernel::OperatorWithKernel;

Y
Yu Yang 已提交
118
 protected:
119
  void InferShape(framework::InferShapeContext* ctx) const override {}
120 121
  OpKernelType GetExpectedKernelType(
      const ExecutionContext& ctx) const override {
122
    return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
Y
Yu Yang 已提交
123
  }
Q
Qiao Longfei 已提交
124 125
};

126
template <typename T1, typename T2>
Y
Yu Yang 已提交
127
class CPUKernelTest : public OpKernel<float> {
Q
Qiao Longfei 已提交
128
 public:
129
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
130
    std::cout << ctx.op().DebugString() << std::endl;
Q
Qiao Longfei 已提交
131
    cpu_kernel_run_num++;
Q
qiaolongfei 已提交
132 133
    ASSERT_EQ(ctx.op().Input("x"), "IN1");
    ASSERT_EQ(ctx.op().Output("y"), "OUT1");
Y
Yan Chunwei 已提交
134 135 136 137 138 139 140 141 142
  }
};

class OpKernelTestMultiInputsProtoAndCheckerMaker
    : public OpProtoAndCheckerMaker {
 public:
  OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
                                              OpAttrChecker* op_checker)
      : OpProtoAndCheckerMaker(proto, op_checker) {
Y
Yu Yang 已提交
143
    AddInput("xs", "inputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
144
    AddInput("k", "input of test op");
Y
Yu Yang 已提交
145
    AddOutput("ys", "outputs of test op").AsDuplicable();
Y
Yan Chunwei 已提交
146 147
    AddAttr<float>("scale", "scale of cosine op")
        .SetDefault(1.0)
F
fengjiayi 已提交
148
        .GreaterThan(0.0);
Y
Yan Chunwei 已提交
149 150 151 152
    AddComment("This is test op");
  }
};

Y
Yu Yang 已提交
153
class CPUKernalMultiInputsTest : public OpKernel<float> {
Y
Yan Chunwei 已提交
154
 public:
155
  void Compute(const ExecutionContext& ctx) const {
Q
qiaolongfei 已提交
156
    auto xs = ctx.op().Inputs("xs");
Y
Yan Chunwei 已提交
157 158 159 160 161
    ASSERT_EQ(xs.size(), 3UL);
    ASSERT_EQ(xs[0], "x0");
    ASSERT_EQ(xs[1], "x1");
    ASSERT_EQ(xs[2], "x2");

162
    auto inVar0 = ctx.MultiInputVar("xs");
163
    ASSERT_EQ(inVar0.size(), 3U);
164 165 166 167 168

    auto intVar1 = ctx.InputVar("k");
    ASSERT_NE(intVar1, nullptr);

    auto outVar0 = ctx.MultiOutputVar("ys");
169
    ASSERT_EQ(outVar0.size(), 2U);
170 171

    auto inTensor0 = ctx.MultiInput<Tensor>("xs");
172
    ASSERT_EQ(inTensor0.size(), 3U);
173 174 175 176 177

    auto intTensor1 = ctx.Input<Tensor>("k");
    ASSERT_NE(intTensor1, nullptr);

    auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
178
    ASSERT_EQ(outTensor0.size(), 2U);
179

Q
qiaolongfei 已提交
180
    auto k = ctx.op().Input("k");
Y
Yan Chunwei 已提交
181 182
    ASSERT_EQ(k, "k0");

Q
qiaolongfei 已提交
183
    auto ys = ctx.op().Outputs("ys");
Y
Yan Chunwei 已提交
184 185 186
    ASSERT_EQ(ys.size(), 2UL);
    ASSERT_EQ(ys[0], "y0");
    ASSERT_EQ(ys[1], "y1");
Q
Qiao Longfei 已提交
187 188 189
  }
};

Y
Yu Yang 已提交
190 191 192
}  // namespace framework
}  // namespace paddle

F
fengjiayi 已提交
193 194 195
REGISTER_OP_WITHOUT_GRADIENT(
    op_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestProtoAndCheckerMaker);
196 197
REGISTER_OP_CPU_KERNEL(op_with_kernel,
                       paddle::framework::CPUKernelTest<float, float>);
Q
Qiao Longfei 已提交
198

Y
Yan Chunwei 已提交
199
// test with single input
Q
Qiao Longfei 已提交
200
TEST(OpKernel, all) {
201
  paddle::framework::InitDevices();
202
  paddle::framework::proto::OpDesc op_desc;
Q
Qiao Longfei 已提交
203
  op_desc.set_type("op_with_kernel");
Y
Fix CI  
Yu Yang 已提交
204 205
  BuildVar("x", {"IN1"}, op_desc.add_inputs());
  BuildVar("y", {"OUT1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
206

Q
Qiao Longfei 已提交
207 208
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
209
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Q
Qiao Longfei 已提交
210 211
  attr->set_f(3.14);

D
dzhwinter 已提交
212
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
213
  paddle::framework::Scope scope;
Q
Qiao Longfei 已提交
214

215
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
Q
Qiao Longfei 已提交
216
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
D
dzhwinter 已提交
217
  op->Run(scope, cpu_place);
Q
Qiao Longfei 已提交
218
  ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
Q
Qiao Longfei 已提交
219
}
Y
Yan Chunwei 已提交
220

F
fengjiayi 已提交
221 222 223
REGISTER_OP_WITHOUT_GRADIENT(
    op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
    paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
Y
Yan Chunwei 已提交
224 225 226 227 228 229 230
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
                       paddle::framework::CPUKernalMultiInputsTest);

// test with multi inputs
TEST(OpKernel, multi_inputs) {
  using namespace paddle::framework;

231
  paddle::framework::InitDevices();
232
  proto::OpDesc op_desc;
D
dzhwinter 已提交
233

Y
Yan Chunwei 已提交
234
  op_desc.set_type("op_multi_inputs_with_kernel");
Y
Yu Yang 已提交
235 236 237
  BuildVar("xs", {"x0", "x1", "x2"}, op_desc.add_inputs());
  BuildVar("k", {"k0"}, op_desc.add_inputs());
  BuildVar("ys", {"y0", "y1"}, op_desc.add_outputs());
Y
Yu Yang 已提交
238

Y
Yan Chunwei 已提交
239 240
  auto attr = op_desc.mutable_attrs()->Add();
  attr->set_name("scale");
241
  attr->set_type(paddle::framework::proto::AttrType::FLOAT);
Y
Yan Chunwei 已提交
242 243
  attr->set_f(3.14);

D
dzhwinter 已提交
244
  paddle::platform::CPUPlace cpu_place;
Y
Yu Yang 已提交
245
  paddle::framework::Scope scope;
Q
QI JUN 已提交
246 247 248 249 250 251
  scope.Var("x0")->GetMutable<LoDTensor>();
  scope.Var("x1")->GetMutable<LoDTensor>();
  scope.Var("x2")->GetMutable<LoDTensor>();
  scope.Var("k0")->GetMutable<LoDTensor>();
  scope.Var("y0")->GetMutable<LoDTensor>();
  scope.Var("y1")->GetMutable<LoDTensor>();
Y
Yan Chunwei 已提交
252

253
  auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
D
dzhwinter 已提交
254
  op->Run(scope, cpu_place);
Y
Yan Chunwei 已提交
255
}
Y
Yu Yang 已提交
256 257 258 259

class OperatorClone : public paddle::framework::OperatorBase {
 public:
  DEFINE_OP_CLONE_METHOD(OperatorClone);
Y
Yu Yang 已提交
260 261 262
  OperatorClone(const std::string& type,
                const paddle::framework::VariableNameMap& inputs,
                const paddle::framework::VariableNameMap& outputs,
Y
Yu Yang 已提交
263 264
                const paddle::framework::AttributeMap& attrs)
      : OperatorBase(type, inputs, outputs, attrs) {}
265 266 267 268

 private:
  void RunImpl(const paddle::framework::Scope& scope,
               const paddle::platform::Place& place) const override {}
Y
Yu Yang 已提交
269 270 271
};

TEST(Operator, Clone) {
272
  paddle::framework::InitDevices();
Y
Yiqun Liu 已提交
273 274 275
  OperatorClone a("ABC", paddle::framework::VariableNameMap{},
                  paddle::framework::VariableNameMap{},
                  paddle::framework::AttributeMap{});
Y
Yu Yang 已提交
276
  auto b = a.Clone();
Y
Yu Yang 已提交
277
  ASSERT_EQ(a.Type(), b->Type());
278
}