未验证 提交 eed7a506 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add elementwise, pool, softmax op bridges and unit tests (#2264)

test=develop
上级 06d058fe
......@@ -4,14 +4,23 @@ set(xpu_bridge_deps xpu_bridge_registry xpu_builder op)
lite_cc_library(xpu_bridge_act_op SRCS act_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_conv_op SRCS conv_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_elementwise_ops SRCS elementwise_ops.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_pool_op SRCS pool_op.cc DEPS ${xpu_bridge_deps})
lite_cc_library(xpu_bridge_softmax_op SRCS softmax_op.cc DEPS ${xpu_bridge_deps})
set(xpu_bridges
xpu_bridge_registry
xpu_bridge_act_op
xpu_bridge_conv_op
xpu_bridge_elementwise_ops
xpu_bridge_pool_op
xpu_bridge_softmax_op
CACHE INTERNAL "xpu_bridges")
set(xpu_bridge_test_deps ${xpu_bridges} ${xpu_kernels} ${ops})
lite_cc_test(test_xpu_bridge_act_op SRCS act_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_elementwise_ops SRCS elementwise_ops_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_pool_op SRCS pool_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
lite_cc_test(test_xpu_bridge_softmax_op SRCS softmax_op_test.cc test_helper.cc DEPS ${xpu_bridge_test_deps})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type ElementwiseConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " + op_type + "...";
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
CHECK(input_nodes.count(x_var_name));
CHECK(input_nodes.count(y_var_name));
auto axis = op_info->GetAttr<int>("axis");
auto x_dims = scope->FindTensor(x_var_name)->dims();
auto y_dims = scope->FindTensor(y_var_name)->dims();
// create elementwise node and set input, attributes
std::shared_ptr<xtcl::xExpr> elementwise_node = nullptr;
if (y_dims.size() == 1) {
elementwise_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateBiasAdd(
*input_nodes.at(x_var_name), *input_nodes.at(y_var_name), axis));
} else if (x_dims.size() == y_dims.size()) {
elementwise_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateBinaryOp(
"add", *input_nodes.at(x_var_name), *input_nodes.at(y_var_name)));
} else {
LOG(ERROR) << "XPU elementwise_add only support y of one dimension, or x "
"and y of the same dimension. But recieved x's dimension: "
<< x_dims << ", y's dimension: " << y_dims << ", axis: " << axis;
}
graph_ctx->builder->SetLayer(unique_op_type);
// output converted nodes
node_map_type output_nodes;
output_nodes[op_info->Output("Out").front()] = elementwise_node;
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(elementwise_add,
paddle::lite::kernels::xpu::bridges::ElementwiseConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/elementwise_ops.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename dtype>
void elementwise_add_ref(const std::shared_ptr<operators::ElementwiseOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto y = scope->FindVar(op_info->Input("Y").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto y_data = y->data<dtype>();
dtype* out_data = out->mutable_data<dtype>();
auto x_dims = x->dims();
auto y_dims = y->dims();
int axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
// do elementwise add/sub/max...
std::string elt_type = "add";
if (elt_type == "add") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "sub") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr - diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "mul") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr * diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "max") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = std::max(*din_ptr, diny_data);
dout_ptr++;
din_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Elementwise type: " << elt_type;
}
}
void test_elementwise_add(std::vector<int64_t> x_dims,
std::vector<int64_t> y_dims,
int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string y_var_name = "y";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* y = scope.Var(y_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(x_dims);
if (y_dims.size() == 0) {
y->Resize(x_dims);
} else {
y->Resize(y_dims);
}
// initialize input&output data
FillTensor<float>(x);
FillTensor<float>(y);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("elementwise_add");
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Y", {y_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::ElementwiseOp>(opdesc, &scope);
LauchOp(op, {x_var_name, y_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
elementwise_add_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
// xpu's bias_add only support y with one dimension
TEST(XPUBridges, elementwise_add) {
test_elementwise_add({1, 2, 3, 4}, {1}, 0);
test_elementwise_add({1, 2, 3, 4}, {2}, 1);
test_elementwise_add({2, 2, 3, 4}, {3}, 2);
test_elementwise_add({2, 2, 3, 4}, {4}, 3);
test_elementwise_add({2, 2, 3, 4}, {4}, -1);
test_elementwise_add({2, 2, 3, 4}, {}, -1);
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_XPU_BRIDGE(elementwise_add);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " + op_type + "...";
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// get input, and attributes
auto x_var_name = op_info->Input("X").front();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto exclusive = op_info->GetAttr<bool>("exclusive");
// create pool node and set params from op
CHECK(input_nodes.count(x_var_name));
std::shared_ptr<xtcl::xExpr> pool_node = nullptr;
if (pooling_type == "max") {
if (global_pooling) {
pool_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateGlobalMaxPool2D(
*input_nodes.at(x_var_name)));
} else {
pool_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateMaxPool2D(*input_nodes.at(x_var_name),
lite::xpu::CvtShape(ksize),
lite::xpu::CvtShape(strides),
lite::xpu::CvtShape(paddings),
"NCHW",
ceil_mode));
}
} else if (pooling_type == "avg") {
if (global_pooling) {
pool_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateGlobalAvgPool2D(
*input_nodes.at(x_var_name)));
} else {
pool_node = std::make_shared<xtcl::xExpr>(
// !exclusive ---> count_include_pad
graph_ctx->builder->CreateAvgPool2D(*input_nodes.at(x_var_name),
lite::xpu::CvtShape(ksize),
lite::xpu::CvtShape(strides),
lite::xpu::CvtShape(paddings),
"NCHW",
ceil_mode,
!exclusive));
}
} else {
LOG(FATAL) << "Unsupported pooling type: " << pooling_type;
}
graph_ctx->builder->SetLayer(unique_op_type);
// output converted nodes
node_map_type output_nodes;
output_nodes[op_info->Output("Out").front()] = pool_node;
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(pool2d, paddle::lite::kernels::xpu::bridges::PoolConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pool_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void pool_ref(const std::shared_ptr<operators::PoolOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto& in_dims = x->dims();
auto& out_dims = out->dims();
const float* src_ptr = x->data<const float>();
float* dst_ptr = out->mutable_data<float>();
std::vector<int> ksize = op_info->GetAttr<std::vector<int>>("ksize");
std::vector<int> strides = op_info->GetAttr<std::vector<int>>("strides");
std::vector<int> paddings = op_info->GetAttr<std::vector<int>>("paddings");
bool exclusive = op_info->GetAttr<bool>("exclusive");
std::string pooling_type = op_info->GetAttr<std::string>("pooling_type");
bool global_pooling = op_info->GetAttr<bool>("global_pooling");
int in_n = in_dims[0];
int in_c = in_dims[1];
int in_h = in_dims[2];
int in_w = in_dims[3];
int size_in_n = in_c * in_h * in_w;
int size_in_c = in_h * in_w;
int out_h = out_dims[2];
int out_w = out_dims[3];
int size_out_n = in_c * out_h * out_w;
int size_out_c = out_h * out_w;
int window_h = ksize[0];
int window_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
if (global_pooling == true) {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
const float* src = src_ptr + n * size_in_n + c * size_in_c;
float res = src[0];
if (pooling_type == "max") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res = cur_val > res ? cur_val : res;
}
} else if (pooling_type == "avg") {
for (int i = 1; i < size_in_c; ++i) {
float cur_val = src[i];
res += cur_val;
}
res /= size_in_c;
}
dst_ptr[n * size_out_n + c] = res;
}
}
} else {
for (int n = 0; n < in_n; ++n) {
for (int c = 0; c < in_c; ++c) {
for (int h = 0; h < out_h; ++h) {
int sh = h * stride_h;
int eh = sh + window_h;
sh = (sh - pad_h) < 0 ? 0 : sh - pad_h;
eh = (eh - pad_h) > in_h ? in_h : eh - pad_h;
for (int w = 0; w < out_w; ++w) {
int sw = w * stride_w;
int ew = sw + window_w;
sw = (sw - pad_w) < 0 ? 0 : sw - pad_w;
ew = (ew - pad_w) > in_w ? in_w : ew - pad_w;
int pooling_size = (ew - sw) * (eh - sh);
if (pooling_size == 0) continue;
float res = 0.f;
for (int kh = sh; kh < eh; ++kh) {
for (int kw = sw; kw < ew; ++kw) {
int src_idx = n * size_in_n + c * size_in_c + kh * in_w + kw;
if (kh == sh && kw == sw) {
res = src_ptr[src_idx];
} else {
if (pooling_type == "max") {
res = res >= src_ptr[src_idx] ? res : src_ptr[src_idx];
}
if (pooling_type == "avg") {
res += src_ptr[src_idx];
}
}
}
}
if (pooling_type == "avg") {
if (exclusive) {
res /= pooling_size;
} else {
res /= window_h * window_w;
}
}
dst_ptr[n * size_out_n + c * size_out_c + h * out_w + w] = res;
}
}
}
}
}
}
void test_pool(int bs,
int ic,
int ih,
int iw,
std::string pooling_type,
bool ceil_mode,
bool global_pooling,
bool exclusive,
int ksize,
int stride,
int padding) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("pool2d");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("pooling_type", pooling_type);
opdesc.SetAttr("ksize", std::vector<int>({ksize, ksize}));
opdesc.SetAttr("global_pooling", global_pooling);
opdesc.SetAttr("exclusive", exclusive);
opdesc.SetAttr("strides", std::vector<int>({stride, stride}));
opdesc.SetAttr("paddings", std::vector<int>({padding, padding}));
opdesc.SetAttr("ceil_mode", ceil_mode);
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::PoolOpLite>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
pool_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, pool) {
for (auto pooling_type : {"max", "avg"}) {
for (auto bs : {1, 3}) {
for (auto ic : {2}) {
for (auto ih : {3}) {
for (auto iw : {4}) {
test_pool(bs, ic, ih, iw, pooling_type, true, true, true, 0, 1, 0);
}
}
}
}
}
for (auto pooling_type : {"max"}) {
for (auto ceil_mode : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) {
for (auto bs : {1, 3}) {
for (auto ic : {2}) {
for (auto ih : {3}) {
for (auto iw : {4}) {
test_pool(bs,
ic,
ih,
iw,
pooling_type,
ceil_mode,
false,
true,
ksize,
stride,
padding);
}
}
}
}
}
}
}
}
}
for (auto pooling_type : {"avg"}) {
for (auto ceil_mode : {true, false}) {
for (auto exclusive : {true, false}) {
for (auto ksize : {2, 3}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1}) {
for (auto bs : {1, 3}) {
for (auto ic : {2}) {
for (auto ih : {3}) {
for (auto iw : {4}) {
test_pool(bs,
ic,
ih,
iw,
pooling_type,
ceil_mode,
false,
exclusive,
ksize,
stride,
padding);
}
}
}
}
}
}
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(pool2d);
USE_XPU_BRIDGE(pool2d);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::xpu::UniqueName(op_type);
LOG(INFO) << "[XPU] Converting " + op_type + "...";
// check context
CHECK(graph_ctx != nullptr);
CHECK(graph_ctx->builder != nullptr);
CHECK(graph_ctx->params != nullptr);
// get op's attributes
auto x_var_name = op_info->Input("X").front();
auto axis = op_info->GetAttr<int>("axis");
// create softmax node and set params from ops
CHECK(input_nodes.count(x_var_name));
std::shared_ptr<xtcl::xExpr> softmax_node = nullptr;
softmax_node = std::make_shared<xtcl::xExpr>(
graph_ctx->builder->CreateSoftmax(*input_nodes.at(x_var_name), axis));
graph_ctx->builder->SetLayer(unique_op_type);
// output converted nodes
node_map_type output_nodes;
output_nodes[op_info->Output("Out").front()] = softmax_node;
return output_nodes;
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_XPU_BRIDGE(softmax,
paddle::lite::kernels::xpu::bridges::SoftmaxConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/softmax_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename dtype>
void softmax_ref(const std::shared_ptr<operators::SoftmaxOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
DDim x_dims = x->dims();
auto x_rank = x_dims.size();
int axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis += x_rank;
}
int axis_size = x_dims[axis];
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
out_data[offset] = exp(x_data[offset] - max_data);
sum_data += out_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
out_data[offset] /= sum_data;
offset += inner_num;
}
}
}
void test_softmax(int bs, int ic, int ih, int iw, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("softmax");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::SoftmaxOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
softmax_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, softmax) {
for (auto bs : {2, 3}) {
for (auto ic : {4}) {
for (auto ih : {5}) {
for (auto iw : {6}) {
for (auto axis : {-3, -1, 0, 1, 2, 3}) {
test_softmax(bs, ic, ih, iw, axis);
}
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(softmax);
USE_XPU_BRIDGE(softmax);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册