未验证 提交 32a38a86 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add batchnorm op bridge and unit test (#2323)

* [XPU] add batchnorm op bridge and unit test

test=develop

* fix DMLC_USE_GLOG

test=develop
上级 2aa7b06b
......@@ -99,5 +99,7 @@ else()
set_property(TARGET xpu_sdk_llvm PROPERTY IMPORTED_LOCATION ${XPU_SDK_LLVM_FILE})
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_GLOG=1")
set(xpu_runtime_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu runtime libs")
set(xpu_builder_libs xpu_sdk_xtcl xpu_sdk_tvm xpu_sdk_xpu_api xpu_sdk_xpu_rt xpu_sdk_xpu_jitc xpu_sdk_llvm CACHE INTERNAL "xpu builder libs")
......@@ -8,6 +8,7 @@ lite_cc_library(xpu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${xpu_br
lite_cc_library(xpu_bridge_pool_op SRCS pool_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_softmax_op SRCS softmax_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_mul_op SRCS mul_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${xpu_bridge_deps})
set(xpu_bridges
xpu_bridge_registry
......@@ -17,6 +18,7 @@ set(xpu_bridges
xpu_bridge_pool_op
xpu_bridge_softmax_op
xpu_bridge_mul_op
xpu_bridge_batch_norm_op
CACHE INTERNAL "xpu_bridges")
set(xpu_bridge_test_deps ${xpu_bridges} ${xpu_kernels} ${ops})
......@@ -27,3 +29,4 @@ lite_cc_test(test_xpu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_h
lite_cc_test(test_xpu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_mul_op SRCS mul_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_batch_norm_op SRCS batch_norm_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type BatchNormConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " + op_type + "...";
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto scale_var_name = op_info->Input("Scale").front();
auto* scale = scope->FindMutableTensor(scale_var_name);
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindMutableTensor(bias_var_name);
auto mean_var_name = op_info->Input("Mean").front();
auto* mean = scope->FindMutableTensor(mean_var_name);
auto variance_var_name = op_info->Input("Variance").front();
auto* variance = scope->FindMutableTensor(variance_var_name);
auto epsilon = op_info->GetAttr<float>("epsilon");
// create scale node
CHECK(!input_nodes.count(scale_var_name));
auto scale_const_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateTensor(scale_var_name,
lite::xpu::CvtShape(scale->dims()),
::xtcl::Float(32)));
auto scale_const_tensor = lite::xpu::CvtTensor(scale);
graph_ctx->params->emplace(
std::make_pair(scale_var_name, *scale_const_tensor));
// create bias node
CHECK(!input_nodes.count(bias_var_name));
auto bias_const_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
bias_var_name, lite::xpu::CvtShape(bias->dims()), ::xtcl::Float(32)));
auto bias_const_tensor = lite::xpu::CvtTensor(bias);
graph_ctx->params->emplace(std::make_pair(bias_var_name, *bias_const_tensor));
// create mean node
CHECK(!input_nodes.count(mean_var_name));
auto mean_const_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
mean_var_name, lite::xpu::CvtShape(mean->dims()), ::xtcl::Float(32)));
auto mean_const_tensor = lite::xpu::CvtTensor(mean);
graph_ctx->params->emplace(std::make_pair(mean_var_name, *mean_const_tensor));
// create variance node
CHECK(!input_nodes.count(variance_var_name));
auto variance_const_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateTensor(variance_var_name,
lite::xpu::CvtShape(variance->dims()),
::xtcl::Float(32)));
auto variance_const_tensor = lite::xpu::CvtTensor(variance);
graph_ctx->params->emplace(
std::make_pair(variance_var_name, *variance_const_tensor));
// create batch_norm node and set params from op
CHECK(input_nodes.count(x_var_name));
auto batch_norm_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateBatchNorm(*input_nodes.at(x_var_name),
*scale_const_node,
*bias_const_node,
*mean_const_node,
*variance_const_node,
1,
epsilon));
batch_norm_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->GetField(*batch_norm_node, 0));
graph_ctx->builder->SetLayer(unique_op_type);
// output converted nodes
node_map_type output_nodes;
output_nodes[op_info->Output("Y").front()] = batch_norm_node;
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(batch_norm,
paddle::lite::kernels::xpu::bridges::BatchNormConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/batch_norm_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename dtype>
void batch_norm_ref(const std::shared_ptr<operators::BatchNormOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto y = scope->FindVar(op_info->Output("Y").front())->GetMutable<Tensor>();
auto bias =
scope->FindVar(op_info->Input("Bias").front())->GetMutable<Tensor>();
auto scale =
scope->FindVar(op_info->Input("Scale").front())->GetMutable<Tensor>();
auto mean =
scope->FindVar(op_info->Input("Mean").front())->GetMutable<Tensor>();
auto variance =
scope->FindVar(op_info->Input("Variance").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto y_data = y->mutable_data<dtype>();
auto scale_data = scale->mutable_data<dtype>();
auto bias_data = bias->mutable_data<dtype>();
auto mean_data = mean->mutable_data<dtype>();
auto variance_data = variance->mutable_data<dtype>();
DDim x_dims = x->dims();
float epsilon = op_info->GetAttr<float>("epsilon");
auto data_layout = op_info->GetAttr<std::string>("data_layout");
bool global_stats = op_info->GetAttr<bool>("use_global_stats");
if (global_stats) {
int64_t outer_size = 0;
int64_t channel_size = 0;
int64_t inner_size = 0;
if (data_layout == "NCHW") {
outer_size = x_dims[0];
channel_size = x_dims[1];
inner_size = x_dims.Slice(2, x_dims.size()).production();
} else {
LOG(FATAL) << "Unknown storage order: " << data_layout;
}
auto x_ptr = x_data;
auto y_ptr = y_data;
for (int o = 0; o < outer_size; o++) {
for (int c = 0; c < channel_size; c++) {
for (int i = 0; i < inner_size; i++) {
dtype norm_x =
(*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon);
*y_ptr = norm_x * scale_data[c] + bias_data[c];
x_ptr++;
y_ptr++;
}
}
}
}
}
void test_batch_norm(int bs, int ic, int ih, int iw, float epsilon) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
std::string scale_var_name = "scale";
std::string bias_var_name = "bias";
std::string mean_var_name = "mean";
std::string variance_var_name = "variance";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* scale = scope.Var(scale_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* mean = scope.Var(mean_var_name)->GetMutable<Tensor>();
auto* variance = scope.Var(variance_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
scale->Resize({ic});
bias->Resize({ic});
mean->Resize({ic});
variance->Resize({ic});
// initialize input&output data
FillTensor<float>(x);
FillTensor<float>(scale);
FillTensor<float>(bias);
FillTensor<float>(mean);
// variance > 0
FillTensor<float>(variance, 1.f, 5.f);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("batch_norm");
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Scale", {scale_var_name});
opdesc.SetInput("Bias", {bias_var_name});
opdesc.SetInput("Mean", {mean_var_name});
opdesc.SetInput("Variance", {variance_var_name});
opdesc.SetOutput("Y", {out_var_name});
opdesc.SetAttr("is_test", 1);
opdesc.SetAttr("use_global_stats", true);
opdesc.SetAttr("epsilon", epsilon);
opdesc.SetAttr("momentum", 0.9f);
opdesc.SetAttr("data_layout", std::string("NCHW"));
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::BatchNormOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
batch_norm_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, batch_norm) {
for (auto bs : {1, 3}) {
for (auto ic : {2, 3}) {
for (auto ih : {4}) {
for (auto iw : {5}) {
for (auto epsilon : {1e-5f}) {
test_batch_norm(bs, ic, ih, iw, epsilon);
}
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(batch_norm);
USE_XPU_BRIDGE(batch_norm);
......@@ -23,3 +23,4 @@ USE_XPU_BRIDGE(elementwise_add);
USE_XPU_BRIDGE(pool2d);
USE_XPU_BRIDGE(softmax);
USE_XPU_BRIDGE(mul);
USE_XPU_BRIDGE(batch_norm);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册