From 900b4cdd6dea19e809a51cf0734dbed6b69fa53a Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 28 Jun 2019 05:54:52 +0000 Subject: [PATCH] Expose CXXTrainer related API to python --- cmake/version.cmake | 3 +- paddle/fluid/lite/api/cxx_api.cc | 11 + paddle/fluid/lite/api/cxx_api.h | 23 +++ paddle/fluid/lite/core/CMakeLists.txt | 7 +- paddle/fluid/lite/core/mir/ssa_graph_test.cc | 4 +- paddle/fluid/lite/core/program.cc | 2 + paddle/fluid/lite/core/scope.h | 6 + paddle/fluid/lite/core/variable.h | 7 +- .../fluid/lite/kernels/host/feed_compute.cc | 2 +- paddle/fluid/lite/kernels/x86/CMakeLists.txt | 8 +- .../lite/kernels/x86/elementwise_compute.cc | 23 ++- .../lite/kernels/x86/elementwise_compute.h | 12 +- paddle/fluid/lite/kernels/x86/sgd_compute.cc | 1 + .../kernels/x86/uniform_random_compute.cc | 67 +++++++ .../fluid/lite/model_parser/compatible_pb.cc | 4 + paddle/fluid/lite/model_parser/cpp/op_desc.cc | 1 + paddle/fluid/lite/model_parser/cpp/op_desc.h | 11 + paddle/fluid/lite/model_parser/pb/op_desc.cc | 1 + paddle/fluid/lite/operators/CMakeLists.txt | 8 +- paddle/fluid/lite/operators/activation_ops.cc | 15 ++ .../fluid/lite/operators/elementwise_ops.cc | 18 +- .../fluid/lite/operators/fill_constant_op.cc | 2 +- paddle/fluid/lite/operators/mean_op.cc | 4 +- paddle/fluid/lite/operators/mul_op.cc | 47 +++-- paddle/fluid/lite/operators/mul_op.h | 2 + paddle/fluid/lite/operators/op_params.h | 12 +- paddle/fluid/lite/operators/sgd_op.cc | 7 +- paddle/fluid/lite/operators/sgd_op.h | 2 +- .../fluid/lite/operators/uniform_random_op.cc | 45 +++++ .../fluid/lite/operators/uniform_random_op.h | 50 +++++ paddle/fluid/lite/python/lite_test.py | 103 ++++++++++ paddle/fluid/lite/tools/build.sh | 24 +++ paddle/fluid/pybind/CMakeLists.txt | 13 +- paddle/fluid/pybind/executor_lite.cc | 189 ++++++++++++++++++ paddle/fluid/pybind/executor_lite.h | 26 +++ paddle/fluid/pybind/pybind.cc | 5 + python/paddle/fluid/__init__.py | 1 + python/paddle/fluid/backward.py | 1 + python/paddle/fluid/cxx_trainer.py | 163 +++++++++++++++ 39 files changed, 870 insertions(+), 60 deletions(-) create mode 100644 paddle/fluid/lite/kernels/x86/uniform_random_compute.cc create mode 100644 paddle/fluid/lite/operators/uniform_random_op.cc create mode 100644 paddle/fluid/lite/operators/uniform_random_op.h create mode 100644 paddle/fluid/lite/python/lite_test.py create mode 100644 paddle/fluid/pybind/executor_lite.cc create mode 100644 paddle/fluid/pybind/executor_lite.h create mode 100644 python/paddle/fluid/cxx_trainer.py diff --git a/cmake/version.cmake b/cmake/version.cmake index f7b065b58..8bcc4ffe7 100644 --- a/cmake/version.cmake +++ b/cmake/version.cmake @@ -3,7 +3,8 @@ set(PADDLE_VERSION $ENV{PADDLE_VERSION}) set(tmp_version "HEAD") set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") -set(LATEST_PADDLE_VERSION "latest") +# set(LATEST_PADDLE_VERSION "latest") +set(LATEST_PADDLE_VERSION "0.0.0") while ("${PADDLE_VERSION}" STREQUAL "") # Check current branch name diff --git a/paddle/fluid/lite/api/cxx_api.cc b/paddle/fluid/lite/api/cxx_api.cc index 0cb6c12db..afb25271a 100644 --- a/paddle/fluid/lite/api/cxx_api.cc +++ b/paddle/fluid/lite/api/cxx_api.cc @@ -79,5 +79,16 @@ const lite::Tensor *Predictor::GetTensor(const std::string &name) const { return &var->Get(); } +#ifdef LITE_WITH_X86 +void Predictor::FeedVars(const std::vector &tensors) { + auto var = scope_->FindVar("feed"); + auto &feed_list = *(var->GetMutable>()); + feed_list.resize(tensors.size()); + + for (size_t i = 0; i < tensors.size(); ++i) + feed_list[i].ShareDataWith(tensors[i]); +} +#endif + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h index e7bed3f90..5a3b6976e 100644 --- a/paddle/fluid/lite/api/cxx_api.h +++ b/paddle/fluid/lite/api/cxx_api.h @@ -24,6 +24,10 @@ #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/model_parser/model_parser.h" +#ifdef LITE_WITH_X86 +#include "paddle/fluid/framework/program_desc.h" +#endif + namespace paddle { namespace lite { @@ -63,6 +67,15 @@ class Predictor { // This method is disabled in mobile, for unnecessary dependencies required. void SaveModel(const std::string& dir); +#ifdef LITE_WITH_X86 + void Run(const std::vector& tensors) { + FeedVars(tensors); + program_->Run(); + } + + void FeedVars(const std::vector& tensors); +#endif + private: Optimizer optimizer_; framework::proto::ProgramDesc program_desc_; @@ -105,6 +118,16 @@ class CXXTrainer { return main_program_executor_; } +#ifdef LITE_WITH_X86 + Predictor& BuildMainProgramExecutor(framework::ProgramDesc& desc) { // NOLINT + return BuildMainProgramExecutor(*desc.Proto()); + } + + void RunStartupProgram(framework::ProgramDesc& desc) { // NOLINT + RunStartupProgram(*desc.Proto()); + } +#endif + // Run the startup program. It just executes once, no cache needed. void RunStartupProgram(const framework::proto::ProgramDesc& desc, int block_id = 0) { diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 9bcb5a0aa..77cc6784a 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -20,14 +20,19 @@ endif() proto_library(framework_proto_lite SRCS framework.proto) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite any_lite op_params_lite framework_proto_lite ${tensor_lite}) +if (LITE_WITH_X86) +cc_library(variable_lite SRCS variable.cc DEPS framework_proto) +cc_library(types_lite SRCS types.cc DEPS framework_proto) +else() cc_library(variable_lite SRCS variable.cc) +cc_library(types_lite SRCS types.cc) +endif() cc_library(op_registry_lite SRCS op_registry.cc DEPS framework_proto_lite) cc_library(scope_lite SRCS scope.cc DEPS ${tensor_lite}) cc_library(cpu_info_lite SRCS cpu_info.cc) lite_cc_library(context_lite SRCS context.cc DEPS ${tensor_lite} any_lite cpu_info_lite eigen3) cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapper_lite kernel_lite cpp_op_desc_lite ${tensor_lite}) -cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) lite_cc_library(program_lite SRCS program.cc diff --git a/paddle/fluid/lite/core/mir/ssa_graph_test.cc b/paddle/fluid/lite/core/mir/ssa_graph_test.cc index 34d8c859a..98a93b463 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph_test.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph_test.cc @@ -52,4 +52,6 @@ TEST(SSAGraph, test) { } // namespace paddle USE_LITE_OP(fc); -USE_LITE_KERNEL(fc, kHost, kFloat, kNCHW, def); +#ifdef LITE_WITH_X86 +// USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/lite/core/program.cc b/paddle/fluid/lite/core/program.cc index 9f12f4b87..5a13a4ecc 100644 --- a/paddle/fluid/lite/core/program.cc +++ b/paddle/fluid/lite/core/program.cc @@ -64,6 +64,7 @@ void RuntimeProgram::SaveParams(const std::string &dir, void Program::Build(const framework::proto::ProgramDesc &program) { CHECK(ops_.empty()) << "Executor duplicate Build found"; + // Create operators. for (const auto &proto_op_desc : program.blocks(0).ops()) { lite::OpDesc op_desc_dummy(proto_op_desc); @@ -98,6 +99,7 @@ void Program::PrepareWorkspace(const framework::proto::ProgramDesc &program) { } else { if (var_desc.Name() == "feed" || var_desc.Name() == "fetch") continue; weights_.push_back(var_desc.Name()); + if (var_desc.Persistable()) scope_->Var(var_desc.Name()); } } } diff --git a/paddle/fluid/lite/core/scope.h b/paddle/fluid/lite/core/scope.h index 57287c17e..67bf52774 100644 --- a/paddle/fluid/lite/core/scope.h +++ b/paddle/fluid/lite/core/scope.h @@ -27,6 +27,12 @@ namespace lite { class Scope final { public: Scope() {} + // delete below two functions to allow pybind to recognise it cannot make a + // copy + // link: + // https://stackoverflow.com/questions/53807248/pybind11-returning-a-pointer-to-a-container-of-unique-ptr + Scope(const Scope&) = delete; + Scope& operator=(const Scope&) = delete; ~Scope(); Scope& NewScope() const; diff --git a/paddle/fluid/lite/core/variable.h b/paddle/fluid/lite/core/variable.h index d52a813a0..e4ab30a36 100644 --- a/paddle/fluid/lite/core/variable.h +++ b/paddle/fluid/lite/core/variable.h @@ -15,12 +15,15 @@ #pragma once #include #include +#include #include "paddle/fluid/lite/core/compatible_tensor.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { +using FeedFetchList = std::vector; + class Variable { public: template @@ -40,7 +43,9 @@ class Variable { } private: - variant blob_; + // variant blob_; + variant> + blob_; }; } // namespace lite diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 1c944e5e0..f594b6d20 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -29,7 +29,7 @@ class FeedCompute auto ¶m = Param(); VLOG(4) << "feed_list.size: " << param.feed_list->size(); VLOG(4) << "col " << param.col; - const lite::Tensor &feed_item = (*param.feed_list)[0]; + const lite::Tensor &feed_item = (*param.feed_list)[param.col]; param.out->ShareDataWith(feed_item); } }; diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index fb3ea2926..496f40aab 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -18,6 +18,7 @@ cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) +cc_library(uniform_random_compute_x86 SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps} ) lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) @@ -47,6 +48,7 @@ set(x86_kernels conv_compute_x86 pool_compute_x86 batch_norm_compute_x86 - ) - -set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") + uniform_random_compute_x86 + sgd_compute_x86 + CACHE INTERNAL "x86 kernels") + diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc index 5024e4986..068020706 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.cc +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.cc @@ -22,9 +22,19 @@ REGISTER_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); -REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW, - paddle::lite::kernels::x86::ElementwiseSubCompute, +REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ElementwiseAddCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +#ifdef LITE_WITH_X86 +REGISTER_LITE_KERNEL( + elementwise_sub_grad, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ElementwiseSubGradCompute, def) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput(paddle::framework::GradVarName("Out"), {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput(paddle::framework::GradVarName("X"), @@ -32,11 +42,4 @@ REGISTER_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW, .BindOutput(paddle::framework::GradVarName("Y"), {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); - -REGISTER_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, - paddle::lite::kernels::x86::ElementwiseAddCompute, - def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) - .Finalize(); +#endif diff --git a/paddle/fluid/lite/kernels/x86/elementwise_compute.h b/paddle/fluid/lite/kernels/x86/elementwise_compute.h index 5e46bf8d4..de976e526 100644 --- a/paddle/fluid/lite/kernels/x86/elementwise_compute.h +++ b/paddle/fluid/lite/kernels/x86/elementwise_compute.h @@ -68,6 +68,7 @@ struct SubGradDY { T operator()(T x, T y, T out, T dout) const { return -dout; } }; +#ifdef LITE_WITH_X86 template class ElementwiseSubGradCompute : public KernelLite { @@ -79,20 +80,25 @@ class ElementwiseSubGradCompute CHECK(context.x86_device_context()); param.X_grad->template mutable_data(); - param.Y_grad->template mutable_data(); // skip out, x, y auto dout = param.Out_grad->raw_tensor(); auto dx = param.X_grad->raw_tensor(); - auto dy = param.Y_grad->raw_tensor(); + + framework::Tensor* dy = nullptr; + if (param.Y_grad) { + param.Y_grad->template mutable_data(); + dy = ¶m.Y_grad->raw_tensor(); + } auto& skip = dout; paddle::operators::ElemwiseExplicitGradCompute< platform::CPUDeviceContext, T, SubGradDX, SubGradDY>( *context.x86_execution_context(), skip, skip, skip, dout, param.axis, - &dx, &dy, SubGradDX(), SubGradDY()); + &dx, dy, SubGradDX(), SubGradDY()); } virtual ~ElementwiseSubGradCompute() = default; }; +#endif template class ElementwiseAddCompute diff --git a/paddle/fluid/lite/kernels/x86/sgd_compute.cc b/paddle/fluid/lite/kernels/x86/sgd_compute.cc index 2b50c9172..593b14eb5 100644 --- a/paddle/fluid/lite/kernels/x86/sgd_compute.cc +++ b/paddle/fluid/lite/kernels/x86/sgd_compute.cc @@ -49,6 +49,7 @@ class SGDCompute : public KernelLite { const T *param_data = param->template data(); const T *grad_data = grad->template data(); int64_t rows_idx = 0; + T *out_data = param_out->template mutable_data( context.x86_device_context()->GetPlace()); diff --git a/paddle/fluid/lite/kernels/x86/uniform_random_compute.cc b/paddle/fluid/lite/kernels/x86/uniform_random_compute.cc new file mode 100644 index 000000000..58e0d1693 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/uniform_random_compute.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/operators/jit/kernels.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class UniformRandomCompute + : public KernelLite { + public: + void Run() override { + auto &context = ctx_->As(); + auto ¶m = *param_.get_mutable(); + CHECK(context.x86_device_context()); + + auto *param_out = ¶m.Out->raw_tensor(); + + T *data = + param_out->mutable_data(context.x86_device_context()->GetPlace()); + + unsigned int seed = static_cast(param.seed); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist(static_cast(param.min), + static_cast(param.max)); + int64_t size = param_out->numel(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + } + + virtual ~UniformRandomCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +// float +REGISTER_LITE_KERNEL(uniform_random, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::UniformRandomCompute, + def) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/model_parser/compatible_pb.cc b/paddle/fluid/lite/model_parser/compatible_pb.cc index 23a09f8af..e89ae6fe8 100644 --- a/paddle/fluid/lite/model_parser/compatible_pb.cc +++ b/paddle/fluid/lite/model_parser/compatible_pb.cc @@ -72,6 +72,10 @@ void AttrsPbToCpp(const pb::OpDesc &pb_desc, cpp::OpDesc *cpp_desc) { cpp_desc->SetAttr>( name, pb_desc.GetAttr>(name)); break; + case AttrType::LONGS: + cpp_desc->SetAttr>( + name, pb_desc.GetAttr>(name)); + break; default: LOG(FATAL) << "Unsupported attr type found " << static_cast(type); } diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.cc b/paddle/fluid/lite/model_parser/cpp/op_desc.cc index b6b854d72..6708b8bd8 100644 --- a/paddle/fluid/lite/model_parser/cpp/op_desc.cc +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.cc @@ -34,6 +34,7 @@ SET_ATTR_IMPL(bool, BOOLEAN); SET_ATTR_IMPL(std::vector, INTS); SET_ATTR_IMPL(std::vector, FLOATS); SET_ATTR_IMPL(std::vector, STRINGS); +SET_ATTR_IMPL(std::vector, LONGS); std::pair FindAttr(const cpp::OpDesc& desc, const std::string& name) { diff --git a/paddle/fluid/lite/model_parser/cpp/op_desc.h b/paddle/fluid/lite/model_parser/cpp/op_desc.h index b70c16926..ac001f3e7 100644 --- a/paddle/fluid/lite/model_parser/cpp/op_desc.h +++ b/paddle/fluid/lite/model_parser/cpp/op_desc.h @@ -58,6 +58,12 @@ class OpDesc : public OpDescAPI { std::map>* mutable_outputs() { return &outputs_; } + + bool HasInput(const std::string& param) const { + auto it = inputs_.find(param); + return it != inputs_.end(); + } + std::vector Input(const std::string& param) const override { auto it = inputs_.find(param); CHECK(it != inputs_.end()); @@ -75,6 +81,11 @@ class OpDesc : public OpDescAPI { return res; } + bool HasOutput(const std::string& param) const { + auto it = outputs_.find(param); + return it != outputs_.end(); + } + std::vector Output(const std::string& param) const override { auto it = outputs_.find(param); CHECK(it != outputs_.end()); diff --git a/paddle/fluid/lite/model_parser/pb/op_desc.cc b/paddle/fluid/lite/model_parser/pb/op_desc.cc index 7f84510a3..38130a923 100644 --- a/paddle/fluid/lite/model_parser/pb/op_desc.cc +++ b/paddle/fluid/lite/model_parser/pb/op_desc.cc @@ -121,6 +121,7 @@ GET_ATTRS_IMPL(std::vector, ints); GET_ATTRS_IMPL(std::vector, floats); GET_ATTRS_IMPL(std::vector, strings); GET_ATTR_IMPL(std::string, s); +GET_ATTRS_IMPL(std::vector, longs); } // namespace pb } // namespace lite diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index c99d17657..43f745c15 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -17,7 +17,9 @@ cc_library(elementwise_ops_lite SRCS elementwise_ops.cc DEPS ${op_DEPS}) cc_library(fusion_elementwise_activation_ops_lite SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops_lite ${op_DEPS}) cc_library(mean_op_lite SRCS mean_op.cc DEPS ${op_DEPS}) cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) -#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) +cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) +cc_library(uniform_random_op_lite SRCS uniform_random_op.cc DEPS ${op_DEPS}) + cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) @@ -52,7 +54,9 @@ set(ops_lite transpose_op_lite fake_quant fake_dequant - PARENT_SCOPE) + sgd_op_lite + uniform_random_op_lite + CACHE INTERNAL "ops lite") lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite diff --git a/paddle/fluid/lite/operators/activation_ops.cc b/paddle/fluid/lite/operators/activation_ops.cc index 8cda67af1..bcdc781e0 100644 --- a/paddle/fluid/lite/operators/activation_ops.cc +++ b/paddle/fluid/lite/operators/activation_ops.cc @@ -72,6 +72,21 @@ class ActivationGradOp : public OpLite { param_.Out_grad = GetVar(scope, Out_grad_name); param_.X_grad = GetMutableVar(scope, X_grad_name); + + if (opdesc.HasInput("X")) { + auto X_name = opdesc.Input("X").front(); + param_.X = GetVar(scope, X_name); + } else { + param_.X = param_.X_grad; + } + + if (opdesc.HasInput("Out")) { + auto Out_name = opdesc.Input("Out").front(); + param_.Out = GetVar(scope, Out_name); + } else { + param_.Out = param_.Out_grad; + } + return true; } diff --git a/paddle/fluid/lite/operators/elementwise_ops.cc b/paddle/fluid/lite/operators/elementwise_ops.cc index 2c6d4e709..221b41104 100644 --- a/paddle/fluid/lite/operators/elementwise_ops.cc +++ b/paddle/fluid/lite/operators/elementwise_ops.cc @@ -48,31 +48,35 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { bool ElementwiseGradExplicitOp::CheckShape() const { CHECK_OR_FALSE(param_.Y); CHECK_OR_FALSE(param_.X_grad); - CHECK_OR_FALSE(param_.Y_grad); CHECK_OR_FALSE(param_.Out_grad); return true; } bool ElementwiseGradExplicitOp::InferShape() const { param_.X_grad->Resize(param_.Out_grad->dims()); - param_.Y_grad->Resize(param_.Y->dims()); + if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); return true; } bool ElementwiseGradExplicitOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { - CHECK_EQ(opdesc.InputArgumentNames().size(), 1UL); + CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); + auto Y_name = opdesc.Input("Y").front(); auto Out_name = opdesc.Input(framework::GradVarName("Out")).front(); - auto X_name = opdesc.Output(framework::GradVarName("X")).front(); - auto Y_name = opdesc.Output(framework::GradVarName("Y")).front(); + auto X_grad = opdesc.Output(framework::GradVarName("X")).front(); + if (opdesc.Output(framework::GradVarName("Y")).size() > 0) { + auto Y_grad = opdesc.Output(framework::GradVarName("Y")).front(); + param_.Y_grad = GetMutableVar(scope, Y_grad); + } + param_.Y = GetVar(scope, Y_name); param_.Out_grad = GetVar(scope, Out_name); - param_.X_grad = GetMutableVar(scope, X_name); - param_.Y_grad = GetMutableVar(scope, Y_name); + param_.X_grad = GetMutableVar(scope, X_grad); param_.axis = opdesc.GetAttr("axis"); return true; } + #endif } // namespace operators diff --git a/paddle/fluid/lite/operators/fill_constant_op.cc b/paddle/fluid/lite/operators/fill_constant_op.cc index b762f0d3c..10079d20a 100644 --- a/paddle/fluid/lite/operators/fill_constant_op.cc +++ b/paddle/fluid/lite/operators/fill_constant_op.cc @@ -36,7 +36,7 @@ class FillConstantOp : public OpLite { bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { auto Out_name = opdesc.Output("Out").front(); - param_.Out = GetMutableVar(scope, Out_name); + param_.Out = GetMutableVar(scope, Out_name); param_.dtype = opdesc.GetAttr("dtype"); param_.shape = opdesc.GetAttr>("shape"); param_.value = opdesc.GetAttr("value"); diff --git a/paddle/fluid/lite/operators/mean_op.cc b/paddle/fluid/lite/operators/mean_op.cc index 411dcbb73..596f4bda0 100644 --- a/paddle/fluid/lite/operators/mean_op.cc +++ b/paddle/fluid/lite/operators/mean_op.cc @@ -51,7 +51,7 @@ class MeanOp : public OpLite { std::string DebugString() const override { return "mean"; } private: - mutable operators::ElementwiseParam param_; + mutable operators::MeanParam param_; }; #ifdef LITE_WITH_X86 @@ -73,7 +73,7 @@ class MeanGradOp : public OpLite { } bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override { - CHECK_EQ(opdesc.InputArgumentNames().size(), 3UL); + CHECK_EQ(opdesc.InputArgumentNames().size(), 2UL); auto X_name = opdesc.Input("X").front(); auto Out_grad_name = opdesc.Input(framework::GradVarName("Out")).front(); auto X_grad_name = opdesc.Output(framework::GradVarName("X")).front(); diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc index 70eb37dd0..4e3660d50 100644 --- a/paddle/fluid/lite/operators/mul_op.cc +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -31,16 +31,18 @@ bool MulOpLite::CheckShape() const { CHECK_GT_OR_FALSE(x_dims.size(), static_cast(param_.x_num_col_dims)); CHECK_GT_OR_FALSE(y_dims.size(), static_cast(param_.y_num_col_dims)); - // auto x_mat_dims = - // framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims); - // auto y_mat_dims = - // framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims); - - // PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0], - // "First matrix's width must be equal with second matrix's - // " - // "height. %s, %s", - // x_mat_dims[1], y_mat_dims[0]); +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + auto x_mat_dims = + framework::flatten_to_2d(x_dims.data(), param_.x_num_col_dims); + auto y_mat_dims = + framework::flatten_to_2d(y_dims.data(), param_.y_num_col_dims); + + PADDLE_ENFORCE_EQ(x_mat_dims[1], y_mat_dims[0], + "First matrix's width must be equal with second matrix's" + "height. %s, %s", + x_mat_dims[1], y_mat_dims[0]); +#endif + return true; } @@ -73,30 +75,34 @@ bool MulGradOpLite::CheckShape() const { CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.output_grad); - CHECK_OR_FALSE(param_.x_grad); - CHECK_OR_FALSE(param_.y_grad); return true; } bool MulGradOpLite::InferShape() const { - param_.x_grad->Resize(param_.x->dims()); - param_.y_grad->Resize(param_.y->dims()); + if (param_.x_grad) param_.x_grad->Resize(param_.x->dims()); + if (param_.y_grad) param_.y_grad->Resize(param_.y->dims()); return true; } bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto X_name = op_desc.Input("X").front(); auto Y_name = op_desc.Input("Y").front(); - auto Out_grad_name = op_desc.Output(framework::GradVarName("Out")).front(); - auto X_grad_name = op_desc.Output(framework::GradVarName("X")).front(); - auto Y_grad_name = op_desc.Output(framework::GradVarName("Y")).front(); + auto Out_grad_name = op_desc.Input(framework::GradVarName("Out")).front(); + + if (op_desc.Output(framework::GradVarName("X")).size()) { + auto X_grad_name = op_desc.Output(framework::GradVarName("X")).front(); + param_.x_grad = GetMutableVar(scope, X_grad_name); + } + + if (op_desc.Output(framework::GradVarName("Y")).size()) { + auto Y_grad_name = op_desc.Output(framework::GradVarName("Y")).front(); + param_.y_grad = GetMutableVar(scope, Y_grad_name); + } param_.x = GetVar(scope, X_name); param_.y = GetVar(scope, Y_name); param_.output_grad = GetVar(scope, Out_grad_name); - param_.x_grad = GetMutableVar(scope, X_grad_name); - param_.y_grad = GetMutableVar(scope, Y_grad_name); return true; } @@ -107,3 +113,6 @@ bool MulGradOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { } // namespace paddle REGISTER_LITE_OP(mul, paddle::lite::operators::MulOpLite); +#ifdef LITE_WITH_X86 +REGISTER_LITE_OP(mul_grad, paddle::lite::operators::MulGradOpLite); +#endif diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h index a01427b1f..05c3a2761 100644 --- a/paddle/fluid/lite/operators/mul_op.h +++ b/paddle/fluid/lite/operators/mul_op.h @@ -66,6 +66,7 @@ class MulOpLite : public OpLite { mutable MulParam param_; }; +#ifdef LITE_WITH_X86 class MulGradOpLite : public OpLite { public: MulGradOpLite() {} @@ -85,6 +86,7 @@ class MulGradOpLite : public OpLite { private: mutable MulGradParam param_; }; +#endif } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 5bbbcc98b..cfde5a077 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -36,7 +36,7 @@ using param_t = Any; /// ----------------------- Functional operators ------------------------------ struct FeedParam { - const std::vector* feed_list{}; + std::vector* feed_list{}; lite::Tensor* out{}; int col; }; @@ -317,6 +317,16 @@ struct SGDParam { lite::Tensor* ParamOut{}; }; +/// ----------------------- uniform_random operators ---------------------- +struct UniformRandomParam { + std::vector shape{}; + float min{-1.0f}; + float max{1.0f}; + int seed{0}; + int dtype{framework::proto::VarType::FP32}; + lite::Tensor* Out{}; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/sgd_op.cc b/paddle/fluid/lite/operators/sgd_op.cc index 2571ad0b1..666ca7980 100644 --- a/paddle/fluid/lite/operators/sgd_op.cc +++ b/paddle/fluid/lite/operators/sgd_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "/paddle/paddle/fluid/lite/operators/sgd_op.h" +#include "paddle/fluid/lite/operators/sgd_op.h" #include "paddle/fluid/lite/core/op_lite.h" #include "paddle/fluid/lite/core/op_registry.h" @@ -30,13 +30,14 @@ bool SGDOpLite::CheckShape() const { bool SGDOpLite::InferShape() const { auto lr_dims = param_.LearningRate->dims().data(); +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK CHECK_EQ_OR_FALSE(framework::product(lr_dims), 1); +#endif param_.ParamOut->Resize(param_.Param->dims()); return true; } -bool SGDOpLite::AttachImpl(const OpDesc& opdesc, lite::Scope* scope) { - CHECK_EQ(opdesc.Inputs().size(), 3UL); +bool SGDOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto Param_name = opdesc.Input("Param").front(); auto LearningRate_name = opdesc.Input("LearningRate").front(); auto Grad_name = opdesc.Input("Grad").front(); diff --git a/paddle/fluid/lite/operators/sgd_op.h b/paddle/fluid/lite/operators/sgd_op.h index dea045c0b..5847a2cc5 100644 --- a/paddle/fluid/lite/operators/sgd_op.h +++ b/paddle/fluid/lite/operators/sgd_op.h @@ -37,7 +37,7 @@ class SGDOpLite : public OpLite { void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } - bool AttachImpl(const OpDesc &op_desc, lite::Scope *scope) override; + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; std::string DebugString() const override { return "sgd"; } diff --git a/paddle/fluid/lite/operators/uniform_random_op.cc b/paddle/fluid/lite/operators/uniform_random_op.cc new file mode 100644 index 000000000..5f38b9ed1 --- /dev/null +++ b/paddle/fluid/lite/operators/uniform_random_op.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/uniform_random_op.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool UniformRandomOpLite::CheckShape() const { return true; } + +bool UniformRandomOpLite::InferShape() const { + param_.Out->Resize(param_.shape); + return true; +} + +bool UniformRandomOpLite::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + param_.shape = opdesc.GetAttr>("shape"); + param_.min = opdesc.GetAttr("min"); + param_.max = opdesc.GetAttr("max"); + param_.seed = opdesc.GetAttr("seed"); + param_.dtype = opdesc.GetAttr("dtype"); + param_.Out = GetMutableVar(scope, opdesc.Output("Out").front()); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(uniform_random, paddle::lite::operators::UniformRandomOpLite); diff --git a/paddle/fluid/lite/operators/uniform_random_op.h b/paddle/fluid/lite/operators/uniform_random_op.h new file mode 100644 index 000000000..0d85baf59 --- /dev/null +++ b/paddle/fluid/lite/operators/uniform_random_op.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class UniformRandomOpLite : public OpLite { + public: + UniformRandomOpLite() {} + + explicit UniformRandomOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; + + std::string DebugString() const override { return "uniform_random"; } + + private: + mutable UniformRandomParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/python/lite_test.py b/paddle/fluid/lite/python/lite_test.py new file mode 100644 index 000000000..5ef354883 --- /dev/null +++ b/paddle/fluid/lite/python/lite_test.py @@ -0,0 +1,103 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import paddle.fluid.compiler as compiler +import paddle.fluid.core as core +import paddle.fluid.core.lite as lite +import paddle.fluid.layers as layers +import numpy as np +import unittest + +from paddle.fluid.cxx_trainer import add_feed_fetch_op + + +def _as_lodtensor(data, place): + # single tensor case + tensor = core.LoDTensor() + tensor.set(data, place) + return tensor + + +data_label = [[ + 0.753544, 0.772977, 0.646915, 0.747543, 0.528923, 0.0517749, 0.248678, + 0.75932, 0.960376, 0.606618 +]] +data_a = [[ + 0.874445, 0.21623, 0.713262, 0.702672, 0.396977, 0.828285, 0.932995, + 0.442674, 0.0321735, 0.484833, 0.045935, 0.21276, 0.556421, 0.131825, + 0.285626, 0.741409, 0.257467, 0.975958, 0.444006, 0.114553 +]] + +data_loss = [0.9876687] + + +class NaiveModelTest(unittest.TestCase): + def test_model(self): + + start_prog = fluid.Program() + main_prog = fluid.Program() + + start_prog.random_seed = 100 + main_prog.random_seed = 100 + + with fluid.program_guard(main_prog, start_prog): + a = fluid.layers.data(name="a", shape=[1, 20], dtype='float32') + label = fluid.layers.data(name="label", shape=[10], dtype='float32') + a1 = fluid.layers.fc(input=a, size=10, act=None, bias_attr=False) + cost = fluid.layers.square_error_cost(a1, label) + avg_cost = fluid.layers.mean(cost) + + optimizer = fluid.optimizer.SGD(learning_rate=0.001) + optimizer.minimize(avg_cost) + + x86_place = lite.Place(lite.TargetType.kX86, + lite.PrecisionType.kFloat, + lite.DataLayoutType.kNCHW, 0) + host_place = lite.Place(lite.TargetType.kHost, + lite.PrecisionType.kFloat, + lite.DataLayoutType.kNCHW, 0) + scope = lite.Scope() + + trainer = lite.CXXTrainer(scope, x86_place, [x86_place, host_place]) + trainer.run_startup_program(start_prog.desc) + + cpu = fluid.core.CPUPlace() + main_prog = add_feed_fetch_op( + main_prog, + feed=['a', 'label'], + fetch_list={avg_cost}, + scope=scope, + place=cpu) + # print(main_prog) + exe = trainer.build_main_program_executor(main_prog.desc) + + feed_data = [ + _as_lodtensor(np.array(data_a, object), cpu), + _as_lodtensor(np.array(data_label, object), cpu) + ] + + exe.run(feed_data) + # print(np.array(exe.get_output(0).raw_tensor())) + self.assertTrue( + np.allclose( + np.array(data_loss), + np.array(exe.get_output(0).raw_tensor()), + atol=1e-8), + "lite result not equel to offline result") + + +if __name__ == '__main__': + unittest.main() diff --git a/paddle/fluid/lite/tools/build.sh b/paddle/fluid/lite/tools/build.sh index c832c6304..9128cb8d1 100755 --- a/paddle/fluid/lite/tools/build.sh +++ b/paddle/fluid/lite/tools/build.sh @@ -112,6 +112,26 @@ function build_test_server { test_lite $TESTS_FILE } +function build_test_train { + mkdir -p ./build + cd ./build + export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/paddle/build/third_party/install/mklml/lib" + prepare_workspace # fake an empty __generated_code__.cc to pass cmake. + cmake .. -DWITH_LITE=ON -DWITH_GPU=OFF -DWITH_PYTHON=ON -DLITE_WITH_X86=ON -DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF -DWITH_TESTING=ON -DWITH_MKL=OFF + + make test_gen_code_lite -j$NUM_CORES_FOR_COMPILE + make test_cxx_api_lite -j$NUM_CORES_FOR_COMPILE + ctest -R test_cxx_api_lite + ctest -R test_gen_code_lite + make test_generated_code -j$NUM_CORES_FOR_COMPILE + + make -j$NUM_CORES_FOR_COMPILE + + find -name "*.whl" | xargs pip2 install + python ../paddle/fluid/lite/python/lite_test.py + +} + # test_arm_android function test_arm_android { local test_name=$1 @@ -543,6 +563,10 @@ function main { build_test_server shift ;; + build_test_train) + build_test_train + shift + ;; build_test_arm) build_test_arm shift diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d709508a6..bee702519 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,13 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper nccl_wrapper prune - feed_fetch_method pass_builder parallel_executor profiler layer scope_pool - tracer analysis_predictor imperative_profiler nccl_context) +message(STATUS "use ${x86_kernels}") +message(STATUS "use ${ops_lite}") + +if(WITH_PYTHON) + cc_library(bind_executor_lite SRCS executor_lite.cc DEPS pybind framework_proto) + set(PYBIND_DEPS pybind python proto_desc memory executor async_executor fleet_wrapper nccl_wrapper prune + feed_fetch_method pass_builder parallel_executor profiler layer scope_pool bind_executor_lite cxx_api_lite scope_lite ${ops_lite} ${host_kernels} ${x86_kernels} mir_passes kernel_lite op_lite optimizer_lite + tracer analysis_predictor imperative_profiler nccl_context) +endif(WITH_PYTHON) + if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) diff --git a/paddle/fluid/pybind/executor_lite.cc b/paddle/fluid/pybind/executor_lite.cc new file mode 100644 index 000000000..2ca4e1dce --- /dev/null +++ b/paddle/fluid/pybind/executor_lite.cc @@ -0,0 +1,189 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/pybind/executor_lite.h" +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/api/paddle_use_passes.h" +#include "paddle/fluid/lite/core/hvy_tensor.h" +#include "paddle/fluid/lite/core/scope.h" +#include "pybind11/pybind11.h" + +namespace lt = paddle::lite; +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindTensor(pybind11::module* m) { + pybind11::class_(*m, "Tensor") + .def(pybind11::init<>()) + .def("raw_tensor", [](lt::TensorHvy& self) { return self.raw_tensor(); }) + .def("share_data_with", + [](lt::TensorHvy& self, const framework::Tensor& other) { + self.ShareDataWith(other); + }); +} + +void BindVariable(pybind11::module* m) { + pybind11::class_(*m, "Variable") + .def("get_mutable_tensor", + [](lt::Variable& self) { return self.GetMutable(); }) + .def("get_mutable_fetch_list", + [](lt::Variable& self) -> paddle::lite::FeedFetchList* { + return self.GetMutable(); + }, + py::return_value_policy::reference); +} + +void BindScope(pybind11::module* m) { + py::class_>(*m, "Scope") + .def(pybind11::init<>()) + .def("new_scope", + [](lt::Scope& self) -> lt::Scope* { return &self.NewScope(); }, + py::return_value_policy::reference) + .def("var", <::Scope::Var, pybind11::return_value_policy::reference) + .def("find_var", <::Scope::FindVar, + pybind11::return_value_policy::reference) + .def("find_local_var", <::Scope::FindLocalVar, + pybind11::return_value_policy::reference) + .def("parent", <::Scope::parent, + pybind11::return_value_policy::reference) + .def("local_var_names", <::Scope::LocalVarNames, + pybind11::return_value_policy::reference); +} + +void BindExecutorLite(pybind11::module* m) { + py::class_(*m, "Predictor") + .def(pybind11::init<>()) + .def("__init__", + [](lt::Predictor& self, + const std::shared_ptr& root_scope) { + new (&self) lt::Predictor(root_scope); + }) + .def("get_input", <::Predictor::GetInput, + pybind11::return_value_policy::reference) + .def("get_output", <::Predictor::GetOutput, + pybind11::return_value_policy::reference) + .def("run", [](lt::Predictor& self) { self.Run(); }) + .def("run", [](lt::Predictor& self, + const std::vector& tensors) { + self.Run(tensors); + }); +} + +void BindEnums(pybind11::module* m) { + py::enum_(*m, "TargetType", py::arithmetic(), + "TargetType enum") + .value("kUnk", lt::TargetType::kUnk) + .value("kHost", lt::TargetType::kHost) + .value("kX86", lt::TargetType::kX86) + .value("kCUDA", lt::TargetType::kCUDA) + .value("kARM", lt::TargetType::kARM) + .value("kAny", lt::TargetType::kAny) + .value("NUM", lt::TargetType::NUM); + + py::enum_(*m, "PrecisionType", py::arithmetic(), + "PrecisionType enum") + .value("kUnk", lt::PrecisionType::kUnk) + .value("kFloat", lt::PrecisionType::kFloat) + .value("kInt8", lt::PrecisionType::kInt8) + .value("kAny", lt::PrecisionType::kAny) + .value("NUM", lt::PrecisionType::NUM); + + py::enum_(*m, "DataLayoutType", py::arithmetic(), + "DataLayoutType enum") + .value("kUnk", lt::DataLayoutType::kUnk) + .value("kNCHW", lt::DataLayoutType::kNCHW) + .value("kAny", lt::DataLayoutType::kAny) + .value("NUM", lt::DataLayoutType::NUM); +} + +void BindPlace(pybind11::module* m) { + pybind11::class_>(*m, "Place") + .def(pybind11::init<>()) + .def("__init__", + [](lt::Place& self, lt::TargetType target, + lt::PrecisionType precision, lt::DataLayoutType layout, + int16_t device) { + new (&self) lt::Place(target, precision, layout, device); + }) + .def("is_valid", <::Place::is_valid, + pybind11::return_value_policy::reference); +} + +void BindCXXTrainer(pybind11::module* m) { + pybind11::class_>( + *m, "CXXTrainer") + .def( + "__init__", + [](lt::CXXTrainer& self, const std::shared_ptr& root_scope, + const lt::Place& preferred_place, + const std::vector& valid_places) { + new (&self) + lt::CXXTrainer(root_scope, preferred_place, valid_places); + }) + .def("build_main_program_executor", + [](lt::CXXTrainer& self, + framework::ProgramDesc& desc) -> lt::Predictor& { + return self.BuildMainProgramExecutor(desc); + }, + pybind11::return_value_policy::reference) + .def("run_startup_program", + [](lt::CXXTrainer& self, framework::ProgramDesc& desc) { + return self.RunStartupProgram(desc); + }); +} + +void BindLite(pybind11::module* m) { + BindTensor(m); + BindVariable(m); + BindScope(m); + BindExecutorLite(m); + BindEnums(m); + BindPlace(m); + BindCXXTrainer(m); +} + +} // namespace pybind +} // namespace paddle + +// USE_LITE_OP(mul); +USE_LITE_OP(elementwise_sub); +USE_LITE_OP(uniform_random); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(fill_constant); +USE_LITE_OP(mul); +USE_LITE_OP(mul_grad); +USE_LITE_OP(mean); +USE_LITE_OP(square); +USE_LITE_OP(sgd); + +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(uniform_random, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fill_constant, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mul_grad, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(square, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(mean, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(sgd, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub_grad, kX86, kFloat, kNCHW, def); +#endif diff --git a/paddle/fluid/pybind/executor_lite.h b/paddle/fluid/pybind/executor_lite.h new file mode 100644 index 000000000..c53e92d31 --- /dev/null +++ b/paddle/fluid/pybind/executor_lite.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace pybind { + +void BindLite(pybind11::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 63d37223c..fa8cee26b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -54,6 +54,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/executor_lite.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/inference_api.h" @@ -366,6 +367,7 @@ PYBIND11_MODULE(core, m) { .def("set", PyCUDAPinnedTensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) + .def("memory_size", [](Tensor &self) { return self.memory_size(); }) .def("_set_float_element", TensorSetElement) .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) @@ -1528,6 +1530,9 @@ All parameter, weight, gradient are variables in Paddle. BindNode(&m); BindInferenceApi(&m); BindDataset(&m); + + py::module lite = m.def_submodule("lite", "submodule lite"); + BindLite(&lite); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index adc7c23f4..cf22c109b 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -65,6 +65,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable from . import install_check from .dygraph.nn import * from .dygraph.layers import * +from .cxx_trainer import * Tensor = LoDTensor diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 41f9016ed..c57b35d02 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -71,6 +71,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): op_desc.set_block_attr(name, val.desc) else: op_desc._set_attr(name, val) + op_desc.check_attrs() return op_desc diff --git a/python/paddle/fluid/cxx_trainer.py b/python/paddle/fluid/cxx_trainer.py new file mode 100644 index 000000000..d25e44224 --- /dev/null +++ b/python/paddle/fluid/cxx_trainer.py @@ -0,0 +1,163 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from . import core +from . import framework +from . import executor +from . import compiler +import sys + +from .framework import default_main_program, Variable + +__all__ = ['add_feed_fetch_op'] + + +def _has_feed_operators(block, feed_targets, feed_holder_name): + """ Check whether the block already has feed operators. + + Return false if the block does not have any feed operators. + If some feed operators have been prepended to the block, check that + the info contained in these feed operators matches the feed_targets + and feed_holder_name. Raise exception when any mismatch is found. + Return true when the block has feed operators with matching info. + + Args: + block: a block instance (typically global block of a program) + feed_targets: a dictionary of {feed_target_name: feed_target_data} + feed_holder_name: the name of the variable that holds the data of + all feed targets. The type of this feed_holder variable is + FEED_MINIBATCH, which is essentially vector. + + Returns: + A boolean value that indicates whether a block has feed operators + that match the info contained in feed_targets and feed_holder_name. + """ + + feed_count = 0 + for op in block.ops: + if op.desc.type() == 'feed': + feed_count += 1 + assert op.desc.input('X')[0] == feed_holder_name + feed_target_name = op.desc.output('Out')[0] + if feed_target_name not in feed_targets: + raise Exception("'feed_targets' does not have {} variable". + format(feed_target_name)) + else: + break + if feed_count > 0 and feed_count != len(feed_targets): + raise Exception( + "Feed operators in program desc do not match 'feed_targets'") + return feed_count > 0 + + +def _has_fetch_operators(block, fetch_targets, fetch_holder_name): + """ Check whether the block already has fetch operators. + + Return false if the block does not have any fetch operators. + If some fetch operators have been appended to the block, check that + the info contained in these fetch operators matches the fetch_targets + and fetch_holder_name. Raise exception when any mismatch is found. + Return true when the block has fetch operators with matching info. + + Args: + block: a block instance (typically global block of a program) + fetch_targets: a dictionary of {fetch_target_name: fetch_target_data} + fetch_holder_name: the name of the variable that holds the data of + all fetch targets. The type of this fetch_holder variable is + FETCH_LIST, which is essentially vector. + + Return: + A boolean value that indicates whether a block has fetch operators + that match the info contained in fetch_targets and fetch_holder_name. + """ + + fetch_count = 0 + for op in block.ops: + if op.desc.type() == 'fetch': + fetch_count += 1 + assert op.desc.output('Out')[0] == fetch_holder_name + fetch_target_name = op.desc.input('X')[0] + if fetch_target_name not in [ + var.desc.name() for var in fetch_targets + ]: + raise Exception("'fetch_targets' does not have {} variable". + format(fetch_target_name)) + idx = op.desc.attr('col') + assert fetch_target_name == fetch_targets[idx].desc.name() + if fetch_count > 0 and fetch_count != len(fetch_targets): + raise Exception( + "Fetch operators in program desc do not match 'fetch_targets'") + return fetch_count > 0 + + +def _add_feed_fetch_ops(program, + feed, + fetch_list, + feed_var_name='feed', + fetch_var_name='fetch'): + tmp_program = program.clone() + + global_block = tmp_program.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + # prepend feed operators + if not _has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + # append fetch_operators + if not _has_fetch_operators(global_block, fetch_list, fetch_var_name): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance( + var, six.string_types), ("Wrong type for fetch_list[%s]: %s" % + (i, type(var))) + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + return tmp_program + + +def add_feed_fetch_op(program, feed, fetch_list, scope, place): + + if program is None: + program = default_main_program() + + program = _add_feed_fetch_ops( + program=program, feed=feed, fetch_list=fetch_list) + + return program -- GitLab