未验证 提交 55db1963 编写于 作者: M MaxwellDing 提交者: GitHub

[MLU] feat: add some kernels as bridge, test=develop (#3896)

add mlu kernel dropout split squeeze reshape
上级 9ae1e645
......@@ -18,8 +18,12 @@ lite_cc_library(subgraph_bridge_fc_op_mlu SRCS fc_op.cc DEPS ${subgraph_bridge_d
lite_cc_library(subgraph_bridge_scale_op_mlu SRCS scale_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_interp_op_mlu SRCS interpolate_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_cast_op_mlu SRCS cast_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_layout_op_mlu SRCS layout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_reshape_op_mlu SRCS reshape_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_mlu
......@@ -34,10 +38,15 @@ set(mlu_subgraph_bridges
subgraph_bridge_scale_op_mlu
subgraph_bridge_interp_op_mlu
subgraph_bridge_concat_op_mlu
subgraph_bridge_dropout_op_mlu
subgraph_bridge_split_op_mlu
subgraph_bridge_cast_op_mlu
subgraph_bridge_layout_op_mlu
subgraph_bridge_squeeze_op_mlu
subgraph_bridge_reshape_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges})
lite_cc_test(test_conv_converter_mlu SRCS conv_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_act_converter_mlu SRCS act_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
......@@ -49,6 +58,10 @@ lite_cc_test(test_fc_converter_mlu SRCS fc_op_test.cc DEPS scope optimizer targe
lite_cc_test(test_scale_converter_mlu SRCS scale_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_layout_converter_mlu SRCS layout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_cast_converter_mlu SRCS cast_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_reshape_converter_mlu SRCS reshape_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
/* auto mask_var_name = op_info->Output("Mask").front(); */
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
/* auto mask = scope->FindVar(mask_var_name)->GetMutable<Tensor>(); */
/* auto mask_dims = mask->dims().Vectorize(); */
/* auto mask_tensor = graph->AddNode( */
/* mask_var_name, mask_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); */
// is_test is true by default
// if(op_info->HasAttr("is_test")){
// auto is_test = op_info->GetAttr<bool>("is_test");
// CHECK(is_test != true);
// }
// Param fix_seed and seed is useless in MLU
auto dropout_implementation =
op_info->GetAttr<std::string>("dropout_implementation");
auto dropout_prob = op_info->GetAttr<float>("dropout_prob");
float alpha = 1.0f - dropout_prob;
if (dropout_implementation == "upscale_in_train") {
alpha = 1.;
}
float beta = 0.;
std::vector<int64_t> shape = {1, 1, 1, 1};
std::string alpha_var_name = string_format("dropout_alpha_%p", op);
std::string beta_var_name = string_format("dropout_beta_%p", op);
auto alpha_tensor = graph->AddNode(
alpha_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType());
auto beta_tensor = graph->AddNode(
beta_var_name, shape, CNML_CONST, CNML_NHWC, graph->FPType());
graph->BindConstRawData(alpha_var_name, &alpha, 1);
graph->BindConstRawData(beta_var_name, &beta, 1);
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t scale_op;
CNML_CALL(cnmlCreateScaleOp(&scale_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
alpha_tensor->mlu_tensor(),
beta_tensor->mlu_tensor()));
graph->FuseOp(scale_op);
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(dropout,
kMLU,
paddle::lite::subgraph::mlu::DropoutConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/dropout_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void dropout_ref(const std::shared_ptr<operators::DropoutOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto dropout_implementation =
op_info->GetAttr<std::string>("dropout_implementation");
auto dropout_prob = op_info->GetAttr<float>("dropout_prob");
float alpha = 1.0f - dropout_prob;
if (dropout_implementation == "upscale_in_train") {
alpha = 1.;
}
float beta = 0.;
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
DDim x_dims = x->dims();
DDim out_dims = out->dims();
CHECK_EQ(x_dims.production(), out_dims.production());
for (int i = 0; i < out_dims.production(); i++) {
out_data[i] = x_data[i] * alpha + beta;
}
}
void test_dropout(int bs,
int ic,
int ih,
int iw,
std::string dropout_implementation,
float dropout_prob,
float bias) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string mask_var_name("mask");
std::string out_ref_var_name("out_ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* mask = scope.Var(mask_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
bool is_test = true;
bool fix_seed = false;
int seed = 0;
cpp::OpDesc opdesc;
opdesc.SetType("dropout");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetOutput("Mask", {mask_var_name});
opdesc.SetAttr("is_test", is_test);
opdesc.SetAttr("fix_seed", fix_seed);
opdesc.SetAttr("seed", seed);
opdesc.SetAttr("dropout_implementation", dropout_implementation);
opdesc.SetAttr("dropout_prob", dropout_prob);
VLOG(6) << "mask: " << mask->dims()[0] << std::endl;
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::DropoutOp>(opdesc, &scope);
dropout_ref(op);
out_ref->CopyDataFrom(*out);
Tensor input_trans;
input_trans.Resize({bs, ic, ih, iw});
transpose(x->mutable_data<float>(),
input_trans.mutable_data<float>(),
{bs, ic, ih, iw},
{0, 2, 3, 1});
auto os = out->dims();
out->Resize({static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])});
x->CopyDataFrom(input_trans);
x->Resize({bs, ih, iw, ic});
LaunchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor('out')
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(os);
transpose(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(MLUBridges, dropout) {
for (auto bs : {1, 3}) {
for (auto ic : {1, 3}) {
for (auto ih : {3, 4}) {
for (auto iw : {4, 3}) {
for (auto dropout_implementation :
{"downgrade_in_infer", "upscale_in_train"}) {
for (auto dropout_prob : {0.f, 1.0f}) {
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw
<< " dropout_implementation: " << dropout_implementation
<< " dropout_prob: " << dropout_prob;
test_dropout(
bs, ic, ih, iw, dropout_implementation, dropout_prob, 0.);
}
}
}
}
}
}
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
......@@ -29,5 +29,11 @@ USE_SUBGRAPH_BRIDGE(concat, kMLU);
USE_SUBGRAPH_BRIDGE(scale, kMLU);
USE_SUBGRAPH_BRIDGE(sigmoid, kMLU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU);
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(split, kMLU);
USE_SUBGRAPH_BRIDGE(cast, kMLU);
USE_SUBGRAPH_BRIDGE(layout, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
USE_SUBGRAPH_BRIDGE(reshape, kMLU);
USE_SUBGRAPH_BRIDGE(reshape2, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
// ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name);
auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
CNML_NCHW);
cnmlBaseOp_t trans1_op{nullptr};
cnmlNdTransposeOpParam_t trans1_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans1_param, trans_1_axis.data(), trans_1_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op,
input_tensor->mlu_tensor(),
trans1_out->mlu_tensor(),
trans1_param));
// ======================== Trans1 End ==================================
// ======================= Reshape op ===================================
cnmlBaseOp_t reshape_op;
auto trans2_input = graph->AddNode(out_var_name + ".trans.o",
output_dims,
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
CNML_NCHW);
cnmlReshapeOpParam_t reshape_param{nullptr};
int cnml_trans2_input_shape[4];
CNML_CALL(
cnmlGetTensorShape(trans2_input->mlu_tensor(), cnml_trans2_input_shape));
CNML_CALL(
cnmlCreateNdReshapeOpParam(&reshape_param, cnml_trans2_input_shape, 4));
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateReshapeOp(&reshape_op,
reshape_param,
trans1_out->mlu_tensor(),
trans2_input->mlu_tensor()));
// ======================= Reshape op End ===================================
// ================== Trans2: NCHW => NHWC ===============================
auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr};
cnmlNdTransposeOpParam_t trans2_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans2_param, trans_2_axis.data(), trans_2_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op,
trans2_input->mlu_tensor(),
output_tensor->mlu_tensor(),
trans2_param));
// ======================== Trans2 End ==================================
// =============== DEBUG ====================
VLOG(6) << "x_var_name: " << x_var_name;
VLOG(6) << "out_var_name: " << out_var_name;
VLOG(6) << "input dim: " << x->dims();
VLOG(6) << "output dim: " << output->dims();
int cnml_input_shape[4];
CNML_CALL(cnmlGetTensorShape(input_tensor->mlu_tensor(), cnml_input_shape));
VLOG(6) << "cnml input dim: ";
for (size_t i = 0; i < 4; i++) {
VLOG(6) << cnml_input_shape[i];
}
// cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(trans1_out->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(trans2_input->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
// =============== DEBUG END =================
graph->FuseOp(trans1_op);
graph->FuseOp(reshape_op);
graph->FuseOp(trans2_op);
CNML_CALL(cnmlDestroyBaseOp(&trans1_op));
CNML_CALL(cnmlDestroyBaseOp(&reshape_op));
CNML_CALL(cnmlDestroyBaseOp(&trans2_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(reshape,
kMLU,
paddle::lite::subgraph::mlu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(reshape2,
kMLU,
paddle::lite::subgraph::mlu::ReshapeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void test_reshape(std::vector<int64_t> input_shape,
std::vector<int64_t> out_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
Tensor x_cpu;
// initialize input&output data
FillTensor<float, int>(x);
x_cpu.CopyDataFrom(*x);
Tensor input_trans;
input_trans.Resize(input_shape);
transpose(x->mutable_data<float>(),
input_trans.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_trans);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("reshape2");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
std::vector<int> shape_attr;
shape_attr.resize(out_shape.size());
for (size_t i = 0; i < out_shape.size(); i++) {
shape_attr[i] = static_cast<int>(out_shape[i]);
}
opdesc.SetAttr<std::vector<int>>("shape", shape_attr);
auto op = CreateOp<operators::ReshapeOp>(opdesc, &scope);
auto os = out->dims();
out->Resize(out_shape);
LaunchOp(op, {x_var_name}, {out_var_name});
Tensor out_trans;
out_trans.Resize(out_shape);
transpose(out->mutable_data<float>(),
out_trans.mutable_data<float>(),
{static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[1]),
static_cast<int>(out_shape[2]),
static_cast<int>(out_shape[3])},
{0, 3, 1, 2});
out->CopyDataFrom(out_trans);
// compare results
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], x_cpu.mutable_data<float>()[i], 1e-5);
}
}
TEST(MLUBridges, reshape) { test_reshape({1, 2, 4, 4}, {1, 4, 2, 4}); }
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(reshape, kMLU);
USE_SUBGRAPH_BRIDGE(reshape2, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out");
auto param_axis = op_info->GetAttr<int>("axis");
auto num = op_info->GetAttr<int>("num");
auto sections = op_info->GetAttr<std::vector<int>>("sections");
int64_t sections_num = static_cast<int64_t>(sections.size());
auto output_num = num > 0 ? num : sections_num;
std::vector<cnmlTensor_t> output_tensor;
for (auto out_name : out_var_name) {
auto out = scope->FindVar(out_name)->GetMutable<Tensor>();
auto out_dims = out->dims().Vectorize();
auto out_tensor = graph->AddNode(
out_name, out_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
output_tensor.push_back(out_tensor->mlu_tensor());
}
auto dims = x_dims.size();
int axis = (param_axis < 0) ? (param_axis + dims) : param_axis;
CHECK_LE(axis, 4) << "Unsupport dims in mlu concat";
int nhwc_axis = GetAxisNHWC2NCHW<int>(dims)[axis];
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t split_op;
cnmlTensor_t inputs = input_tensor->mlu_tensor();
CNML_CALL(cnmlCreateNdSplitOp(
&split_op, nhwc_axis, &inputs, 1, output_tensor.data(), output_num));
graph->FuseOp(split_op);
CNML_CALL(cnmlDestroyBaseOp(&split_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(split,
kMLU,
paddle::lite::subgraph::mlu::SplitConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/split_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
template <typename dtype>
void split_ref(const std::shared_ptr<operators::SplitOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
int num = op_info->GetAttr<int>("num");
int axis = op_info->GetAttr<int>("axis");
std::vector<int> sections = op_info->GetAttr<std::vector<int>>("sections");
std::vector<lite::Tensor*> output_vec;
auto output = op_info->Output("Out");
for (auto out_var : output) {
output_vec.push_back(scope->Var(out_var)->GetMutable<Tensor>());
}
auto in_dims = x->dims();
auto rank = in_dims.size();
int outs_number = output_vec.size();
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
if (axis < 0) {
axis += rank;
}
if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
}
}
for (int j = 0; j < outs_dims.size(); ++j) {
output_vec[j]->Resize(outs_dims[j]);
}
const dtype* din = x->mutable_data<const dtype>();
std::vector<int> in_strides(in_dims.size());
in_strides[in_dims.size() - 1] = in_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; --i) {
in_strides[i] = in_strides[i + 1] * in_dims[i];
}
int input_offset = 0;
for (auto out : output_vec) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
dtype* out_data = out->mutable_data<dtype>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
for (int i = 0; i < before; ++i) {
std::memcpy(out_data + i * out_after,
din + input_offset + i * in_after,
sizeof(dtype) * out_after);
}
input_offset += out_strides[axis];
}
}
void test_split(int bs,
int ic,
int ih,
int iw,
int axis,
int num,
std::vector<int> sections) {
// prepare input&output variables
std::string x_var_name = "x";
std::string out_var_name_1 = "out_1";
std::string out_var_name_2 = "out_2";
std::string out_ref_var_name_1 = "out_ref_1";
std::string out_ref_var_name_2 = "out_ref_2";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out_1 = scope.Var(out_var_name_1)->GetMutable<Tensor>();
auto* out_2 = scope.Var(out_var_name_2)->GetMutable<Tensor>();
auto* out_ref_1 = scope.Var(out_ref_var_name_1)->GetMutable<Tensor>();
auto* out_ref_2 = scope.Var(out_ref_var_name_2)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("split");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name_1, out_var_name_2});
opdesc.SetAttr("axis", axis);
opdesc.SetAttr("sections", sections);
opdesc.SetAttr("num", num);
auto op = CreateOp<operators::SplitOp>(opdesc, &scope);
split_ref<float>(op);
out_ref_1->CopyDataFrom(*out_1);
out_ref_2->CopyDataFrom(*out_2);
// execute reference implementation and save to output tensor
Tensor input;
input.Resize({bs, ic, ih, iw});
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(bs),
static_cast<int>(ic),
static_cast<int>(ih),
static_cast<int>(iw)},
{0, 2, 3, 1});
x->CopyDataFrom(input);
LaunchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2});
// compare results
auto* out_data_1 = out_1->mutable_data<float>();
auto* out_data_2 = out_2->mutable_data<float>();
auto* out_ref_data_1 = out_ref_1->mutable_data<float>();
auto* out_ref_data_2 = out_ref_2->mutable_data<float>();
Tensor output1, output2;
output1.Resize(out_1->dims());
output2.Resize(out_2->dims());
transpose<float>(out_data_1,
output1.mutable_data<float>(),
{static_cast<int>(out_1->dims()[0]),
static_cast<int>(out_1->dims()[2]),
static_cast<int>(out_1->dims()[3]),
static_cast<int>(out_1->dims()[1])},
{0, 3, 1, 2});
transpose<float>(out_data_2,
output2.mutable_data<float>(),
{static_cast<int>(out_2->dims()[0]),
static_cast<int>(out_2->dims()[2]),
static_cast<int>(out_2->dims()[3]),
static_cast<int>(out_2->dims()[1])},
{0, 3, 1, 2});
out_data_1 = output1.mutable_data<float>();
out_data_2 = output2.mutable_data<float>();
for (int i = 0; i < out_1->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data_1[i], out_ref_data_1[i], 5e-4);
}
for (int i = 0; i < out_2->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data_2[i], out_ref_data_2[i], 5e-4);
}
}
TEST(MLUBridges, split) {
test_split(4, 2, 3, 1, 0, 2, {});
test_split(4, 2, 3, 1, 0, 0, {3, 1});
test_split(4, 6, 3, 1, 1, 2, {});
test_split(4, 6, 3, 1, 1, 0, {2, 4});
test_split(4, 2, 2, 1, 2, 2, {});
test_split(4, 2, 6, 1, 2, 0, {3, 3});
test_split(4, 2, 3, 4, 3, 2, {});
test_split(4, 2, 3, 6, 3, 0, {5, 1});
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(split, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int SqueezeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto fp_type = graph->FPType();
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto output_dims_nhwc = DimNCHW2NHWC(output_dims);
std::vector<int> o_dims(output_dims.size());
std::transform(output_dims_nhwc.cbegin(),
output_dims_nhwc.cend(),
o_dims.begin(),
[](DDim::value_type d) { return static_cast<int>(d); });
cnmlReshapeOpParam_t param;
cnmlBaseOp_t squeeze_op;
CNML_CALL(cnmlCreateNdReshapeOpParam(&param, o_dims.data(), o_dims.size()));
CNML_CALL(cnmlCreateReshapeOp(&squeeze_op,
param,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyReshapeOpParam(&param));
graph->FuseOp(squeeze_op);
CNML_CALL(cnmlDestroyBaseOp(&squeeze_op));
if (op_type == "squeeze2") {
auto xshape_var_name = op_info->Output("XShape").front();
auto xshape = scope->FindVar(xshape_var_name)->GetMutable<Tensor>();
auto dims_64 = xshape->dims().Vectorize();
auto dims_64_nhwc = DimNCHW2NHWC(dims_64);
auto xshape_tensor = graph->AddNode(
xshape_var_name, dims_64, CNML_TENSOR, CNML_NCHW, fp_type);
std::vector<int> xshape_dims(dims_64.size());
std::transform(dims_64_nhwc.cbegin(),
dims_64_nhwc.cend(),
xshape_dims.begin(),
[](DDim::value_type d) { return static_cast<int>(d); });
cnmlBaseOp_t squeeze2_op;
CNML_CALL(cnmlCreateNdReshapeOpParam(
&param, xshape_dims.data(), xshape_dims.size()));
CNML_CALL(cnmlCreateReshapeOp(&squeeze2_op,
param,
input_tensor->mlu_tensor(),
xshape_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyReshapeOpParam(&param));
graph->FuseOp(squeeze2_op);
CNML_CALL(cnmlDestroyBaseOp(&squeeze2_op));
}
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(squeeze,
kMLU,
paddle::lite::subgraph::mlu::SqueezeConverter);
REGISTER_SUBGRAPH_BRIDGE(squeeze2,
kMLU,
paddle::lite::subgraph::mlu::SqueezeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/squeeze_op.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
// squeeze
TEST(MLUBridges, squeeze) {
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string ref_var_name("ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(ref_var_name)->GetMutable<Tensor>();
std::vector<int64_t> x_shape({1, 3, 1, 5});
x->Resize(x_shape);
out_ref->Resize(x_shape);
std::vector<int64_t> out_shape({3, 5});
out->Resize(out_shape);
FillTensor<float>(x, 0, 10);
out_ref->CopyDataFrom(*x);
// SqueezeCompute squeeze;
cpp::OpDesc opdesc;
opdesc.SetType("squeeze");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
std::vector<int> axes{0, -2};
opdesc.SetAttr("axes", axes);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::SqueezeOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name});
auto x_data = out_ref->data<float>();
auto out_data = out->data<float>();
for (int j = 0; j < out->numel(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
}
}
// squeeze2
TEST(MLUBridges, squeeze2) {
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string xshape_var_name("xshape");
std::string ref_var_name("ref");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* xshape = scope.Var(xshape_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(ref_var_name)->GetMutable<Tensor>();
std::vector<int64_t> x_shape({1, 3, 1, 5});
x->Resize(x_shape);
out_ref->Resize(x_shape);
std::vector<int64_t> out_shape({3, 5});
out->Resize(out_shape);
std::vector<int64_t> xshape_shape({1, 3, 1, 5});
xshape->Resize(xshape_shape);
FillTensor<float>(x, 0, 10);
out_ref->CopyDataFrom(*x);
// Squeeze2Compute squeeze2;
cpp::OpDesc opdesc;
opdesc.SetType("squeeze2");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetOutput("XShape", {xshape_var_name});
std::vector<int> axes({0, -2});
opdesc.SetAttr("axes", axes);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::SqueezeOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name, xshape_var_name});
auto x_data = out_ref->mutable_data<float>();
auto out_data = out->mutable_data<float>();
auto xshape_data = xshape->mutable_data<float>();
for (int j = 0; j < out->numel(); ++j) {
EXPECT_NEAR(out_data[j], x_data[j], 1e-5);
EXPECT_NEAR(xshape_data[j], x_data[j], 1e-5);
}
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册