提交 f5338469 编写于 作者: L liu zhengxi 提交者: lijianshe02

add fill_constant_batch_size_like op and add its unittest (#2044)

* add fill_constant_batch_size_like op and add its unittest
上级 e7eea682
......@@ -9,6 +9,7 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function)
add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
# lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
......@@ -45,6 +46,7 @@ add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kerne
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
lite_cc_test(test_slice_compute_x86 SRCS slice_compute_test.cc DEPS slice_compute_x86)
lite_cc_test(test_squeeze_compute_x86 SRCS squeeze_compute_test.cc DEPS squeeze_compute_x86)
lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_batch_size_like_compute_test.cc DEPS fill_constant_batch_size_like_compute_x86)
lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
REGISTER_LITE_KERNEL(
fill_constant_batch_size_like,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::FillConstantBatchSizeLikeCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Core>
#include "lite/backends/x86/math/blas.h"
#include "lite/backends/x86/math/math_function.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class FillConstantBatchSizeLikeCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantBatchSizeLikeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto* out = param.Out;
auto* in = param.Input;
if (in->lod().size() && param.input_dim_idx == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
int output_dim_idx = param.output_dim_idx;
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
}
auto value = param.value;
paddle::lite::x86::math::SetConstant<TargetType::kX86, T> setter;
Context<TargetType::kX86> ctx;
setter(ctx, out, static_cast<T>(value));
}
virtual ~FillConstantBatchSizeLikeCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/fill_constant_batch_size_like_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(fill_constant_batch_size_like_x86, retrive_op) {
auto fill_constant_batch_size_like =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"fill_constant_batch_size_like");
ASSERT_FALSE(fill_constant_batch_size_like.empty());
ASSERT_TRUE(fill_constant_batch_size_like.front());
}
TEST(fill_constant_batch_size_like_x86, init) {
lite::kernels::x86::FillConstantBatchSizeLikeCompute<float>
fill_constant_batch_size_like;
ASSERT_EQ(fill_constant_batch_size_like.precision(), PRECISION(kFloat));
ASSERT_EQ(fill_constant_batch_size_like.target(), TARGET(kX86));
}
TEST(fill_constant_batch_size_like_x86, run_test) {
lite::Tensor input;
lite::Tensor out;
std::vector<int64_t> input_shape{219, 232};
input.Resize(input_shape);
std::vector<int64_t> out_shape{219, 132, 7};
auto input_data = input.mutable_data<float>();
auto out_data = out.mutable_data<float>();
for (int64_t i = 0; i < input.dims().production(); ++i) {
input_data[i] = static_cast<float>(i);
}
FillConstantBatchSizeLikeCompute<float> fill_constant_batch_size_like;
operators::FillConstantBatchSizeLikeParam param;
param.Input = &input;
param.Out = &out;
std::vector<int> shape{-1, 132, 7};
float value = 3.5;
param.shape = shape;
param.value = value;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
fill_constant_batch_size_like.SetParam(param);
fill_constant_batch_size_like.Run();
for (int i = 0; i < out.dims().production(); i++) {
LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fill_constant_batch_size_like, kX86, kFloat, kNCHW, def);
......@@ -25,6 +25,7 @@ add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DE
add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS})
add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS})
#add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS})
add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS})
add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fill_constant_batch_size_like_op.h"
#include <algorithm>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool FillConstantBatchSizeLikeOp::CheckShape() const {
CHECK_OR_FALSE(param_.Input);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool FillConstantBatchSizeLikeOp::InferShape() const {
auto shape = param_.shape;
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) {
return static_cast<int64_t>(a);
});
lite::DDim output_dim(shape_int64);
int input_dim_idx = param_.input_dim_idx;
int output_dim_idx = param_.output_dim_idx;
output_dim[output_dim_idx] = param_.Input->dims()[input_dim_idx];
param_.Out->Resize(output_dim);
return true;
}
bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto Input = op_desc.Input("X").front();
auto Out = op_desc.Output("Out").front();
param_.Input = scope->FindVar(Input)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.shape = op_desc.GetAttr<std::vector<int>>("shape");
param_.input_dim_idx = op_desc.GetAttr<int>("input_dim_idx");
param_.output_dim_idx = op_desc.GetAttr<int>("output_dim_idx");
param_.dtype = op_desc.GetAttr<int>("dtype");
param_.value = op_desc.GetAttr<float>("value");
CHECK(param_.Input);
CHECK(param_.Out);
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchSizeLikeOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FillConstantBatchSizeLikeOp : public OpLite {
public:
FillConstantBatchSizeLikeOp() {}
explicit FillConstantBatchSizeLikeOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fill_constant_batch_size_like";
}
private:
mutable FillConstantBatchSizeLikeParam param_;
};
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
......@@ -371,6 +371,17 @@ struct FillConstantParam {
lite::Tensor* Out{};
};
struct FillConstantBatchSizeLikeParam {
lite::Tensor* Input;
lite::Tensor* Out;
std::vector<int> shape;
int input_dim_idx{0};
int output_dim_idx{0};
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
float value{0.0f};
};
//
struct FakeQuantizeMovingAvgMaxAbsParam {
const lite::Tensor* x{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册