提交 97fa88eb 编写于 作者: qnqinan's avatar qnqinan

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into develop

...@@ -45,6 +45,7 @@ if (WITH_TESTING) ...@@ -45,6 +45,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "mobilenet_v2_relu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "resnet50.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz") lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "inception_v4_simple.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz")
endif() endif()
endif() endif()
......
...@@ -143,6 +143,11 @@ if(WITH_TESTING) ...@@ -143,6 +143,11 @@ if(WITH_TESTING)
${ops} ${host_kernels} ${x86_kernels} ${ops} ${host_kernels} ${x86_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50) ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
add_dependencies(test_resnet50_lite_x86 extern_lite_download_resnet50_tar_gz) add_dependencies(test_resnet50_lite_x86 extern_lite_download_resnet50_tar_gz)
lite_cc_test(test_step_rnn_lite_x86 SRCS test_step_rnn_lite_x86.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/step_rnn)
add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz)
endif() endif()
endif() endif()
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> //NOLINT
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -126,6 +127,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -126,6 +127,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
void Run() override; void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::string GetVersion() const override; std::string GetVersion() const override;
// get inputs names and get outputs names // get inputs names and get outputs names
...@@ -146,6 +149,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor { ...@@ -146,6 +149,8 @@ class CxxPaddleApiImpl : public lite_api::PaddlePredictor {
private: private:
Predictor raw_predictor_; Predictor raw_predictor_;
lite_api::CxxConfig config_;
std::mutex mutex_;
}; };
/* /*
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "lite/api/cxx_api.h" #include "lite/api/cxx_api.h"
#include <memory>
#include <mutex> //NOLINT
#include <string> #include <string>
#include "lite/api/paddle_api.h" #include "lite/api/paddle_api.h"
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
...@@ -22,6 +24,7 @@ namespace paddle { ...@@ -22,6 +24,7 @@ namespace paddle {
namespace lite { namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config;
#ifdef LITE_WITH_CUDA #ifdef LITE_WITH_CUDA
Env<TARGET(kCUDA)>::Init(); Env<TARGET(kCUDA)>::Init();
#endif #endif
...@@ -50,6 +53,13 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() { ...@@ -50,6 +53,13 @@ std::vector<std::string> CxxPaddleApiImpl::GetOutputNames() {
void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); } void CxxPaddleApiImpl::Run() { raw_predictor_.Run(); }
std::shared_ptr<lite_api::PaddlePredictor> CxxPaddleApiImpl::Clone() {
std::lock_guard<std::mutex> lock(mutex_);
auto predictor = std::make_shared<lite::CxxPaddleApiImpl>();
predictor->Init(config_);
return predictor;
}
std::string CxxPaddleApiImpl::GetVersion() const { return version(); } std::string CxxPaddleApiImpl::GetVersion() const { return version(); }
std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor( std::unique_ptr<const lite_api::Tensor> CxxPaddleApiImpl::GetTensor(
......
...@@ -96,6 +96,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor { ...@@ -96,6 +96,8 @@ class LightPredictorImpl : public lite_api::PaddlePredictor {
void Run() override; void Run() override;
std::shared_ptr<lite_api::PaddlePredictor> Clone() override;
std::string GetVersion() const override; std::string GetVersion() const override;
std::vector<std::string> GetInputNames() override; std::vector<std::string> GetInputNames() override;
std::vector<std::string> GetOutputNames() override; std::vector<std::string> GetOutputNames() override;
......
...@@ -44,6 +44,10 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput( ...@@ -44,6 +44,10 @@ std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetOutput(
void LightPredictorImpl::Run() { raw_predictor_->Run(); } void LightPredictorImpl::Run() { raw_predictor_->Run(); }
std::shared_ptr<lite_api::PaddlePredictor> LightPredictorImpl::Clone() {
LOG(FATAL) << "The Clone API is not supported in LigthPredictor";
}
std::string LightPredictorImpl::GetVersion() const { return lite::version(); } std::string LightPredictorImpl::GetVersion() const { return lite::version(); }
std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor( std::unique_ptr<const lite_api::Tensor> LightPredictorImpl::GetTensor(
......
...@@ -46,6 +46,10 @@ template <> ...@@ -46,6 +46,10 @@ template <>
const int8_t *Tensor::data() const { const int8_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int8_t>(); return ctensor(raw_tensor_)->data<int8_t>();
} }
template <>
const int64_t *Tensor::data() const {
return ctensor(raw_tensor_)->data<int64_t>();
}
template <> template <>
const int32_t *Tensor::data() const { const int32_t *Tensor::data() const {
...@@ -64,6 +68,10 @@ template <> ...@@ -64,6 +68,10 @@ template <>
int8_t *Tensor::mutable_data(TargetType type) const { int8_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int8_t>(type); return tensor(raw_tensor_)->mutable_data<int8_t>(type);
} }
template <>
int64_t *Tensor::mutable_data(TargetType type) const {
return tensor(raw_tensor_)->mutable_data<int64_t>(type);
}
template <typename T, TargetType type> template <typename T, TargetType type>
void Tensor::CopyFromCpu(const T *src_data) { void Tensor::CopyFromCpu(const T *src_data) {
......
...@@ -78,6 +78,7 @@ class LITE_API PaddlePredictor { ...@@ -78,6 +78,7 @@ class LITE_API PaddlePredictor {
virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0; virtual std::unique_ptr<const Tensor> GetOutput(int i) const = 0;
virtual void Run() = 0; virtual void Run() = 0;
virtual std::shared_ptr<PaddlePredictor> Clone() = 0;
virtual std::string GetVersion() const = 0; virtual std::string GetVersion() const = 0;
......
...@@ -103,6 +103,8 @@ static size_t PrecisionTypeLength(PrecisionType type) { ...@@ -103,6 +103,8 @@ static size_t PrecisionTypeLength(PrecisionType type) {
return 1; return 1;
case PrecisionType::kInt32: case PrecisionType::kInt32:
return 4; return 4;
case PrecisionType::kInt64:
return 8;
case PrecisionType::kFP16: case PrecisionType::kFP16:
return 2; return 2;
default: default:
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/lite_api_test_helper.h"
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
TEST(Step_rnn, test_step_rnn_lite_x86) {
std::string model_dir = FLAGS_model_dir;
lite_api::CxxConfig config;
config.set_model_dir(model_dir);
config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
auto predictor = lite_api::CreatePaddlePredictor(config);
std::vector<std::string> target_names = {"item_type_id",
"mthid_id",
"source_id_id",
"layout_id",
"mark_id",
"category_id",
"subcategory_id",
"score_segment_id",
"item_attention_id",
"queue_num_id",
"micro_video_id",
"vertical_type_id"};
for (int i = 0; i < target_names.size(); ++i) {
auto input_tensor = predictor->GetInput(i);
int size = 0;
if (i == 6 || i == 8) {
input_tensor->Resize(std::vector<int64_t>{5, 1});
input_tensor->SetLoD({{0, 5}});
size = 5;
} else {
input_tensor->Resize(std::vector<int64_t>{1, 1});
input_tensor->SetLoD({{0, 1}});
size = 1;
}
auto* data = input_tensor->mutable_data<int64_t>();
for (int i = 0; i < size; i++) data[i] = 1;
}
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor->Run();
}
// LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
std::vector<std::vector<float>> results;
// i = 1
results.emplace_back(std::vector<float>({0.5030127, 0.496987}));
auto out = predictor->GetOutput(0);
std::vector<int64_t> out_shape = out->shape();
for (int i = 0; i < results.size(); ++i) {
for (int j = 0; j < results[i].size(); ++j) {
EXPECT_NEAR(
out->data<float>()[j + (out_shape[1] * i)], results[i][j], 1e-6);
}
}
}
} // namespace lite
} // namespace paddle
...@@ -255,8 +255,6 @@ class Context<TargetType::kX86> { ...@@ -255,8 +255,6 @@ class Context<TargetType::kX86> {
public: public:
Context() {} Context() {}
Context(Context&& ctx) {}
// NOTE: InitOnce should only be used by ContextScheduler // NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {} void InitOnce() {}
......
...@@ -54,4 +54,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -54,4 +54,5 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass) paddle::lite::mir::QuantDequantFusePass)
.BindTargets({TARGET(kAny)}); .BindTargets({TARGET(kAny)})
.BindKernel("calib");
...@@ -229,6 +229,8 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) { ...@@ -229,6 +229,8 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
} }
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
#ifndef LITH_WITH_XPU
// TOOD(xxx): remove Launch() at last // TOOD(xxx): remove Launch() at last
auto& kkks = stmt.kernels(); auto& kkks = stmt.kernels();
if (!kkks.empty()) { if (!kkks.empty()) {
...@@ -237,6 +239,7 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) { ...@@ -237,6 +239,7 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
kk->Launch(); kk->Launch();
} }
} }
#endif
} }
} }
......
...@@ -54,6 +54,8 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -54,6 +54,8 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL1(target__, kFP16); \ CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \ case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \ CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \ default: \
CHECK(false) << "not supported kernel precision " \ CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \ << PrecisionToStr(precision); \
...@@ -126,6 +128,7 @@ KernelRegistry::KernelRegistry() ...@@ -126,6 +128,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kFloat, kNCHW); INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW); INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny); INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
INIT_FOR(kARM, kFloat, kNCHW); INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kInt8, kNCHW); INIT_FOR(kARM, kInt8, kNCHW);
......
...@@ -32,7 +32,7 @@ struct EigenDim { ...@@ -32,7 +32,7 @@ struct EigenDim {
static Type From(const lite::DDim& dims) { static Type From(const lite::DDim& dims) {
PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size"); PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size");
Type ret; Type ret;
for (int64_t d = 0; d < dims.size(); d++) { for (size_t d = 0; d < dims.size(); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
} }
return ret; return ret;
...@@ -118,7 +118,9 @@ struct EigenScalar { ...@@ -118,7 +118,9 @@ struct EigenScalar {
using ConstType = Eigen::TensorMap< using ConstType = Eigen::TensorMap<
Eigen::TensorFixedSize<const T, Eigen::Sizes<>, MajorType, IndexType>>; Eigen::TensorFixedSize<const T, Eigen::Sizes<>, MajorType, IndexType>>;
static Type From(Tensor& tensor) { return Type(tensor.data<T>()); } // NOLINT static Type From(Tensor* tensor) {
return Type(const_cast<T*>(tensor->data<T>()));
} // NOLINT
static ConstType From(const Tensor& tensor) { static ConstType From(const Tensor& tensor) {
return ConstType(tensor.data<T>()); return ConstType(tensor.data<T>());
......
...@@ -36,6 +36,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP ...@@ -36,6 +36,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
if(NOT LITE_WITH_X86) if(NOT LITE_WITH_X86)
return() return()
......
...@@ -57,3 +57,13 @@ REGISTER_LITE_KERNEL(gelu, ...@@ -57,3 +57,13 @@ REGISTER_LITE_KERNEL(gelu,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(softsign,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SoftsignCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
...@@ -187,6 +187,31 @@ class GeluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -187,6 +187,31 @@ class GeluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~GeluCompute() = default; virtual ~GeluCompute() = default;
}; };
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
template <typename T>
class SoftsignCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
// auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>();
param.Out->template mutable_data<T>();
Activate<SoftsignFunctor<T>>(param.X, param.Out);
}
virtual ~SoftsignCompute() = default;
};
} // namespace x86 } // namespace x86
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -42,7 +42,10 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -42,7 +42,10 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
int64_t axis = static_cast<int64_t>(param.axis); int64_t axis = static_cast<int64_t>(param.axis);
auto x_dims = param.x[0]->dims(); auto x_dims = param.x[0]->dims();
auto out = param.output; auto out = param.output;
if (param.x.size() == 1) return; if (param.x.size() == 1) {
param.output->ShareDataWith(*param.x[0]);
return;
}
auto output_data = param.output->template mutable_data<T>(); auto output_data = param.output->template mutable_data<T>();
int offset_concat_axis = 0; int offset_concat_axis = 0;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/lookup_table_compute.h"
// REGISTER_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW,
// paddle::lite::kernels::x86::LookupTableCompute<float>,
// def)
// .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
// .Finalize();
//,
REGISTER_LITE_KERNEL(lookup_table,
kX86,
kInt64,
kNCHW,
paddle::lite::kernels::x86::LookupTableCompute<float>,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kInt64)> {
public:
using param_t = operators::LookupTableParam;
void Run() override {
auto &param = *param_.get_mutable<operators::LookupTableParam>();
// auto& context = context_->As<X86Context>();
auto *ids_t = param.Ids;
auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>();
int64_t ids_numel = ids_t->dims().production();
auto *table_t = param.W;
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
memset(output, 0, output_t->dims().production() * sizeof(float));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != -1 && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids[i], row_number);
CHECK_GE(ids[i], 0);
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
virtual ~LookupTableCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reduce_compute.h"
REGISTER_LITE_KERNEL(reduce_sum,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ReduceSumCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
#include "lite/kernels/x86/reduce_op_function.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
struct SumFunctor {
template <typename X, typename Y, typename Dim>
void operator()(X* x, Y* y, const Dim& dim) {
y->device(lite::fluid::EigenDeviceType<TARGET(kX86)>()) = x->sum(dim);
}
};
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
paddle::lite::kernels::x86:: \
ReduceFunctor<lite::TargetType::kX86, T, NDIM, RDIM, SumFunctor>( \
*input, output, dims, keep_dim); \
}
template <typename T>
class ReduceSumCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReduceParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ReduceParam>();
// auto& context = ctx_->As<X86Context>();
bool reduce_all = param.reduce_all;
auto* input = param.x;
auto* output = param.output;
param.output->mutable_data<T>();
auto dims = param.dim;
bool keep_dim = param.keep_dim;
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = lite::fluid::EigenVector<T>::Flatten(*input);
auto out = lite::fluid::EigenScalar<T>::From(output);
// auto& place = *platform::CPUDeviceContext().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
SumFunctor functor;
functor(&x, &out, reduce_dim);
} else {
int ndim = input->dims().size();
int rdim = dims.size();
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
}
virtual ~ReduceSumCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = lite::fluid::EigenTensor<T, D, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = lite::fluid::EigenScalar<T, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = lite::fluid::EigenVector<T, MajorType, IndexType>;
template <lite::TargetType Target,
typename T,
size_t D,
size_t R_D,
typename Functor>
// const lite::Context<Target>& context,
void ReduceFunctor(const lite::Tensor& input,
lite::Tensor* output,
const std::vector<int>& dims,
bool keep_dim) {
auto x = EigenTensor<T, D>::From(input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto reduce_dim = Eigen::array<int, R_D>();
std::vector<int> dims_ref = dims;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
reduce_dim[i] = dims_ref[i];
}
// construct the squeezed output tensor
lite::DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
const int kDelFlag = -2;
auto dims_vector = out_dims.Vectorize();
for (size_t i = 0; i < dims_ref.size(); ++i) {
dims_vector[dims_ref[i]] = kDelFlag;
}
dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = lite::DDim(dims_vector);
}
// auto& place = *context.eigen_device();
Functor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(output);
functor(&x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(&x, &out, reduce_dim);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -34,3 +34,14 @@ REGISTER_LITE_KERNEL(reshape2, ...@@ -34,3 +34,14 @@ REGISTER_LITE_KERNEL(reshape2,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(reshape2,
kX86,
kInt64,
kNCHW,
paddle::lite::kernels::x86::Reshape2Compute<int64_t>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("XShape",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reshape_compute.h"
REGISTER_LITE_KERNEL(
sequence_reshape,
kX86,
kInt64,
kNCHW,
paddle::lite::kernels::x86::SequenceReshapeCompute<int64_t>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SequenceReshapeCompute
: public KernelLite<TARGET(kX86), PRECISION(kInt64)> {
public:
using param_t = operators::SequenceReshapeParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SequenceReshapeParam>();
// auto& context = context_->As<X86Context>();
auto* in = param.x;
auto* out = param.output;
int out_width = param.new_dim;
auto in_dims = in->dims();
int64_t in_width = in_dims[1];
// LOG(INFO)<<"sequence_reshape in tensor:"<<*in;
auto& in_lod = in->lod();
CHECK_EQ(in_lod.size(), 1UL);
CHECK_EQ((uint64_t)in_dims[0], in_lod[0].back());
auto in_lod_l0 = in_lod[0];
int seq_num = in_lod_l0.size() - 1;
if (in_width == out_width) {
out->set_lod(in->lod());
} else {
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].resize(seq_num + 1);
out_lod[0][0] = 0;
for (int i = 0; i < seq_num; ++i) {
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
size_t offset = 0;
offset = (seq_len * in_width) / out_width;
CHECK_EQ(offset * out_width, seq_len * in_width);
out_lod[0][i + 1] = out_lod[0][i] + offset;
}
}
out->Resize(in_dims);
auto* dst_ptr = out->mutable_data<T>();
auto size = in->numel() * sizeof(T);
std::memcpy(dst_ptr, in->data<T>(), size);
std::vector<int64_t> out_shape{static_cast<int64_t>(out->lod()[0].back()),
out_width};
out->Resize(lite::DDim(out_shape));
}
virtual ~SequenceReshapeCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -76,6 +76,8 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS ...@@ -76,6 +76,8 @@ add_operator(sequence_expand_as_op_lite extra SRCS sequence_expand_as_op.cc DEPS
add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS}) add_operator(range_op extra SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
...@@ -118,6 +118,7 @@ REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); ...@@ -118,6 +118,7 @@ REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN #ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace operators { namespace operators {
bool ConcatOpLite::CheckShape() const { bool ConcatOpLite::CheckShape() const {
CHECK_GT_OR_FALSE(param_.x.size(), 1UL); CHECK_GE_OR_FALSE(param_.x.size(), 1UL);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
return true; return true;
} }
......
...@@ -50,6 +50,7 @@ bool LookupTableOpLite::InferShape() const { ...@@ -50,6 +50,7 @@ bool LookupTableOpLite::InferShape() const {
} }
out_dims.push_back(table_dims[1]); out_dims.push_back(table_dims[1]);
param_.Out->Resize(lite::DDim{out_dims}); param_.Out->Resize(lite::DDim{out_dims});
param_.Out->set_lod(param_.Ids->lod());
return true; return true;
} }
......
...@@ -721,6 +721,12 @@ struct SequencePoolParam { ...@@ -721,6 +721,12 @@ struct SequencePoolParam {
#endif #endif
}; };
struct SequenceReshapeParam {
lite::Tensor* x{};
lite::Tensor* output{};
int new_dim;
};
struct SequenceExpandParam { struct SequenceExpandParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
const lite::Tensor* Y{}; const lite::Tensor* Y{};
...@@ -753,6 +759,15 @@ struct IsEmptyParam { ...@@ -753,6 +759,15 @@ struct IsEmptyParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct ReduceParam {
lite::Tensor* x{};
lite::Tensor* output{};
std::vector<int> dim{0};
bool keep_dim{false};
bool reduce_all{false};
};
/// ----------------------- shape operators ---------------------- /// ----------------------- shape operators ----------------------
struct ShapeParam { struct ShapeParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reduce_ops.h"
#include <algorithm>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReduceOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
CHECK_LE(x_rank, 6UL) << "Tensors with rank at most 6 are supported.";
return true;
}
bool ReduceOp::InferShape() const {
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
auto dims = param_.dim;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
CHECK_LT(dims[i], x_rank)
<< "The dim should be in the range [-rank(input), rank(input).";
}
sort(dims.begin(), dims.end());
bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim;
if (reduce_all) {
if (keep_dim)
param_.output->Resize(lite::DDim(std::vector<int64_t>(x_rank, 1)));
else
param_.output->Resize(lite::DDim(std::vector<int64_t>{1}));
} else {
auto dims_vector = x_dims.Vectorize();
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = lite::DDim(dims_vector);
param_.output->Resize(out_dims);
if (dims[0] != 0) {
param_.output->set_lod(param_.x->lod());
}
}
return true;
}
bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.dim = opdesc.GetAttr<std::vector<int>>("dim");
param_.reduce_all = opdesc.GetAttr<bool>("reduce_all");
param_.keep_dim = opdesc.GetAttr<bool>("keep_dim");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(reduce_sum, paddle::lite::operators::ReduceOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ReduceOp : public OpLite {
public:
ReduceOp() {}
explicit ReduceOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce"; }
private:
mutable ReduceParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_reshape_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceReshapeOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
CHECK_EQ_OR_FALSE(x_dims.size(), 2U);
return true;
}
bool SequenceReshapeOp::InferShape() const {
int new_dim = param_.new_dim;
auto x_numel = param_.x->dims().production();
std::vector<int64_t> out_shape{x_numel / new_dim,
static_cast<int64_t>(new_dim)};
param_.output->Resize(lite::DDim(out_shape));
return true;
}
bool SequenceReshapeOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.new_dim = opdesc.GetAttr<int>("new_dim");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_reshape, paddle::lite::operators::SequenceReshapeOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceReshapeOp : public OpLite {
public:
SequenceReshapeOp() {}
explicit SequenceReshapeOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_reshape"; }
private:
mutable SequenceReshapeParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -22,6 +22,14 @@ namespace lite { ...@@ -22,6 +22,14 @@ namespace lite {
class Any { class Any {
public: public:
Any() = default;
explicit Any(const Any& other) {
type_ = other.type_;
data_ = other.clone_data_(other.data_);
deleter_ = other.deleter_;
clone_data_ = other.clone_data_;
}
template <typename T> template <typename T>
void set(const T& v) { void set(const T& v) {
set<T>(); set<T>();
...@@ -34,7 +42,16 @@ class Any { ...@@ -34,7 +42,16 @@ class Any {
CHECK(type_ == typeid(T).hash_code()); CHECK(type_ == typeid(T).hash_code());
} else { } else {
type_ = typeid(T).hash_code(); type_ = typeid(T).hash_code();
deleter_ = [&] { delete static_cast<T*>(data_); }; deleter_ = [&](void** data) {
delete static_cast<T*>(*data);
*data = nullptr;
};
clone_data_ = [&](void* data) {
T* res = new T;
CHECK(data) << "data pointer is nullptr";
*res = *static_cast<T*>(data);
return res;
};
} }
data_ = new T; data_ = new T;
} }
...@@ -54,17 +71,18 @@ class Any { ...@@ -54,17 +71,18 @@ class Any {
bool valid() const { return (data_ != nullptr); } bool valid() const { return (data_ != nullptr); }
// ~Any() { ~Any() {
// if (valid()) { if (valid()) {
// deleter_(); deleter_(&data_);
// } }
// } }
private: private:
static size_t kInvalidType; static size_t kInvalidType;
size_t type_{kInvalidType}; size_t type_{kInvalidType};
void* data_{nullptr}; void* data_{nullptr};
std::function<void()> deleter_; std::function<void(void**)> deleter_;
std::function<void*(void*)> clone_data_;
}; };
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册