未验证 提交 e9f33320 编写于 作者: Y Yan Chunwei 提交者: GitHub

Add some ops for training (#17442)

上级 72fb4adb
......@@ -232,6 +232,17 @@ using OpKernelConfigsMap =
std::unordered_map<OpKernelType, std::vector<KernelConfig>,
OpKernelType::Hash>;
class OpDuppy : public OperatorBase {
public:
OpDuppy() : OperatorBase("duppy", {}, {}, {}) {}
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
OpDuppy op_duppy;
Scope scope_duppy;
RuntimeContext runtime_context_duppy({}, {});
class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
......@@ -244,6 +255,13 @@ class ExecutionContext {
ctx_(ctx),
kernel_configs_(configs) {}
ExecutionContext(const platform::DeviceContext& device_context)
: op_(op_duppy),
scope_(scope_duppy),
device_context_(device_context),
ctx_(runtime_context_duppy),
kernel_configs_(nullptr) {}
const OperatorBase& op() const { return op_; }
const Scope& scope() const { return scope_; }
......
include_directories(lite)
\ No newline at end of file
......@@ -25,7 +25,7 @@ namespace paddle {
namespace lite {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void CXXPredictor::SaveModel(const std::string &dir) {
void ExecutorLite::SaveModel(const std::string &dir) {
MkDirRecursively(dir.c_str());
program_->PersistModel(dir, program_desc_);
}
......
......@@ -28,14 +28,24 @@ namespace lite {
struct Config {};
class CXXPredictor {
class ExecutorLite {
public:
CXXPredictor() { scope_ = std::make_shared<Scope>(); }
ExecutorLite() { scope_ = std::make_shared<Scope>(); }
explicit ExecutorLite(const std::shared_ptr<lite::Scope>& root_scope) {
scope_ = root_scope;
}
void Build(const std::string& model_path, const Place& prefer_place,
const std::vector<Place>& valid_places) {
LoadModel(model_path, scope_.get(), &program_desc_);
Program program(program_desc_, scope_, valid_places);
Build(program_desc_, prefer_place, valid_places);
}
void Build(const framework::proto::ProgramDesc& desc,
const Place& prefer_place,
const std::vector<Place>& valid_places) {
program_desc_ = desc;
Program program(desc, scope_, valid_places);
optimizer_.KernelPickPreferPlace(prefer_place);
core::KernelPickFactor factor;
......@@ -81,5 +91,57 @@ class CXXPredictor {
std::unique_ptr<RuntimeProgram> program_;
};
/*
* An executor for training.
*
* Usage:
*
* CXXTrainer trainer(...);
* trainer.RunStartupProgram(...);
* auto exe = BuildMainProgramExecutor(...);
*
* for (auto& epoch : epoches) {
* auto* tensor0 = exe.GetInput(...);
* // fill data for tensor0
* exe.Run();
* }
*/
class CXXTrainer {
public:
CXXTrainer(const std::shared_ptr<lite::Scope>& root_scope,
const Place& preferred_place,
const std::vector<Place>& valid_places)
: scope_(root_scope),
preferred_place_(preferred_place),
valid_places_(valid_places),
main_program_executor_(ExecutorLite(scope_)) {}
// Build the RuntimeProgram cache for the main program. The cache will run
// multiple times for the epoches.
// NOTE Just support to execute the 0-th block currently.
ExecutorLite& BuildMainProgramExecutor(
const framework::proto::ProgramDesc& desc, int block_id = 0) {
main_program_executor_.Build(desc, preferred_place_, valid_places_);
return main_program_executor_;
}
// Run the startup program. It just executes once, no cache needed.
void RunStartupProgram(const framework::proto::ProgramDesc& desc,
int block_id = 0) {
ExecutorLite exe(scope_);
exe.Build(desc, preferred_place_, valid_places_);
exe.Run();
}
private:
std::shared_ptr<lite::Scope> scope_;
Place preferred_place_;
std::vector<Place> valid_places_;
// The training program.
ExecutorLite main_program_executor_;
};
} // namespace lite
} // namespace paddle
......@@ -20,7 +20,7 @@ namespace paddle {
namespace lite {
void Run(const char* model_dir) {
lite::CXXPredictor predictor;
lite::Executor predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
......
......@@ -22,11 +22,15 @@
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
// For training.
DEFINE_string(startup_program_path, "", "");
DEFINE_string(main_program_path, "", "");
namespace paddle {
namespace lite {
TEST(CXXApi, test) {
lite::CXXPredictor predictor;
lite::ExecutorLite predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
#else
......@@ -64,14 +68,48 @@ TEST(CXXApi, test) {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(CXXApi, save_model) {
lite::CXXPredictor predictor;
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kCUDA), PRECISION(kFloat)},
valid_places);
predictor.SaveModel(FLAGS_optimized_model);
}
#endif
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(CXXTrainer, train) {
Place prefer_place({TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)});
std::vector<Place> valid_places({prefer_place});
auto scope = std::make_shared<lite::Scope>();
CXXTrainer trainer(scope, prefer_place, valid_places);
std::string main_program_pb, startup_program_pb;
ReadBinaryFile(FLAGS_main_program_path, &main_program_pb);
ReadBinaryFile(FLAGS_startup_program_path, &startup_program_pb);
framework::proto::ProgramDesc main_program_desc, startup_program_desc;
main_program_desc.ParseFromString(main_program_pb);
startup_program_desc.ParseFromString(startup_program_pb);
LOG(INFO) << main_program_desc.DebugString();
for (const auto& op : main_program_desc.blocks(0).ops()) {
LOG(INFO) << "get op " << op.type();
}
return;
trainer.RunStartupProgram(startup_program_desc);
auto& exe = trainer.BuildMainProgramExecutor(main_program_desc);
auto* tensor0 = exe.GetInput(0);
tensor0->Resize(std::vector<int64_t>({100, 100}));
auto* data0 = tensor0->mutable_data<float>();
data0[0] = 0;
exe.Run();
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace lite
} // namespace paddle
......
......@@ -16,9 +16,13 @@
#include "paddle/fluid/lite/utils/any.h"
#ifdef LITE_WITH_CUDA
#include <paddle/fluid/lite/cuda/blas.h>
#include "paddle/fluid/lite/cuda/blas.h"
#include "paddle/fluid/lite/cuda/cuda_utils.h"
#endif
#ifdef LITE_WITH_X86
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
#endif
#include <memory>
#include <set>
#include <vector>
......@@ -54,6 +58,10 @@ struct X86Context {
// overall information
// kernel information
// legacy info.
std::unique_ptr<::paddle::platform::CPUDeviceContext> x86_device_context;
std::unique_ptr<::paddle::framework::ExecutionContext> x86_execution_context;
};
#endif
......
......@@ -29,7 +29,7 @@ namespace lite {
class DDimHvy : public DDimBase<DDimHvy> {
public:
DDimHvy() = default;
explicit DDimHvy(const std::vector<value_type>& x) : DDimBase<DDimHvy>() {
DDimHvy(const std::vector<value_type>& x) : DDimBase<DDimHvy>() { // NOLINT
ConstructFrom(x);
}
explicit DDimHvy(const framework::DDim& x) : data_(x) {}
......@@ -47,6 +47,14 @@ class DDimHvy : public DDimBase<DDimHvy> {
size_t size() const { return data_.size(); }
bool empty() const { return data_.size() == 0; }
bool operator==(const DDimHvy& other) {
if (data_.size() != other.data_.size()) return false;
for (int i = 0; i < data_.size(); i++) {
if (data_[i] != other.data_[i]) return false;
}
return true;
}
private:
framework::DDim data_;
};
......@@ -85,8 +93,7 @@ class TensorHvy : public TensorBase<TensorHvy> {
const void* raw_data() const { return data_.raw_data(); }
template <typename DimT>
void Resize(const DimT& dims) {
void Resize(const DDimHvy& dims) {
LOG(INFO) << "dims.size " << dims.size();
data_.Resize(framework::make_ddim(dims.Vectorize()));
}
......@@ -103,6 +110,9 @@ class TensorHvy : public TensorBase<TensorHvy> {
const framework::LoD& lod() const { return data_.lod(); }
framework::LoD* mutable_lod() { return data_.mutable_lod(); }
const framework::LoDTensor& raw_tensor() const { return data_; }
framework::LoDTensor& raw_tensor() { return data_; }
private:
framework::LoDTensor data_;
};
......
import numpy
import sys, os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
a = fluid.layers.data(name="a", shape=[100], dtype='float32')
label = fluid.layers.data(name="label", shape=[100], dtype='float32')
a1 = fluid.layers.fc(input=a, size=500, act=None, bias_attr=False)
cost = fluid.layers.square_error_cost(a1, label)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(cost)
cpu = fluid.core.CPUPlace()
loss = exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
with open('startup_program.pb', 'wb') as f:
f.write(fluid.default_startup_program().desc.serialize_to_string())
data_1 = np.array(numpy.random.random([100, 100]), dtype='float32')
#fluid.default_main_program().desc.
#prog = fluid.compiler.CompiledProgram(fluid.default_main_program())
prog = fluid.default_main_program()
#append_backward(loss)
with open('main_program.pb', 'wb') as f:
f.write(prog.desc.serialize_to_string())
#outs = exe.run(program=prog, feed={'a':data_1, }, fetch_list=[cost])
sys.exit(0)
fluid.io.save_inference_model("./model2", [a.name], [a1], exe)
print(numpy.array(outs))
......@@ -71,7 +71,7 @@ bool OpLite::Run() {
bool OpLite::Attach(const OpDesc &opdesc, lite::Scope *scope) {
// valid_places_.clear();
CHECK(scope != nullptr);
// CHECK(!op_info_.get());
//CHECK(!op_info_.get());
scope_ = scope;
op_info_.reset(new OpInfo); // Force clean the out-of-date infomation.
op_info_->Build(opdesc.ReadonlyProto());
......
......@@ -116,6 +116,22 @@ class OpLite : public Registry {
friend class mir::Node;
friend class mir::SSAGraph;
protected:
// some helper functions.
template <typename T>
const T *GetVar(Scope *scope, const std::string &name) {
auto *var = scope->FindVar(name);
CHECK(var) << "No var found for " << name;
return &var->Get<T>();
}
template <typename T>
T *GetMutableVar(Scope *scope, const std::string &name) {
auto *var = scope->FindVar(name);
CHECK(var) << "No var found for " << name;
return var->GetMutable<T>();
}
protected:
lite::Scope *scope_{};
std::unique_ptr<KernelBase> kernel_;
......
......@@ -62,13 +62,13 @@ struct Program {
// Build from a program and scope.
void Build(const framework::proto::ProgramDesc& program) {
CHECK(ops.empty()) << "Executor duplicate Build found";
// Create operators.
for (const auto& proto_op_desc : program.blocks(0).ops()) {
lite::OpDesc op_desc(proto_op_desc);
auto op_type = op_desc.Type();
// if (op_type == "feed" || op_type == "fetch") continue;
VLOG(4) << "create Op [" << op_type << "]";
LOG(INFO) << "create Op [" << op_type << "]";
auto op = LiteOpRegistry::Global().Create(op_type);
CHECK(op) << "no Op found for " << op_type;
ops.emplace_back(std::move(op));
......@@ -86,6 +86,7 @@ struct Program {
tmp_vars.push_back("feed");
tmp_vars.push_back("fetch");
CHECK(!program.blocks().empty());
for (auto proto_var_desc : program.blocks(0).vars()) {
lite::VarDesc var_desc(proto_var_desc);
if (!var_desc.Persistable()) {
......
......@@ -3,3 +3,4 @@ set(lite_kernel_deps type_system kernel_lite op_lite op_registry_lite ${tensor_l
add_subdirectory(host)
add_subdirectory(arm)
add_subdirectory(cuda)
add_subdirectory(x86)
if(NOT LITE_WITH_X86)
return()
endif()
cc_library(activation_compute SRCS activation_compute.cc DEPS ${lite_kernel_deps} activation_op)
cc_library(elementwise_compute SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_op)
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename Functor>
void Activate(const platform::CPUDeviceContext& context,
const framework::LoDTensor* X, framework::LoDTensor* Out) {
using T = typename Functor::ELEMENT_TYPE;
auto* place = context.eigen_device();
auto x =
framework::EigenVector<T>::Flatten(paddle::operators::detail::Ref(X));
auto out =
framework::EigenVector<T>::Flatten(paddle::operators::detail::Ref(Out));
Functor()(*place, x, out);
}
template <typename Functor>
void ActivateGrad(const platform::CPUDeviceContext& context,
const framework::LoDTensor* X,
const framework::LoDTensor* Out,
const framework::LoDTensor* Out_grad,
framework::LoDTensor* X_grad) {
using T = typename Functor::ELEMENT_TYPE;
auto* place = context.eigen_device();
auto x =
framework::EigenVector<T>::Flatten(paddle::operators::detail::Ref(X));
auto out =
framework::EigenVector<T>::Flatten(paddle::operators::detail::Ref(Out));
auto x_grad = framework::EigenVector<T>::Flatten(
paddle::operators::detail::Ref(X_grad));
auto out_grad = framework::EigenVector<T>::Flatten(
paddle::operators::detail::Ref(Out_grad));
Functor()(*place, x, out, out_grad, x_grad);
}
template <typename T>
class SquareCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
auto& context = context_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>();
CHECK(context.x86_device_context);
param.Out->template mutable_data<T>();
Activate<paddle::operators::SquareFunctor<T>>(*context.x86_device_context,
&param.X->raw_tensor(),
&param.Out->raw_tensor());
}
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~SquareCompute() = default;
};
template <typename T>
class SquareGradCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationGradParam;
void Run() override {
auto& context = context_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationGradParam>();
CHECK(context.x86_device_context);
param.X_grad->template mutable_data<T>();
ActivateGrad<paddle::operators::SquareGradFunctor<T>>(
*context.x86_device_context, &param.X->raw_tensor(),
&param.Out->raw_tensor(), &param.Out_grad->raw_tensor(),
&param.X_grad->raw_tensor());
}
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~SquareGradCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(square, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SquareCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
REGISTER_LITE_KERNEL(square_grad, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::SquareGradCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.Finalize();
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
class ElementwiseSubCompute
: public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = context_->As<X86Context>();
CHECK(context.x86_device_context);
param.Out->template mutable_data<T>();
paddle::operators::ElementwiseComputeEx<SubFunctor<T>,
platform::CPUDeviceContext, T>(
*context.x86_execution_context, &param.X->raw_tensor(),
&param.Y->raw_tensor(), param.axis, SubFunctor<T>(),
&param.Out->raw_tensor());
}
// TargetType target() const override;
// PrecisionType precision() const override;
virtual ~ElementwiseSubCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// float
REGISTER_LITE_KERNEL(square, kHost, kFloat, kNCHW,
paddle::lite::kernels::x86::ElementwiseSubCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -114,7 +114,7 @@ void LoadLoDTensor(std::istream &is, Variable *var) {
void ReadBinaryFile(const std::string &filename, std::string *contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
CHECK(fin.is_open()) << "Cannot open file " << filename;
CHECK(fin.is_open()) << "Cannot open file: " << filename;
fin.seekg(0, std::ios::end);
auto size = fin.tellg();
contents->clear();
......
......@@ -47,5 +47,7 @@ void SerializeTensor(std::ostream& os, const lite::Scope& scope,
// LoDTensor to ostream
void TensorToStream(std::ostream& os, const lite::Tensor& tensor);
void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace lite
} // namespace paddle
......@@ -7,6 +7,8 @@ cc_library(scale_op_lite SRCS scale_op.cc DEPS ${op_DEPS})
cc_library(feed_op_lite SRCS feed_op.cc DEPS ${op_DEPS})
cc_library(fetch_op_lite SRCS fetch_op.cc DEPS ${op_DEPS})
cc_library(io_copy_op_lite SRCS io_copy_op.cc DEPS ${op_DEPS})
cc_library(activation_ops_lite SRCS activation_ops.cc DEPS ${op_DEPS})
cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite})
set(ops_lite
......
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class ActivationOp : public OpLite {
public:
explicit ActivationOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override {
param_.Out->Resize(param_.X->dims());
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
private:
mutable ActivationParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(square, paddle::lite::operators::ActivationOp);
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class ElementwiseOp : public OpLite {
public:
explicit ElementwiseOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool InferShape() const override {
CHECK_OR_FALSE(param_.X->dims() == param_.Y->dims());
param_.Out->Resize(param_.X->dims());
return true;
}
bool AttachImpl(const OpDesc& opdesc, lite::Scope* scope) override {
CHECK_EQ(opdesc.Inputs().size(), 2UL);
auto X_name = opdesc.Input("X").front();
auto Y_name = opdesc.Input("Y").front();
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Y = GetVar<lite::Tensor>(scope, Y_name);
param_.Out = GetMutableVar<Tensor>(scope, Out_name);
param_.axis = boost::get<int>(opdesc.GetAttr("axis"));
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
private:
mutable operators::ElementwiseParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(elementwise_sub, paddle::lite::operators::ElementwiseOp);
......@@ -31,7 +31,6 @@ class FeedOp : public OpLite {
bool InferShape() const override { return true; }
protected:
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected:
......
......@@ -41,9 +41,12 @@ class MulOpLite : public OpLite {
auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<Tensor>();
param_.y = scope->FindVar(W)->GetMutable<Tensor>();
auto *var = scope->FindVar(input);
CHECK(var);
param_.x = var->GetMutable<Tensor>();
var = scope->FindVar(W);
CHECK(var);
param_.y = var->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.x_num_col_dims = GetAttr<int>(op_desc.GetAttr("x_num_col_dims"));
......
......@@ -25,6 +25,9 @@ namespace paddle {
namespace lite {
namespace operators {
using param_t = Any;
/// ----------------------- Functional operators ------------------------------
struct FeedParam {
const std::vector<lite::Tensor>* feed_list{};
lite::Tensor* out{};
......@@ -37,6 +40,14 @@ struct FetchParam {
int col;
};
// Helper op for lite framework
struct IoCopyParam {
const lite::Tensor* x{};
lite::Tensor* y{};
};
/// -------------------------- NN operators ------------------------------------
struct FcParam {
lite::Tensor* input{};
lite::Tensor* w{};
......@@ -71,13 +82,34 @@ struct ScaleParam {
bool bias_after_scale{true};
};
struct IoCopyParam {
const lite::Tensor* x{};
lite::Tensor* y{};
/// ----------------------- element wise operators ----------------------
struct ElementwiseParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
lite::Tensor* Out{};
int axis{-1}; // for broadcasting.
};
struct ElementwiseGradParam {
const lite::Tensor* X_grad{};
const lite::Tensor* Y_grad{};
lite::Tensor* Out_grad{};
int axis{-1}; // for broadcasting.
};
using param_t = variant<FeedParam, FetchParam, FcParam, ReluParam, MulParam,
ScaleParam, IoCopyParam>;
/// ----------------------- activation operators ----------------------
struct ActivationParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct ActivationGradParam {
const lite::Tensor* X{};
const lite::Tensor* Out{};
// for backward
lite::Tensor* X_grad{};
const lite::Tensor* Out_grad{};
};
} // namespace operators
} // namespace lite
......
......@@ -21,3 +21,4 @@
#include "paddle/fluid/lite/utils/io.h"
#include "paddle/fluid/lite/utils/macros.h"
#include "paddle/fluid/lite/utils/varient.h"
#include "paddle/fluid/lite/utils/any.h"
......@@ -55,7 +55,10 @@ class Factory {
}
item_ptr_t Create(const std::string& op_type) const {
return std::move(Creates(op_type).front());
auto res = Creates(op_type);
if (res.empty()) return nullptr;
CHECK_EQ(res.size(), 1UL) << "Get multiple Op for type " << op_type;
return std::move(res.front());
}
std::list<item_ptr_t> Creates(const std::string& op_type) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册