From 367a2814def03e213700553db18b7081b321591f Mon Sep 17 00:00:00 2001 From: Superjomn Date: Fri, 19 Apr 2019 08:07:37 +0800 Subject: [PATCH] add feed_op --- paddle/fluid/lite/core/CMakeLists.txt | 2 +- paddle/fluid/lite/core/memory.h | 7 ++ paddle/fluid/lite/core/tensor.h | 31 +++++++-- paddle/fluid/lite/kernels/host/CMakeLists.txt | 2 + paddle/fluid/lite/kernels/host/fc_compute.cc | 4 +- paddle/fluid/lite/kernels/host/fc_compute.h | 4 +- .../fluid/lite/kernels/host/feed_compute.cc | 46 +++++++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 2 + paddle/fluid/lite/operators/feed_op.cc | 64 +++++++++++++++++++ paddle/fluid/lite/operators/op_params.h | 8 ++- 10 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/lite/kernels/host/feed_compute.cc create mode 100644 paddle/fluid/lite/operators/feed_op.cc diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 1c58ceda545..51d15a25082 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(memory_lite SRCS memory.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc) -cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) +cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite target_wrapper_lite) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(variable_lite SRCS variable.cc) cc_library(op_registry_lite SRCS op_registry.cc) diff --git a/paddle/fluid/lite/core/memory.h b/paddle/fluid/lite/core/memory.h index 0267dd43ed1..4cb46607f4a 100644 --- a/paddle/fluid/lite/core/memory.h +++ b/paddle/fluid/lite/core/memory.h @@ -79,6 +79,13 @@ class Buffer { space_ = 0; } + void CopyDataFrom(const Buffer& other, size_t nbytes) { + target_ = other.target_; + ResizeLazy(nbytes); + // TODO(Superjomn) support copy between different targets. + memcpy(data_, other.data_, nbytes); + } + private: size_t space_{0}; void* data_{nullptr}; diff --git a/paddle/fluid/lite/core/tensor.h b/paddle/fluid/lite/core/tensor.h index e308913a3dc..a9a129cda42 100644 --- a/paddle/fluid/lite/core/tensor.h +++ b/paddle/fluid/lite/core/tensor.h @@ -14,9 +14,11 @@ #pragma once #include +#include #include #include -#include "memory.h" +#include "paddle/fluid/lite/core/memory.h" +#include "paddle/fluid/lite/core/target_wrapper.h" namespace paddle { namespace lite { @@ -62,11 +64,11 @@ using LoD = std::vector>; // A light-weight tensor implementation. class Tensor { public: - Tensor() = default; + Tensor() : buffer_(std::make_shared()) {} template const T* data() const { - return static_cast(buffer_.data()); + return static_cast(buffer_->data()); } void Resize(const DDim& ddim) { dims_ = ddim; } @@ -78,16 +80,31 @@ class Tensor { template T* mutable_data() { - buffer_.ResetLazy(target_, product(dims_) * sizeof(T)); - return static_cast(buffer_.data()); + buffer_->ResetLazy(target_, product(dims_) * sizeof(T)); + return static_cast(buffer_->data()); } - bool IsInitialized() const { return buffer_.data(); } + bool IsInitialized() const { return buffer_->data(); } + + // Other share data to this. + void ShareDataWith(const Tensor& other) { + buffer_ = other.buffer_; + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + } + + void CopyDataFrom(const Tensor& other) { + dims_ = other.dims_; + target_ = other.target_; + lod_ = other.lod_; + *buffer_ = *other.buffer_; + } private: TargetType target_{TargetType::kHost}; DDim dims_; - Buffer buffer_; + std::shared_ptr buffer_; LoD lod_; }; diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index 7c416dbf505..60e500630d5 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -2,12 +2,14 @@ cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) +cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite) cc_library(host_kernels DEPS fc_compute_host relu_compute_host mul_compute_host scale_compute_host + feed_compute_host DEPS kernel_lite ) diff --git a/paddle/fluid/lite/kernels/host/fc_compute.cc b/paddle/fluid/lite/kernels/host/fc_compute.cc index 9c2b6c6205b..e81900bad9d 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.cc +++ b/paddle/fluid/lite/kernels/host/fc_compute.cc @@ -42,9 +42,9 @@ void FcCompute::Run() { param.output->mutable_data()); } -TargetType FcCompute::target() const { return TARGET(kHost); } +// TargetType FcCompute::target() const { return TARGET(kHost); } -PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } +// PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } } // namespace host } // namespace kernels diff --git a/paddle/fluid/lite/kernels/host/fc_compute.h b/paddle/fluid/lite/kernels/host/fc_compute.h index 355f1be8503..d835e96d409 100644 --- a/paddle/fluid/lite/kernels/host/fc_compute.h +++ b/paddle/fluid/lite/kernels/host/fc_compute.h @@ -29,8 +29,8 @@ class FcCompute : public OpKernel { void Run() override; - TargetType target() const override; - PrecisionType precision() const override; + // TargetType target() const override; + // PrecisionType precision() const override; virtual ~FcCompute() = default; }; diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc new file mode 100644 index 00000000000..6a0f480a4d1 --- /dev/null +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class FeedCompute : public OpKernel { + public: + using param_t = operators::FeedParam; + + void Run() override { + auto &theparam = param(); + const Tensor &feed_item = theparam.feed_list->at(theparam.col); + theparam.out->CopyDataFrom(feed_item); + } +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(feed, kHost, kFloat, + paddle::lite::kernels::host::FeedCompute) + .BindInput("X", {paddle::lite::Type::Get( + TARGET(kHost))}) + .BindOutput("Out", {paddle::lite::Type::Get( + TARGET(kHost))}) + .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 9b33e75d5b7..3a80f3b0229 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -2,6 +2,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite prot cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) +cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(ops_lite DEPS @@ -9,6 +10,7 @@ cc_library(ops_lite DEPS relu_op_lite mul_op_lite scale_op_lite + feed_op_lite ) cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) diff --git a/paddle/fluid/lite/operators/feed_op.cc b/paddle/fluid/lite/operators/feed_op.cc new file mode 100644 index 00000000000..45d2f640a35 --- /dev/null +++ b/paddle/fluid/lite/operators/feed_op.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +class FeedOp : public OpLite { + public: + explicit FeedOp(const std::string& type) : OpLite(type) {} + + bool CheckShape() const override { + CHECK_OR_FALSE(param_.feed_list); + CHECK_OR_FALSE(param_.out); + return true; + } + + bool InferShape() const override { return true; } + + protected: + bool AttachImpl(const framework::OpDesc& opdesc, + lite::Scope* scope) override { + auto feed_var_name = opdesc.Input("X").front(); + auto* feed_var = scope->FindVar(feed_var_name); + CHECK(feed_var); + auto& feed_tensor_list = feed_var->Get>(); + param_.feed_list = &feed_tensor_list; + + auto out_name = opdesc.Output("Out").front(); + auto* out_var = scope->FindVar(out_name); + CHECK(out_var); + param_.out = out_var->GetMutable(); + + // NOTE need boost here + // TODO(Superjomn) drop the need of framework::op_desc + param_.col = boost::get(opdesc.GetAttr("col")); + return true; + } + + std::string DebugString() const override { return "feed"; } + + private: + mutable FeedParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(feed, paddle::lite::operators::FeedOp); diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index dea8759eca9..bf57af0a4ee 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -24,6 +24,12 @@ namespace paddle { namespace lite { namespace operators { +struct FeedParam { + const std::vector* feed_list; + Tensor* out; + int col; +}; + struct FcParam { Tensor* input{}; Tensor* w{}; @@ -58,7 +64,7 @@ struct ScaleParam { bool bias_after_scale{true}; }; -using param_t = variant; +using param_t = variant; } // namespace operators } // namespace lite -- GitLab