diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 5a471a3c38b402b17e47a24cfbfc3cca292b8cd8..804a0bcda3562d89ee95ca96b55fbd0cb98f6976 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(operators) add_subdirectory(kernels) add_subdirectory(model_parser) add_subdirectory(utils) +add_subdirectory(api) diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..cae0912bd749a2fb8f76b80fbaf9ada62bcf27ec --- /dev/null +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -0,0 +1,3 @@ +cc_library(cxx_api_lite SRCS cxx_api.h DEPS scope_lite executor_lite host_kernels ops_lite) + +cc_test(test_cxx_api_lite SRCS cxx_api_test.cc DEPS cxx_api_lite model_parser_lite) diff --git a/paddle/fluid/lite/api/cxx_api.cc b/paddle/fluid/lite/api/cxx_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..81450cb8d15dc7948c67251fc9af64a90d790d34 --- /dev/null +++ b/paddle/fluid/lite/api/cxx_api.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +// Created by chunwei on 19-4-11. +// + +#include "paddle/fluid/lite/api/cxx_api.h" diff --git a/paddle/fluid/lite/api/cxx_api.h b/paddle/fluid/lite/api/cxx_api.h new file mode 100644 index 0000000000000000000000000000000000000000..9c304b036a5dde924f0d0071112d9d5b3aae7a88 --- /dev/null +++ b/paddle/fluid/lite/api/cxx_api.h @@ -0,0 +1,20 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/lite/model_parser/model_parser.h" + +namespace paddle { +namespace lite {} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..09fd7a78bfcfd7f224af135a6d1a88f7520075dd --- /dev/null +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/api/cxx_api.h" +#include +#include "paddle/fluid/lite/core/executor.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +TEST(CXXApi, test) { + Scope scope; + framework::proto::ProgramDesc prog; + LoadModel("/home/chunwei/project2/models/model2", &scope, &prog); + framework::ProgramDesc prog_desc(prog); + + lite::Executor executor(&scope, + {OpLite::Place{TARGET(kHost), PRECISION(kFloat)}}); + + auto x = scope.Var("a")->GetMutable(); + x->Resize({100, 100}); + x->mutable_data(); + + executor.PrepareWorkspace(prog_desc, &scope); + executor.Build(prog_desc); + executor.Run(); +} + +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(fc); +USE_LITE_OP(scale); +USE_LITE_KERNEL(fc, kHost, kFloat); +USE_LITE_KERNEL(mul, kHost, kFloat); +USE_LITE_KERNEL(scale, kHost, kFloat); diff --git a/paddle/fluid/lite/core/executor.h b/paddle/fluid/lite/core/executor.h index d73dff5908ae19fae23f1111ca52fde9c7897bc7..b87cb232d9d866c0c1e6315060049a4463bbe342 100644 --- a/paddle/fluid/lite/core/executor.h +++ b/paddle/fluid/lite/core/executor.h @@ -47,6 +47,7 @@ class Executor { // Create operators. for (auto* op_desc : program.Block(0).AllOps()) { auto op_type = op_desc->Type(); + if (op_type == "feed" || op_type == "fetch") continue; LOG(INFO) << "create Op [" << op_type << "]"; ops_.emplace_back(LiteOpRegistry::Global().Create(op_type)); // pick initial kernel diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 2966df4d3f0cda824bc8c27330e5745dafbf8100..b2208a015d343302022ac0111947c202e20701fa 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -67,6 +67,9 @@ class OpKernel : public KernelBase { void Touch() {} + TargetType target() const override { return Target; } + PrecisionType precision() const override { return Precision; } + OpKernel() = default; virtual ~OpKernel() = default; diff --git a/paddle/fluid/lite/core/op_lite.cc b/paddle/fluid/lite/core/op_lite.cc index 53098a2a957dc7333395305101fb602447417f96..a053b77974d6c962d17555432e17558455f81e20 100644 --- a/paddle/fluid/lite/core/op_lite.cc +++ b/paddle/fluid/lite/core/op_lite.cc @@ -20,13 +20,14 @@ namespace paddle { namespace lite { std::vector> OpLite::CreateKernels( - const std::vector &places) { + const std::vector &places, const std::string &kernel_type) { std::vector> kernels; CHECK(!op_type_.empty()) << "op_type_ should be set first"; for (auto place : places) { - kernels.emplace_back(KernelRegistry::Global().Create(op_type_, place.target, - place.precision)); + kernels.emplace_back(KernelRegistry::Global().Create( + (kernel_type.empty() ? op_type_ : kernel_type), place.target, + place.precision)); } return kernels; diff --git a/paddle/fluid/lite/core/op_lite.h b/paddle/fluid/lite/core/op_lite.h index 3ba849eb96060ad71a023dbff2af90c36caee4d6..2d9ad332ec2f7cff45720c025df36206c9416269 100644 --- a/paddle/fluid/lite/core/op_lite.h +++ b/paddle/fluid/lite/core/op_lite.h @@ -119,7 +119,8 @@ class OpLite : public Registry { // Create all the kernels for the valid targets. std::vector> CreateKernels( - const std::vector &places); + const std::vector &places, + const std::string &kernel_type = ""); protected: std::unique_ptr op_context_; diff --git a/paddle/fluid/lite/core/tensor.cc b/paddle/fluid/lite/core/tensor.cc index b56baf7d17c8eb20170e152d99e090cfd9d7ff71..39ffd2ee66ec8d5bcb118ffd0fc70e61e5a7d205 100644 --- a/paddle/fluid/lite/core/tensor.cc +++ b/paddle/fluid/lite/core/tensor.cc @@ -32,8 +32,8 @@ std::ostream &operator<<(std::ostream &os, const DDim &dims) { } std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { - os << "Tensor:" << std::endl; - os << "dim: " << tensor.dims(); + os << "Tensor:" << '\n'; + os << "dim: " << tensor.dims() << '\n'; for (int i = 0; i < product(tensor.dims()); i++) { os << tensor.data()[i] << " "; } diff --git a/paddle/fluid/lite/core/types.h b/paddle/fluid/lite/core/types.h index 0e95882c56056e93d02e65fb34152e2818da0972..52b5ea7d02abe5800bbeea0082874d5413a1882a 100644 --- a/paddle/fluid/lite/core/types.h +++ b/paddle/fluid/lite/core/types.h @@ -25,6 +25,21 @@ using any_context_t = variant, // Context // >; +struct dim2 { + int x{}; + int y{}; + + dim2(int x, int y) : x(x), y(y) {} +}; + +struct dim3 { + int x{}; + int y{}; + int z{}; + + dim3(int x, int y, int z) : x(x), y(y), z(z) {} +}; + } // namespace core } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 0b1ed64263aa0af8f762809f3c8d132acd4c88d9..17935f8094d1aec04ca8ccd7cd18345cacbc8faa 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -1,8 +1,13 @@ cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) +cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) +cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) cc_library(host_kernels DEPS fc_compute_host - relu_compute_host) + relu_compute_host + mul_compute_host + scale_compute_host + ) cc_test(test_fc_compute SRCS fc_compute_test.cc DEPS fc_compute_host fc_op_lite) diff --git a/paddle/fluid/lite/kernels/host/mul_compute.cc b/paddle/fluid/lite/kernels/host/mul_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..08bb2a737a23886c40e2d0b90d462114cbe3fe09 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/mul_compute.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +void mul_compute_eigen(const T* x, int x_h, int x_w, const T* y, int y_h, + int y_w, T* out) { + using matrix_t = + Eigen::Matrix; + + Eigen::Map X(x, x_h, x_w); + Eigen::Map Y(y, y_h, y_w); + Eigen::Map Out(out, x_h, y_w); + + Out = X * Y; +} + +class MulCompute : public OpKernel { + public: + using param_t = operators::MulParam; + + void Run() override { + auto& theparam = param(); + core::dim2 x_shape( + {product(theparam.x->dims().begin(), + theparam.x->dims().begin() + theparam.x_num_col_dims), + product(theparam.x->dims().begin() + theparam.x_num_col_dims, + theparam.x->dims().end())}); + + core::dim2 y_shape( + {product(theparam.y->dims().begin(), + theparam.y->dims().begin() + theparam.x_num_col_dims), + product(theparam.y->dims().begin() + theparam.x_num_col_dims, + theparam.y->dims().end())}); + + mul_compute_eigen(theparam.x->data(), x_shape.x, x_shape.y, // + theparam.y->data(), y_shape.x, y_shape.y, // + theparam.output->mutable_data()); + } + + virtual ~MulCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(mul, kHost, kFloat, + paddle::lite::kernels::host::MulCompute); diff --git a/paddle/fluid/lite/kernels/host/scale_compute.cc b/paddle/fluid/lite/kernels/host/scale_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..1eaea0312729c803a6d5df17229f5dabcd78f1e9 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/scale_compute.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +void scale_compute(const T* x, T* out, int size, float scale, float bias, + bool bias_before) { + if (bias_before) bias *= scale; + for (int i = 0; i < size; i++) { + out[i] = x[i] * scale + bias; + } +} + +class ScaleCompute : public OpKernel { + public: + using param_t = operators::MulParam; + + void Run() override { + auto& theparam = param(); + scale_compute(theparam.x->data(), theparam.x->mutable_data(), + product(theparam.x->dims()), theparam.scale, theparam.bias, + theparam.bias_after_scale); + } + + virtual ~ScaleCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(scale, kHost, kFloat, + paddle::lite::kernels::host::ScaleCompute); diff --git a/paddle/fluid/lite/model_parser/CMakeLists.txt b/paddle/fluid/lite/model_parser/CMakeLists.txt index 78de6f9367da5aa106f4b02b38256324015a3fd0..36044b9ad9fb567dd9f85e22e4067ef7ab168884 100644 --- a/paddle/fluid/lite/model_parser/CMakeLists.txt +++ b/paddle/fluid/lite/model_parser/CMakeLists.txt @@ -1,3 +1,3 @@ -cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite) +cc_library(model_parser_lite SRCS model_parser.cc DEPS variable_lite scope_lite tensor_lite scope_lite) cc_library(runtime_lite SRCS runtime.cc) cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite) diff --git a/paddle/fluid/lite/model_parser/model_parser.cc b/paddle/fluid/lite/model_parser/model_parser.cc index fb3b1e9ac7f272f0c882d148b6a996a825a462c3..63316d34cd0b9e5023eebfeae5c2295e9692a6d1 100644 --- a/paddle/fluid/lite/model_parser/model_parser.cc +++ b/paddle/fluid/lite/model_parser/model_parser.cc @@ -138,15 +138,27 @@ void LoadParam(const std::string &path, Variable *out) { LoadLoDTensor(fin, out); } -void LoadModel(const std::string &model_dir, Scope *scope) { +void LoadModel(const std::string &model_dir, Scope *scope, + framework::proto::ProgramDesc *prog) { const std::string prog_path = model_dir + "/__model__"; - auto prog = LoadProgram(prog_path); + *prog = *LoadProgram(prog_path); auto main_block = prog->blocks(0); for (auto &var : main_block.vars()) { + if (var.name() == "feed" || var.name() == "fetch" || !var.persistable()) + continue; + std::string file_path = model_dir + "/" + var.name(); + LOG(INFO) << "reading weight " << var.name(); + std::ifstream file(file_path); - LoadLoDTensor(file, scope->Var(var.name())); + switch (var.type().type()) { + case framework::proto::VarType_Type_LOD_TENSOR: + LoadLoDTensor(file, scope->Var(var.name())); + break; + default: + CHECK(false) << "unknown weight type"; + } } } diff --git a/paddle/fluid/lite/model_parser/model_parser.h b/paddle/fluid/lite/model_parser/model_parser.h index 358bb9a1abeb6c595c92b62e7079051699ce9843..41a5a9a93172df330064aeeb9f60fe54b8fee652 100644 --- a/paddle/fluid/lite/model_parser/model_parser.h +++ b/paddle/fluid/lite/model_parser/model_parser.h @@ -19,6 +19,7 @@ #include #include #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/lite/core/scope.h" #include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/variable.h" @@ -36,7 +37,8 @@ void LoadParams(const std::string& path); void LoadParam(const std::string& path, Variable* out); // Read a model and files of parameters. -void LoadModel(const std::string& model_dir); +void LoadModel(const std::string& model_dir, Scope* scope, + framework::proto::ProgramDesc* prog); } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/model_parser/model_parser_test.cc b/paddle/fluid/lite/model_parser/model_parser_test.cc index 1ce7eb83a5204a837785c5b32afae8172725d044..b5d721809a68913f485f6b73fc120e735623546a 100644 --- a/paddle/fluid/lite/model_parser/model_parser_test.cc +++ b/paddle/fluid/lite/model_parser/model_parser_test.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/lite/model_parser/model_parser.h" #include +#include "paddle/fluid/lite/core/scope.h" namespace paddle { namespace lite { @@ -23,5 +24,20 @@ TEST(ModelParser, LoadProgram) { "/home/chunwei/project2/models/fc/fluid_checkpoint/__model__"); } +TEST(ModelParser, LoadParam) { + Scope scope; + auto* v = scope.Var("xxx"); + LoadParam("/home/chunwei/project2/models/fc/fluid_checkpoint/b1", v); + const auto& t = v->Get(); + LOG(INFO) << "loaded\n"; + LOG(INFO) << t; +} + +TEST(ModelParser, LoadModel) { + Scope scope; + framework::proto::ProgramDesc prog; + LoadModel("/home/chunwei/project2/models/fc/fluid_checkpoint", &scope, &prog); +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 7da1abd4e0b278d4ee6f1c058b9277e6dc6c77d7..9b33e75d5b75b48e00a24af88667e48c2d6d7c3b 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -1,8 +1,14 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite proto_desc) cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) +cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) +cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) + cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(ops_lite DEPS fc_op_lite - relu_op_lite) + relu_op_lite + mul_op_lite + scale_op_lite + ) cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) diff --git a/paddle/fluid/lite/operators/fc_op.h b/paddle/fluid/lite/operators/fc_op.h index e3893cc938531c3755265f9b8af9e4ed9f65401f..71cd815184012be3072c37acb88b970120f72c60 100644 --- a/paddle/fluid/lite/operators/fc_op.h +++ b/paddle/fluid/lite/operators/fc_op.h @@ -66,11 +66,6 @@ class FcOpLite : public OpLite { std::string DebugString() const override { return "fc"; } - void StaticPickKernel(const std::vector &valid_targets) override { - auto kernels = CreateKernels(valid_targets); - kernel_ = std::move(kernels.front()); - } - private: mutable FcParam param_; }; diff --git a/paddle/fluid/lite/operators/mul_op.cc b/paddle/fluid/lite/operators/mul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f1dd0de970edca22f4c12369ffb1300451a8518 --- /dev/null +++ b/paddle/fluid/lite/operators/mul_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/mul_op.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MulOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.y); + CHECK_OR_FALSE(param_.output); + // bias is optional. + + const auto x_dims = param_.x->dims(); + const auto y_dims = param_.y->dims(); + + CHECK_EQ_OR_FALSE(y_dims.size(), 2UL); + CHECK_GT_OR_FALSE(x_dims.size(), static_cast(param_.x_num_col_dims)); + + return true; +} + +bool MulOpLite::InferShape() const { + const auto x_dims = param_.x->dims(); + const auto y_dims = param_.y->dims(); + + // Set output dims + std::vector out_dims(param_.x_num_col_dims + 1, 0); + for (int i = 0; i < param_.x_num_col_dims; ++i) { + out_dims[i] = x_dims[i]; + } + out_dims.back() = y_dims[1]; + param_.output->Resize(out_dims); + + // share LoD + // param_.output->set_lod(param_.input->lod()); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(mul, paddle::lite::operators::MulOpLite); diff --git a/paddle/fluid/lite/operators/mul_op.h b/paddle/fluid/lite/operators/mul_op.h new file mode 100644 index 0000000000000000000000000000000000000000..392d3e72beca3fa0a18e1a387e36a987b2fd194f --- /dev/null +++ b/paddle/fluid/lite/operators/mul_op.h @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MulOpLite : public OpLite { + public: + MulOpLite() {} + + explicit MulOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { + auto input = op_desc.Input("X").front(); + auto W = op_desc.Input("Y").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(input)->GetMutable(); + param_.y = scope->FindVar(W)->GetMutable(); + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); + param_.x_num_col_dims = boost::get(op_desc.GetAttr("x_num_col_dims")); + param_.y_num_col_dims = boost::get(op_desc.GetAttr("y_num_col_dims")); + + CHECK(kernel_); + kernel_->SetParam(param_); + + return true; + } + + std::string DebugString() const override { return "mul"; } + + private: + mutable MulParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index d701f5518c35b559b2c1126fbd00a8930028f2ba..dea8759eca9401b0ff9b21c142ad8dcb6bded8da 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -30,7 +30,7 @@ struct FcParam { Tensor* bias{}; Tensor* output{}; DDim in_mat_dims; - int in_num_col_dims{0}; + int in_num_col_dims{1}; }; struct ReluParam { @@ -38,7 +38,27 @@ struct ReluParam { Tensor* output{}; }; -using param_t = variant; +// For Mul Op +struct MulParam { + Tensor* x{}; + Tensor* y{}; + Tensor* output{}; + + int x_num_col_dims{1}; + int y_num_col_dims{1}; +}; + +// For Scale Op +struct ScaleParam { + Tensor* x{}; + Tensor* output{}; + + float scale{1.}; + float bias{}; + bool bias_after_scale{true}; +}; + +using param_t = variant; } // namespace operators } // namespace lite diff --git a/paddle/fluid/lite/operators/relu_op.h b/paddle/fluid/lite/operators/relu_op.h index e7412d709cda6a913e5bc6e7eceb563d5060c385..da8553000b5d59b5445d67dc0f4fb497de556f32 100644 --- a/paddle/fluid/lite/operators/relu_op.h +++ b/paddle/fluid/lite/operators/relu_op.h @@ -36,11 +36,6 @@ class ReluOp : public OpLite { std::string DebugString() const override { return "tanh"; } - void StaticPickKernel( - const std::vector &valid_targets) override { - kernel_ = std::move(CreateKernels(valid_targets).front()); - } - private: mutable ReluParam param_; }; diff --git a/paddle/fluid/lite/operators/scale_op.cc b/paddle/fluid/lite/operators/scale_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1371c01277e47e31ce7d1302114d70850640c36e --- /dev/null +++ b/paddle/fluid/lite/operators/scale_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ScaleOp : public OpLite { + public: + ScaleOp() {} + + explicit ScaleOp(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.output); + return true; + } + + bool InferShape() const override { + param_.output->Resize(param_.x->dims()); + return true; + } + + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool Attach(const framework::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + + param_.x = scope->FindVar(x)->GetMutable(); + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); + param_.scale = boost::get(op_desc.GetAttr("scale")); + param_.bias = boost::get(op_desc.GetAttr("bias")); + param_.bias_after_scale = + boost::get(op_desc.GetAttr("bias_after_scale")); + + CHECK(kernel_); + kernel_->SetParam(param_); + + return true; + } + + std::string DebugString() const override { return op_type_; } + + private: + mutable ScaleParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(scale, paddle::lite::operators::ScaleOp);