提交 b78b416f 编写于 作者: S superjomn

add executor

上级 e579eb00
cc_library(executor_lite SRCS executor.cc)
cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc)
cc_library(scope_lite SRCS scope.cc)
cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite)
cc_library(executor_lite SRCS executor.cc DEPS scope_lite tensor_lite op_lite op_registry_lite
#TODO(Superjomn) remove these dependencies from original framework
proto_desc)
cc_test(test_scope_lite SRCS scope_test.cc DEPS scope_lite)
cc_test(test_kernel_lite SRCS kernel_test.cc DEPS target_wrapper_x86)
cc_test(test_op_lite SRCS op_lite_test.cc DEPS op_lite)
cc_test(test_tensor_lite SRCS tensor_test.cc)
cc_test(test_executor_lite SRCS executor_test.cc DEPS executor_lite ops_lite host_kernels)
......@@ -11,3 +11,67 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/scope.h"
namespace paddle {
namespace lite {
// The Executor is used to run the operators.
class Executor {
public:
Executor(lite::Scope* scope, const std::vector<OpLite::Place>& valid_places)
: scope_(scope), valid_places_(valid_places) {}
// Create temporary variables.
void PrepareWorkspace(framework::ProgramDesc& program, lite::Scope* scope) {
CHECK(!exec_scope_) << "Duplicate PrepareWorkspace found";
exec_scope_ = &scope_->NewScope();
for (auto var_desc : program.Block(0).AllVars()) {
if (!var_desc->Persistable()) {
auto* var = exec_scope_->Var(var_desc->Name());
LOG(INFO) << "create tmp var " << var_desc->Name() << " " << var;
}
}
}
// Build from a program and scope.
void Build(framework::ProgramDesc& program) {
CHECK(ops_.empty()) << "Executor duplicate Build found";
// Create operators.
for (auto* op_desc : program.Block(0).AllOps()) {
auto op_type = op_desc->Type();
LOG(INFO) << "create Op [" << op_type << "]";
ops_.emplace_back(LiteOpRegistry::Global().Create(op_type));
// pick initial kernel
ops_.back()->PickKernel(valid_places_);
ops_.back()->Attach(*op_desc, exec_scope_);
}
}
// Run the program.
void Run() {
for (auto& op : ops_) {
LOG(INFO) << op->DebugString();
// TODO(Superjomn) check only once
op->CheckShape();
op->InferShape();
op->Run();
}
}
private:
std::vector<std::unique_ptr<OpLite>> ops_;
lite::Scope* scope_{};
std::vector<OpLite::Place> valid_places_;
lite::Scope* exec_scope_{};
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/executor.h"
#include <gtest/gtest.h>
#include <vector>
namespace paddle {
namespace lite {
TEST(executor, test) {
std::vector<OpLite::Place> valid_places{
OpLite::Place{TARGET(kHost), PRECISION(kFloat)}};
Scope scope;
Executor executor(&scope, valid_places);
framework::ProgramDesc program;
program.MutableBlock(0)->Var("x");
program.MutableBlock(0)->Var("bias")->SetPersistable(true);
program.MutableBlock(0)->Var("w")->SetPersistable(true);
program.MutableBlock(0)->Var("output");
auto& op_desc = *program.MutableBlock(0)->AppendOp();
op_desc.SetType("fc");
op_desc.SetInput("Input", {"x"});
op_desc.SetInput("W", {"w"});
op_desc.SetInput("Bias", {"bias"});
op_desc.SetOutput("Out", {"output"});
op_desc.SetAttr("in_num_col_dims", static_cast<int>(1));
program.Flush();
auto* w = scope.Var("w")->GetMutable<Tensor>();
w->Resize({20, 20});
auto* x = scope.Var("x")->GetMutable<Tensor>();
x->Resize({1, 10, 20});
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
bias->Resize({1, 20});
bias->mutable_data<float>();
w->mutable_data<float>();
x->mutable_data<float>();
executor.PrepareWorkspace(program, &scope);
executor.Build(program);
executor.Run();
}
} // namespace lite
} // namespace paddle
USE_LITE_OP(fc);
USE_LITE_KERNEL(fc, kHost, kFloat);
......@@ -44,8 +44,8 @@ class OpLiteRegistor : public Registor<OpClass> {
OpLiteRegistor(const std::string &op_type)
: Registor<OpClass>([&] {
LiteOpRegistry::Global().Register(
op_type, []() -> std::unique_ptr<OpLite> {
return std::unique_ptr<OpLite>(new OpClass);
op_type, [op_type]() -> std::unique_ptr<OpLite> {
return std::unique_ptr<OpLite>(new OpClass(op_type));
});
}) {}
};
......@@ -134,11 +134,15 @@ class KernelRegistor : public lite::Registor<KernelType> {
#define LITE_OP_REGISTER_FAKE(op_type__) op_type__##__registry__
#define REGISTER_LITE_OP(op_type__, OpClass) \
static paddle::lite::OpLiteRegistor<OpClass> LITE_OP_REGISTER_INSTANCE( \
op_type__)(#op_type__);
op_type__)(#op_type__); \
int touch_op_##op_type__() { \
return LITE_OP_REGISTER_INSTANCE(op_type__).Touch(); \
}
#define USE_LITE_OP(op_type__) \
int LITE_OP_REGISTER_FAKE(op_type__)((unused)) = \
LITE_OP_REGISTER_INSTANCE(op_type__).Touch();
extern int touch_op_##op_type__(); \
int LITE_OP_REGISTER_FAKE(op_type__) __attribute__((unused)) = \
touch_op_##op_type__();
// Kernel registry
#define LITE_KERNEL_REGISTER(op_type__, target__, precision__) \
......
cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite)
cc_library(host_kernels DEPS
fc_compute_host
relu_compute_host)
cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite)
......@@ -14,7 +14,6 @@
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/kernels/fc_compute.h"
#include "paddle/fluid/lite/operators/fc_op.h"
namespace paddle {
......
......@@ -14,17 +14,33 @@
#pragma once
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class ReluCompute final : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
class ReluCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override {
auto& theparam = param<operators::ReluParam>();
auto n = product(theparam.input->dims());
const float* input = theparam.input->data<float>();
float* output = theparam.output->mutable_data<float>();
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
}
TargetType target() const override { return TARGET(kHost); }
PrecisionType precision() const override { return PRECISION(kFloat); }
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(relu, kHost, kFloat,
paddle::lite::kernels::host::ReluCompute);
cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite proto_desc)
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_library(ops_lite DEPS
fc_op_lite
relu_op_lite)
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host)
......@@ -43,7 +43,7 @@ bool FcOpLite::CheckShape() const {
static_cast<size_t>(param_.in_num_col_dims));
param_.in_mat_dims = lite::flatten_to_2d(input_dims, param_.in_num_col_dims);
CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
// CHECK_EQ_OR_FALSE(param_.in_mat_dims[1], w_dims[0]);
return true;
}
......
......@@ -35,11 +35,13 @@ class FcOpLite : public OpLite {
bool InferShape() const override;
/*
bool Run() override {
CHECK(kernel_);
kernel_->Run();
return true;
}
*/
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override {
......@@ -51,10 +53,12 @@ class FcOpLite : public OpLite {
param_.input = scope->FindVar(input)->GetMutable<Tensor>();
param_.w = scope->FindVar(W)->GetMutable<Tensor>();
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims =
boost::get<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
return true;
......
......@@ -25,7 +25,7 @@ namespace lite {
namespace operators {
struct FcParam {
Tensor* input{nullptr};
Tensor* input{};
Tensor* w{};
Tensor* bias{};
Tensor* output{};
......@@ -33,7 +33,12 @@ struct FcParam {
int in_num_col_dims{0};
};
using param_t = variant<FcParam>;
struct ReluParam {
Tensor* input{};
Tensor* output{};
};
using param_t = variant<FcParam, ReluParam>;
} // namespace operators
} // namespace lite
......
......@@ -31,10 +31,15 @@ bool ReluOp::InferShape() const {
return true;
}
bool ReluOp::Run() { return false; }
bool ReluOp::Attach(const framework::OpDesc &opdesc, framework::Scope *scope) {
return false;
bool ReluOp::Attach(const framework::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<Tensor>();
CHECK(param_.input);
CHECK(param_.output);
kernel_->SetParam(param_);
return true;
}
REGISTER_LITE_OP(relu, ReluOp);
......
......@@ -15,6 +15,7 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -22,27 +23,23 @@ namespace paddle {
namespace lite {
namespace operators {
struct ReluParam {
Tensor* input{nullptr};
Tensor* output{nullptr};
};
class ReluOp : public OpLite {
public:
ReluOp() {}
ReluOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool Run() override;
bool Attach(const framework::OpDesc& opdesc,
framework::Scope* scope) override;
bool Attach(const framework::OpDesc &opdesc, lite::Scope *scope) override;
std::string DebugString() const override { return "tanh"; }
void StaticPickKernel(const std::vector<OpTarget>& valid_targets) override {}
void StaticPickKernel(
const std::vector<OpLite::Place> &valid_targets) override {
kernel_ = std::move(CreateKernels(valid_targets).front());
}
private:
mutable ReluParam param_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册