提交 d5e58a98 编写于 作者: L lijianshe02

add asr related kernel test=develop

上级 f480d474
......@@ -101,6 +101,8 @@ static size_t PrecisionTypeLength(PrecisionType type) {
return 1;
case PrecisionType::kInt32:
return 4;
case PrecisionType::kInt64:
return 8;
case PrecisionType::kFP16:
return 2;
default:
......
......@@ -54,6 +54,8 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
CREATE_KERNEL1(target__, kFP16); \
case PRECISION(kAny): \
CREATE_KERNEL1(target__, kAny); \
case PRECISION(kInt64): \
CREATE_KERNEL1(target__, kInt64); \
default: \
CHECK(false) << "not supported kernel precision " \
<< PrecisionToStr(precision); \
......@@ -123,6 +125,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
INIT_FOR(kX86, kAny, kAny);
INIT_FOR(kX86, kInt64, kNCHW);
INIT_FOR(kARM, kFloat, kNCHW);
INIT_FOR(kARM, kInt8, kNCHW);
......
......@@ -32,7 +32,7 @@ struct EigenDim {
static Type From(const lite::DDim& dims) {
PADDLE_ENFORCE(dims.size() == D, "D must match DDim::size");
Type ret;
for (int64_t d = 0; d < dims.size(); d++) {
for (size_t d = 0; d < dims.size(); d++) {
ret[d] = dims[d];
}
return ret;
......@@ -118,7 +118,9 @@ struct EigenScalar {
using ConstType = Eigen::TensorMap<
Eigen::TensorFixedSize<const T, Eigen::Sizes<>, MajorType, IndexType>>;
static Type From(Tensor& tensor) { return Type(tensor.data<T>()); } // NOLINT
static Type From(const Tensor& tensor) {
return Type(const_cast<T*>(tensor.data<T>()));
} // NOLINT
static ConstType From(const Tensor& tensor) {
return ConstType(tensor.data<T>());
......
......@@ -35,6 +35,9 @@ add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEP
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
if(NOT LITE_WITH_X86)
return()
......
......@@ -35,3 +35,13 @@ REGISTER_LITE_KERNEL(relu,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(softsign,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SoftsignCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -115,6 +115,31 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
virtual ~ReluCompute() = default;
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
template <typename T>
class SoftsignCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override {
// auto& context = ctx_->As<X86Context>();
auto& param = *param_.get_mutable<operators::ActivationParam>();
param.Out->template mutable_data<T>();
Activate<SoftsignFunctor<T>>(param.X, param.Out);
}
virtual ~SoftsignCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/lookup_table_compute.h"
// REGISTER_LITE_KERNEL(lookup_table, kX86, kFloat, kNCHW,
// paddle::lite::kernels::x86::LookupTableCompute<float>,
// def)
// .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86))})
// .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
// .Finalize();
//,
REGISTER_LITE_KERNEL(lookup_table,
kX86,
kInt64,
kNCHW,
paddle::lite::kernels::x86::LookupTableCompute<int64_t>,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
/*struct LookupTableTimer {
std::chrono::time_point<std::chrono::high_resolution_clock> timer_{};
uint64_t total_{};
void Start() { timer_ = std::chrono::high_resolution_clock::now(); }
void Stop() {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::high_resolution_clock::now() - timer_);
Log(duration.count());
}
void Log(uint32_t timespan) { total_ += timespan; }
~LookupTableTimer() {
LOG(INFO) << "lookup table timer: [" << total_ << "us]";
}
};*/
template <typename T>
class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kInt64)> {
public:
using param_t = operators::LookupTableParam;
void Run() override {
auto &param = *param_.get_mutable<operators::LookupTableParam>();
// auto& context = context_->As<X86Context>();
auto *ids_t = param.Ids;
auto *output_t = param.Out;
LOG(INFO) << "lookup_table input ids tensor address: " << ids_t;
int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>();
LOG(INFO) << "ids data address: " << ids;
int64_t ids_numel = ids_t->dims().production();
std::cout << "ids_numel: " << ids_numel << std::endl;
std::cout << "ids tensor info: [";
for (size_t i = 0; i < ids_numel; ++i) {
std::cout << ids[i] << ",";
}
std::cout << "]" << std::endl;
auto *table_t = param.W;
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
std::cout << "row_number: " << row_number << std::endl;
std::cout << "row_width: " << row_width << std::endl;
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
memset(output, 0, output_t->dims().production() * sizeof(T));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != -1 && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
std::cout << "*************************" << std::endl;
std::cout << "ids[i]: " << ids[i] << std::endl;
std::cout << "row_number: " << row_number << std::endl;
std::cout << "*************************" << std::endl;
CHECK_LT(ids[i], row_number);
CHECK_GE(ids[i], 0);
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
virtual ~LookupTableCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/reduce_compute.h"
REGISTER_LITE_KERNEL(reduce_sum,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::ReduceSumCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
#include "lite/kernels/x86/reduce_op_function.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
struct SumFunctor {
template <typename X, typename Y, typename Dim>
void operator()(X* x, Y* y, const Dim& dim) {
y->device(lite::fluid::EigenDeviceType<TARGET(kX86)>()) = x->sum(dim);
}
};
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
paddle::lite::kernels::x86:: \
ReduceFunctor<lite::TargetType::kX86, T, NDIM, RDIM, SumFunctor>( \
*input, output, dims, keep_dim); \
}
template <typename T>
class ReduceSumCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ReduceParam;
void Run() override {
auto& param = *param_.get_mutable<operators::ReduceParam>();
// auto& context = ctx_->As<X86Context>();
bool reduce_all = param.reduce_all;
auto* input = param.x;
auto* output = param.output;
param.output->mutable_data<T>();
auto dims = param.dim;
bool keep_dim = param.keep_dim;
if (reduce_all) {
// Flatten and reduce 1-D tensor
auto x = lite::fluid::EigenVector<T>::Flatten(*input);
auto out = lite::fluid::EigenScalar<T>::From(*output);
// auto& place = *platform::CPUDeviceContext().eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
SumFunctor functor;
functor(&x, &out, reduce_dim);
} else {
int ndim = input->dims().size();
int rdim = dims.size();
HANDLE_DIM(4, 3);
HANDLE_DIM(4, 2);
HANDLE_DIM(4, 1);
HANDLE_DIM(3, 2);
HANDLE_DIM(3, 1);
HANDLE_DIM(2, 1);
HANDLE_DIM(1, 1);
}
}
virtual ~ReduceSumCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = lite::fluid::EigenTensor<T, D, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = lite::fluid::EigenScalar<T, MajorType, IndexType>;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = lite::fluid::EigenVector<T, MajorType, IndexType>;
template <lite::TargetType Target,
typename T,
size_t D,
size_t R_D,
typename Functor>
// const lite::Context<Target>& context,
void ReduceFunctor(const lite::Tensor& input,
lite::Tensor* output,
const std::vector<int>& dims,
bool keep_dim) {
auto x = EigenTensor<T, D>::From(input);
auto x_rank = static_cast<int>(x.dimensions().size());
auto reduce_dim = Eigen::array<int, R_D>();
std::vector<int> dims_ref = dims;
for (size_t i = 0; i < dims_ref.size(); ++i) {
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
reduce_dim[i] = dims_ref[i];
}
// construct the squeezed output tensor
lite::DDim out_dims = output->dims();
if (keep_dim && x_rank > 1) {
const int kDelFlag = -2;
auto dims_vector = out_dims.Vectorize();
for (size_t i = 0; i < dims_ref.size(); ++i) {
dims_vector[dims_ref[i]] = kDelFlag;
}
dims_vector.erase(remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
out_dims = lite::DDim(dims_vector);
}
// auto& place = *context.eigen_device();
Functor functor;
if (D == 1) {
auto out = EigenScalar<T>::From(*output);
functor(&x, &out, reduce_dim);
} else {
auto out = EigenTensor<T, (D - R_D)>::From(*output, out_dims);
functor(&x, &out, reduce_dim);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reshape_compute.h"
REGISTER_LITE_KERNEL(sequence_reshape,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceReshapeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SequenceReshapeCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceReshapeParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SequenceReshapeParam>();
// auto& context = context_->As<X86Context>();
auto* in = param.x;
auto* out = param.output;
int out_width = param.new_dim;
auto in_dims = in->dims();
int64_t in_width = in_dims[1];
// LOG(INFO)<<"sequence_reshape in tensor:"<<*in;
auto& in_lod = in->lod();
CHECK_EQ(in_lod.size(), 1UL);
CHECK_EQ((uint64_t)in_dims[0], in_lod[0].back());
auto in_lod_l0 = in_lod[0];
int seq_num = in_lod_l0.size() - 1;
if (in_width == out_width) {
out->set_lod(in->lod());
} else {
auto& out_lod = *out->mutable_lod();
out_lod.resize(1);
out_lod[0].resize(seq_num + 1);
out_lod[0][0] = 0;
for (int i = 0; i < seq_num; ++i) {
size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i];
size_t offset = 0;
offset = (seq_len * in_width) / out_width;
CHECK_EQ(offset * out_width, seq_len * in_width);
out_lod[0][i + 1] = out_lod[0][i] + offset;
}
}
out->Resize(in_dims);
auto* dst_ptr = out->mutable_data<T>();
auto size = in->numel() * sizeof(T);
std::memcpy(dst_ptr, in->data<T>(), size);
std::vector<int64_t> out_shape{static_cast<int64_t>(out->lod()[0].back()),
out_width};
out->Resize(lite::DDim(out_shape));
}
virtual ~SequenceReshapeCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -76,6 +76,8 @@ add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite basic SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite basic SRCS reduce_ops.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -118,6 +118,7 @@ REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(softsign, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
......@@ -21,7 +21,7 @@ namespace lite {
namespace operators {
bool ConcatOpLite::CheckShape() const {
CHECK_GT_OR_FALSE(param_.x.size(), 1UL);
CHECK_GE_OR_FALSE(param_.x.size(), 1UL);
CHECK_OR_FALSE(param_.output);
return true;
}
......
......@@ -717,6 +717,12 @@ struct SequencePoolParam {
#endif
};
struct SequenceReshapeParam {
lite::Tensor* x{};
lite::Tensor* output{};
int new_dim;
};
struct SequenceExpandParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
......@@ -749,6 +755,15 @@ struct IsEmptyParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct ReduceParam {
lite::Tensor* x{};
lite::Tensor* output{};
std::vector<int> dim{0};
bool keep_dim{false};
bool reduce_all{false};
};
/// ----------------------- shape operators ----------------------
struct ShapeParam {
const lite::Tensor* X{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reduce_ops.h"
#include <algorithm>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ReduceOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
CHECK_LE(x_rank, 6UL) << "Tensors with rank at most 6 are supported.";
return true;
}
bool ReduceOp::InferShape() const {
auto x_dims = param_.x->dims();
auto x_rank = x_dims.size();
auto dims = param_.dim;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i];
CHECK_LT(dims[i], x_rank)
<< "The dim should be in the range [-rank(input), rank(input).";
}
sort(dims.begin(), dims.end());
bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim;
if (reduce_all) {
if (keep_dim)
param_.output->Resize(lite::DDim(std::vector<int64_t>(x_rank, 1)));
else
param_.output->Resize(lite::DDim(std::vector<int64_t>{1}));
} else {
auto dims_vector = x_dims.Vectorize();
if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1;
}
} else {
const int kDelFlag = -2;
for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = kDelFlag;
}
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
}
auto out_dims = lite::DDim(dims_vector);
param_.output->Resize(out_dims);
if (dims[0] != 0) {
param_.output->set_lod(param_.x->lod());
}
}
return true;
}
bool ReduceOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.dim = opdesc.GetAttr<std::vector<int>>("dim");
param_.reduce_all = opdesc.GetAttr<bool>("reduce_all");
param_.keep_dim = opdesc.GetAttr<bool>("keep_dim");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(reduce_sum, paddle::lite::operators::ReduceOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ReduceOp : public OpLite {
public:
ReduceOp() {}
explicit ReduceOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "reduce"; }
private:
mutable ReduceParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_reshape_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceReshapeOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.output);
auto x_dims = param_.x->dims();
CHECK_EQ_OR_FALSE(x_dims.size(), 2U);
return true;
}
bool SequenceReshapeOp::InferShape() const {
int new_dim = param_.new_dim;
auto x_numel = param_.x->dims().production();
std::vector<int64_t> out_shape{x_numel / new_dim,
static_cast<int64_t>(new_dim)};
param_.output->Resize(lite::DDim(out_shape));
return true;
}
bool SequenceReshapeOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.x =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.new_dim = opdesc.GetAttr<int>("new_dim");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_reshape, paddle::lite::operators::SequenceReshapeOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceReshapeOp : public OpLite {
public:
SequenceReshapeOp() {}
explicit SequenceReshapeOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_reshape"; }
private:
mutable SequenceReshapeParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册