未验证 提交 792d898a 编写于 作者: L lijianshe02 提交者: GitHub

fix asr modle related kernel bugs test=develop (#2179)

* fix asr modle related kernel bugs test=develop
上级 253acb80
...@@ -375,6 +375,8 @@ endfunction() ...@@ -375,6 +375,8 @@ endfunction()
# Bundle several static libraries into one. # Bundle several static libraries into one.
function(bundle_static_library tgt_name bundled_tgt_name fake_target) function(bundle_static_library tgt_name bundled_tgt_name fake_target)
list(APPEND static_libs ${tgt_name}) list(APPEND static_libs ${tgt_name})
# for x86
add_dependencies(lite_compile_deps ${fake_target})
function(_recursively_collect_dependencies input_target) function(_recursively_collect_dependencies input_target)
set(_input_link_libraries LINK_LIBRARIES) set(_input_link_libraries LINK_LIBRARIES)
......
...@@ -463,9 +463,9 @@ void Blas<Target>::MatMul(const lite::Tensor &mat_a, ...@@ -463,9 +463,9 @@ void Blas<Target>::MatMul(const lite::Tensor &mat_a,
auto dim_out = mat_out->dims(); auto dim_out = mat_out->dims();
PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix"); "The input and output of matmul be matrix");
PADDLE_ENFORCE( // PADDLE_ENFORCE(
mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target(), // mat_a.target() == mat_b.target() && mat_a.target() == mat_out->target(),
"The targets of matrices must be same"); // "The targets of matrices must be same");
int M = dim_out[0]; int M = dim_out[0];
int N = dim_out[1]; int N = dim_out[1];
......
...@@ -4,17 +4,13 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li ...@@ -4,17 +4,13 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
# lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(sgd_compute_x86 SRCS sgd_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(mul_compute_x86 SRCS mul_compute.cc DEPS ${lite_kernel_deps})
# lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(relu_compute_x86 SRCS relu_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function)
add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op) add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
# lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
# lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) # lite_cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
# lite_cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
# lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col) # lite_cc_library(conv_compute_x86 SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col)
# lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling) # lite_cc_library(pool_compute_x86 SRCS pool_compute.cc DEPS ${lite_kernel_deps} pooling)
# lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(batch_norm_compute_x86 SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
...@@ -26,8 +22,6 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp ...@@ -26,8 +22,6 @@ add_kernel(sequence_expand_as_compute_x86 X86 basic SRCS sequence_expand_as_comp
# lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86) # lite_cc_test(test_fc_compute_x86 SRCS fc_compute_test.cc DEPS fc_compute_x86)
# lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) # lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
# lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) # lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
# lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
# lite_cc_test(test_elementwise_compute_x86 SRCS elementwise_compute_test.cc DEPS elementwise_compute_x86)
# lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86) # lite_cc_test(test_scale_compute_x86 SRCS scale_compute_test.cc DEPS scale_compute_x86)
# lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) # lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
# lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) # lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
......
...@@ -21,6 +21,6 @@ REGISTER_LITE_KERNEL( ...@@ -21,6 +21,6 @@ REGISTER_LITE_KERNEL(
kNCHW, kNCHW,
paddle::lite::kernels::x86::FillConstantBatchSizeLikeCompute<float>, paddle::lite::kernels::x86::FillConstantBatchSizeLikeCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -33,6 +33,7 @@ class FillConstantBatchSizeLikeCompute ...@@ -33,6 +33,7 @@ class FillConstantBatchSizeLikeCompute
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
auto& ctx = ctx_->As<X86Context>();
auto* out = param.Out; auto* out = param.Out;
auto* in = param.Input; auto* in = param.Input;
if (in->lod().size() && param.input_dim_idx == 0) { if (in->lod().size() && param.input_dim_idx == 0) {
...@@ -40,11 +41,13 @@ class FillConstantBatchSizeLikeCompute ...@@ -40,11 +41,13 @@ class FillConstantBatchSizeLikeCompute
auto odims = out->dims(); auto odims = out->dims();
int output_dim_idx = param.output_dim_idx; int output_dim_idx = param.output_dim_idx;
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1; odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->Resize(odims);
// out->mutable_data<T>();
} }
out->mutable_data<T>();
auto value = param.value; auto value = param.value;
paddle::lite::x86::math::SetConstant<TargetType::kX86, T> setter; paddle::lite::x86::math::SetConstant<lite::TargetType::kX86, T> setter;
Context<TargetType::kX86> ctx;
setter(ctx, out, static_cast<T>(value)); setter(ctx, out, static_cast<T>(value));
} }
......
...@@ -45,6 +45,7 @@ TEST(fill_constant_batch_size_like_x86, run_test) { ...@@ -45,6 +45,7 @@ TEST(fill_constant_batch_size_like_x86, run_test) {
std::vector<int64_t> input_shape{219, 232}; std::vector<int64_t> input_shape{219, 232};
input.Resize(input_shape); input.Resize(input_shape);
std::vector<int64_t> out_shape{219, 132, 7}; std::vector<int64_t> out_shape{219, 132, 7};
out.Resize(out_shape);
auto input_data = input.mutable_data<float>(); auto input_data = input.mutable_data<float>();
auto out_data = out.mutable_data<float>(); auto out_data = out.mutable_data<float>();
...@@ -64,11 +65,14 @@ TEST(fill_constant_batch_size_like_x86, run_test) { ...@@ -64,11 +65,14 @@ TEST(fill_constant_batch_size_like_x86, run_test) {
std::unique_ptr<KernelContext> ctx(new KernelContext); std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>(); ctx->As<X86Context>();
fill_constant_batch_size_like.SetContext(std::move(ctx));
fill_constant_batch_size_like.SetParam(param); fill_constant_batch_size_like.SetParam(param);
fill_constant_batch_size_like.Run(); fill_constant_batch_size_like.Run();
for (int i = 0; i < out.dims().production(); i++) { std::vector<float> ref_results{
LOG(INFO) << out_data[i]; 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5};
for (int i = 0; i < ref_results.size(); i++) {
EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
} }
} }
......
...@@ -28,9 +28,8 @@ REGISTER_LITE_KERNEL(gru, ...@@ -28,9 +28,8 @@ REGISTER_LITE_KERNEL(gru,
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("H0", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Batch_gate", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Batch_reset_hidden_prev", .BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kX86))})
{LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Batch_hidden", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -22,4 +22,5 @@ REGISTER_LITE_KERNEL(sequence_pool, ...@@ -22,4 +22,5 @@ REGISTER_LITE_KERNEL(sequence_pool,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("MaxIndex", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -20,6 +20,6 @@ REGISTER_LITE_KERNEL(shape, ...@@ -20,6 +20,6 @@ REGISTER_LITE_KERNEL(shape,
kNCHW, kNCHW,
paddle::lite::kernels::x86::ShapeCompute<float>, paddle::lite::kernels::x86::ShapeCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -46,7 +46,7 @@ bool FillConstantBatchSizeLikeOp::InferShape() const { ...@@ -46,7 +46,7 @@ bool FillConstantBatchSizeLikeOp::InferShape() const {
bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc, bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) { lite::Scope *scope) {
auto Input = op_desc.Input("X").front(); auto Input = op_desc.Input("Input").front();
auto Out = op_desc.Output("Out").front(); auto Out = op_desc.Output("Out").front();
param_.Input = scope->FindVar(Input)->GetMutable<lite::Tensor>(); param_.Input = scope->FindVar(Input)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
......
...@@ -685,6 +685,7 @@ struct SequencePoolParam { ...@@ -685,6 +685,7 @@ struct SequencePoolParam {
std::string pool_type{"AVERAGE"}; std::string pool_type{"AVERAGE"};
#ifdef LITE_WITH_X86 #ifdef LITE_WITH_X86
float pad_value{0.0}; float pad_value{0.0};
lite::Tensor* MaxIndex{};
#endif #endif
}; };
......
...@@ -52,7 +52,7 @@ class Any { ...@@ -52,7 +52,7 @@ class Any {
return static_cast<T*>(data_); return static_cast<T*>(data_);
} }
bool valid() const { return data_; } bool valid() const { return (data_ != nullptr); }
// ~Any() { // ~Any() {
// if (valid()) { // if (valid()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册