提交 cce54cb6 编写于 作者: D dingminghui 提交者: jackzhang235

feat(squeeze): add squeeze converter

上级 148fafe8
......@@ -23,6 +23,7 @@ lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgrap
lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_mlu
......@@ -42,6 +43,7 @@ set(mlu_subgraph_bridges
subgraph_bridge_slice_op_mlu
subgraph_bridge_split_op_mlu
subgraph_bridge_argmax_op_mlu
subgraph_bridge_squeeze_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
......@@ -66,6 +68,7 @@ lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optim
lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
if (LITE_BUILD_EXTRA)
lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
endif()
......
......@@ -34,6 +34,9 @@ USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU);
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(argmax, kMLU);
USE_SUBGRAPH_BRIDGE(split, kMLU);
USE_SUBGRAPH_BRIDGE(slice, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
#ifdef LITE_BUILD_EXTRA
USE_SUBGRAPH_BRIDGE(lrn, kMLU)
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int SqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto fp_type = graph->FPType();
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto output_dims_nhwc = DimNCHW2NHWC(output_dims);
std::vector<int> o_dims(output_dims.size());
std::transform(output_dims_nhwc.cbegin(),
output_dims_nhwc.cend(),
o_dims.begin(),
[](DDim::value_type d) { return static_cast<int>(d); });
cnmlReshapeOpParam_t param;
cnmlBaseOp_t squeeze_op;
CNML_CALL(cnmlCreateNdReshapeOpParam(&param, o_dims.data(), o_dims.size()));
CNML_CALL(cnmlCreateReshapeOp(&squeeze_op,
param,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyReshapeOpParam(&param));
graph->FuseOp(squeeze_op);
CNML_CALL(cnmlDestroyBaseOp(&squeeze_op));
if (op_type == "squeeze2") {
auto xshape_var_name = op_info->Output("XShape").front();
auto xshape = scope->FindVar(xshape_var_name)->GetMutable<Tensor>();
auto dims_64 = xshape->dims().Vectorize();
auto dims_64_nhwc = DimNCHW2NHWC(dims_64);
auto xshape_tensor = graph->AddNode(
xshape_var_name, dims_64, CNML_TENSOR, CNML_NCHW, fp_type);
std::vector<int> xshape_dims(dims_64.size());
std::transform(dims_64_nhwc.cbegin(),
dims_64_nhwc.cend(),
xshape_dims.begin(),
[](DDim::value_type d) { return static_cast<int>(d); });
cnmlBaseOp_t squeeze2_op;
CNML_CALL(cnmlCreateNdReshapeOpParam(
&param, xshape_dims.data(), xshape_dims.size()));
CNML_CALL(cnmlCreateReshapeOp(&squeeze2_op,
param,
input_tensor->mlu_tensor(),
xshape_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyReshapeOpParam(&param));
graph->FuseOp(squeeze2_op);
CNML_CALL(cnmlDestroyBaseOp(&squeeze2_op));
}
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(squeeze,
kMLU,
paddle::lite::subgraph::mlu::SqueezeConverter);
REGISTER_SUBGRAPH_BRIDGE(squeeze2,
kMLU,
paddle::lite::subgraph::mlu::SqueezeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/squeeze_op.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
// squeeze
TEST(MLUBridges, squeeze) {
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string ref_var_name("ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(ref_var_name)->GetMutable<Tensor>();
std::vector<int64_t> x_shape({1, 3, 1, 5});
x->Resize(x_shape);
out_ref->Resize(x_shape);
std::vector<int64_t> out_shape({3, 5});
out->Resize(out_shape);
FillTensor<float>(x, 0, 10);
out_ref->CopyDataFrom(*x);
// SqueezeCompute squeeze;
cpp::OpDesc opdesc;
opdesc.SetType("squeeze");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
std::vector<int> axes{0, -2};
opdesc.SetAttr("axes", axes);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::SqueezeOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name});
auto x_data = out_ref->data<float>();
auto out_data = out->data<float>();
for (int j = 0; j < out->numel(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
// squeeze2
TEST(MLUBridges, squeeze2) {
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string xshape_var_name("xshape");
std::string ref_var_name("ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* xshape = scope.Var(xshape_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(ref_var_name)->GetMutable<Tensor>();
std::vector<int64_t> x_shape({1, 3, 1, 5});
x->Resize(x_shape);
out_ref->Resize(x_shape);
std::vector<int64_t> out_shape({3, 5});
out->Resize(out_shape);
std::vector<int64_t> xshape_shape({1, 3, 1, 5});
xshape->Resize(xshape_shape);
FillTensor<float>(x, 0, 10);
out_ref->CopyDataFrom(*x);
// Squeeze2Compute squeeze2;
cpp::OpDesc opdesc;
opdesc.SetType("squeeze2");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetOutput("XShape", {xshape_var_name});
std::vector<int> axes({0, -2});
opdesc.SetAttr("axes", axes);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::SqueezeOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name, xshape_var_name});
auto x_data = out_ref->mutable_data<float>();
auto out_data = out->mutable_data<float>();
auto xshape_data = xshape->mutable_data<float>();
for (int j = 0; j < out->numel(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
EXPECT_NEAR(xshape_data[j], x_data[j], 1e-5);
}
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
......@@ -103,14 +103,44 @@ inline const ::paddle::lite::DDimLite DimNCHW2NHWC(
std::vector<int64_t>({dim[0], dim[2], dim[3], dim[1]}));
}
inline const std::vector<int64_t> DimNHWC2NCHW(
const std::vector<int64_t>& dim) {
return std::vector<int64_t>({dim[0], dim[3], dim[1], dim[2]});
inline const std::vector<DDimLite::value_type> DimNHWC2NCHW(
const std::vector<DDimLite::value_type>& dim) {
switch (dim.size()) {
case 1:
return dim;
case 2:
return dim;
case 3:
return std::vector<DDimLite::value_type>({dim[0], dim[2], dim[1]});
case 4:
return std::vector<DDimLite::value_type>(
{dim[0], dim[3], dim[1], dim[2]});
case 5:
return std::vector<DDimLite::value_type>(
{dim[0], dim[4], dim[1], dim[2], dim[3]});
default:
CHECK(0) << "unsupport dimension";
}
}
inline const std::vector<int64_t> DimNCHW2NHWC(
const std::vector<int64_t>& dim) {
return std::vector<int64_t>({dim[0], dim[2], dim[3], dim[1]});
inline const std::vector<DDimLite::value_type> DimNCHW2NHWC(
const std::vector<DDimLite::value_type>& dim) {
switch (dim.size()) {
case 1:
return dim;
case 2:
return dim;
case 3:
return std::vector<DDimLite::value_type>({dim[0], dim[2], dim[1]});
case 4:
return std::vector<DDimLite::value_type>(
{dim[0], dim[2], dim[3], dim[1]});
case 5:
return std::vector<DDimLite::value_type>(
{dim[0], dim[2], dim[3], dim[4], dim[1]});
default:
CHECK(0) << "unsupport dimension";
}
}
template <paddle::lite_api::PrecisionType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册