提交 367a2814 编写于 作者: S Superjomn

add feed_op

上级 65bfecc9
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc)
cc_library(target_wrapper_lite SRCS target_wrapper.cc) cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite target_wrapper_lite)
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite) cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc) cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
......
...@@ -79,6 +79,13 @@ class Buffer { ...@@ -79,6 +79,13 @@ class Buffer {
space_ = 0; space_ = 0;
} }
void CopyDataFrom(const Buffer& other, size_t nbytes) {
target_ = other.target_;
ResizeLazy(nbytes);
// TODO(Superjomn) support copy between different targets.
memcpy(data_, other.data_, nbytes);
}
private: private:
size_t space_{0}; size_t space_{0};
void* data_{nullptr}; void* data_{nullptr};
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <memory>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "memory.h" #include "paddle/fluid/lite/core/memory.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -62,11 +64,11 @@ using LoD = std::vector<std::vector<size_t>>; ...@@ -62,11 +64,11 @@ using LoD = std::vector<std::vector<size_t>>;
// A light-weight tensor implementation. // A light-weight tensor implementation.
class Tensor { class Tensor {
public: public:
Tensor() = default; Tensor() : buffer_(std::make_shared<Buffer>()) {}
template <typename T> template <typename T>
const T* data() const { const T* data() const {
return static_cast<const T*>(buffer_.data()); return static_cast<const T*>(buffer_->data());
} }
void Resize(const DDim& ddim) { dims_ = ddim; } void Resize(const DDim& ddim) { dims_ = ddim; }
...@@ -78,16 +80,31 @@ class Tensor { ...@@ -78,16 +80,31 @@ class Tensor {
template <typename T> template <typename T>
T* mutable_data() { T* mutable_data() {
buffer_.ResetLazy(target_, product(dims_) * sizeof(T)); buffer_->ResetLazy(target_, product(dims_) * sizeof(T));
return static_cast<T*>(buffer_.data()); return static_cast<T*>(buffer_->data());
} }
bool IsInitialized() const { return buffer_.data(); } bool IsInitialized() const { return buffer_->data(); }
// Other share data to this.
void ShareDataWith(const Tensor& other) {
buffer_ = other.buffer_;
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
}
void CopyDataFrom(const Tensor& other) {
dims_ = other.dims_;
target_ = other.target_;
lod_ = other.lod_;
*buffer_ = *other.buffer_;
}
private: private:
TargetType target_{TargetType::kHost}; TargetType target_{TargetType::kHost};
DDim dims_; DDim dims_;
Buffer buffer_; std::shared_ptr<Buffer> buffer_;
LoD lod_; LoD lod_;
}; };
......
...@@ -2,12 +2,14 @@ cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite) ...@@ -2,12 +2,14 @@ cc_library(fc_compute_host SRCS fc_compute.cc DEPS tensor_lite)
cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite) cc_library(relu_compute_host SRCS relu_compute.cc DEPS tensor_lite)
cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite) cc_library(mul_compute_host SRCS mul_compute.cc DEPS tensor_lite)
cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite) cc_library(scale_compute_host SRCS scale_compute.cc DEPS tensor_lite)
cc_library(feed_compute_host SRCS feed_compute.cc DEPS tensor_lite)
cc_library(host_kernels DEPS cc_library(host_kernels DEPS
fc_compute_host fc_compute_host
relu_compute_host relu_compute_host
mul_compute_host mul_compute_host
scale_compute_host scale_compute_host
feed_compute_host
DEPS kernel_lite DEPS kernel_lite
) )
......
...@@ -42,9 +42,9 @@ void FcCompute::Run() { ...@@ -42,9 +42,9 @@ void FcCompute::Run() {
param.output->mutable_data<float>()); param.output->mutable_data<float>());
} }
TargetType FcCompute::target() const { return TARGET(kHost); } // TargetType FcCompute::target() const { return TARGET(kHost); }
PrecisionType FcCompute::precision() const { return PRECISION(kFloat); } // PrecisionType FcCompute::precision() const { return PRECISION(kFloat); }
} // namespace host } // namespace host
} // namespace kernels } // namespace kernels
......
...@@ -29,8 +29,8 @@ class FcCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -29,8 +29,8 @@ class FcCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
void Run() override; void Run() override;
TargetType target() const override; // TargetType target() const override;
PrecisionType precision() const override; // PrecisionType precision() const override;
virtual ~FcCompute() = default; virtual ~FcCompute() = default;
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class FeedCompute : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public:
using param_t = operators::FeedParam;
void Run() override {
auto &theparam = param<operators::FeedParam>();
const Tensor &feed_item = theparam.feed_list->at(theparam.col);
theparam.out->CopyDataFrom(feed_item);
}
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(feed, kHost, kFloat,
paddle::lite::kernels::host::FeedCompute)
.BindInput("X", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.BindOutput("Out", {paddle::lite::Type::Get<paddle::lite::TensorFp32NCHWTy>(
TARGET(kHost))})
.Finalize();
...@@ -2,6 +2,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite prot ...@@ -2,6 +2,7 @@ cc_library(fc_op_lite SRCS fc_op.cc DEPS op_lite op_params_lite tensor_lite prot
cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite) cc_library(relu_op_lite SRCS relu_op.cc DEPS op_lite)
cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite) cc_library(mul_op_lite SRCS mul_op.cc DEPS op_lite)
cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite) cc_library(scale_op_lite SRCS scale_op.cc DEPS op_lite)
cc_library(feed_op_lite SRCS feed_op.cc DEPS op_lite)
cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite) cc_library(op_params_lite SRCS op_params.cc DEPS tensor_lite)
cc_library(ops_lite DEPS cc_library(ops_lite DEPS
...@@ -9,6 +10,7 @@ cc_library(ops_lite DEPS ...@@ -9,6 +10,7 @@ cc_library(ops_lite DEPS
relu_op_lite relu_op_lite
mul_op_lite mul_op_lite
scale_op_lite scale_op_lite
feed_op_lite
) )
cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host) cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite fc_compute_host)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
class FeedOp : public OpLite {
public:
explicit FeedOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.feed_list);
CHECK_OR_FALSE(param_.out);
return true;
}
bool InferShape() const override { return true; }
protected:
bool AttachImpl(const framework::OpDesc& opdesc,
lite::Scope* scope) override {
auto feed_var_name = opdesc.Input("X").front();
auto* feed_var = scope->FindVar(feed_var_name);
CHECK(feed_var);
auto& feed_tensor_list = feed_var->Get<std::vector<Tensor>>();
param_.feed_list = &feed_tensor_list;
auto out_name = opdesc.Output("Out").front();
auto* out_var = scope->FindVar(out_name);
CHECK(out_var);
param_.out = out_var->GetMutable<Tensor>();
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_.col = boost::get<int>(opdesc.GetAttr("col"));
return true;
}
std::string DebugString() const override { return "feed"; }
private:
mutable FeedParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(feed, paddle::lite::operators::FeedOp);
...@@ -24,6 +24,12 @@ namespace paddle { ...@@ -24,6 +24,12 @@ namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
struct FeedParam {
const std::vector<Tensor>* feed_list;
Tensor* out;
int col;
};
struct FcParam { struct FcParam {
Tensor* input{}; Tensor* input{};
Tensor* w{}; Tensor* w{};
...@@ -58,7 +64,7 @@ struct ScaleParam { ...@@ -58,7 +64,7 @@ struct ScaleParam {
bool bias_after_scale{true}; bool bias_after_scale{true};
}; };
using param_t = variant<FcParam, ReluParam, MulParam, ScaleParam>; using param_t = variant<FeedParam, FcParam, ReluParam, MulParam, ScaleParam>;
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册