提交 a386ed11 编写于 作者: H hong19860320 提交者: GitHub

[LITE][XPU] Fix dropout op bridge and unit test for BERT (#2665)

上级 96e71565
......@@ -21,34 +21,51 @@ namespace lite {
namespace subgraph {
namespace xpu {
int DropoutConverter(void* ctx, OpLite* op) {
int DropoutConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Create node and set params from op
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto dropout_prob = op_info->GetAttr<float>("dropout_prob");
auto dropout_implementation =
op_info->GetAttr<std::string>("dropout_implementation");
double rate;
// X node
std::shared_ptr<xtcl::xExpr> x_node = nullptr;
if (graph->HasNode(x_name)) {
x_node = graph->GetNode(x_name);
} else {
x_node = graph->AddNode(x_name, x_dims);
}
// Dropout node
if (dropout_implementation == "downgrade_in_infer") {
rate = 1. - dropout_prob;
graph->AddNode(
out_name,
graph->builder_.CreateScale(*x_node, 1.f - dropout_prob, 0.0f, false));
} else if (dropout_implementation == "upscale_in_train") {
rate = 1.;
graph->AddNode(out_name,
graph->builder_.CreateScale(*x_node, 1.0f, 0.0f, false));
} else {
LOG(FATAL) << "unsupported dropout_implementation == "
<< dropout_implementation << " for dropout";
LOG(WARNING) << "[XPU] Unsupported dropout_implementation == "
<< dropout_implementation << " for dropout";
return FAILED;
}
CHECK(graph->HasNode(x_var_name));
graph->AddNode(
out_var_name,
graph->builder_.CreateDropout(*graph->GetNode(x_var_name), rate));
return SUCCESS;
}
......
......@@ -18,6 +18,7 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -26,8 +27,8 @@ class DropoutComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string type_ = "dropout";
std::string input_ = "x";
std::string output_ = "out";
std::string x_ = "x";
std::string out_ = "out";
std::string mask_ = "mask";
DDim dims_{{1}};
float dropout_prob_ = 0.5;
......@@ -51,12 +52,12 @@ class DropoutComputeTester : public arena::TestCase {
dropout_implementation_(dropout_implementation) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
auto* out = scope->NewTensor(out_);
CHECK(out);
out->Resize(dims_);
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>();
if (dropout_implementation_ == "downgrade_in_infer") {
......@@ -74,8 +75,8 @@ class DropoutComputeTester : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(type_);
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("Mask", {mask_});
op_desc->SetAttr("dropout_prob", dropout_prob_);
op_desc->SetAttr("fix_seed", fix_seed_);
......@@ -84,16 +85,9 @@ class DropoutComputeTester : public arena::TestCase {
}
void PrepareData() override {
std::vector<float> input_data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
#if 0
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
#else
input_data[i] = 1;
#endif
}
SetCommonTensor(input_, dims_, input_data.data());
std::vector<float> x(dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(x_, dims_, x.data());
}
};
......@@ -107,16 +101,15 @@ TEST(Dropout, precision) {
return;
#endif
std::vector<std::vector<int64_t>> dims{
/*{3} ,*/ {3, 4} /*, {3, 4, 5}, {1, 2, 3, 4}, {2, 3, 4, 5}*/};
for (auto dim : dims) {
for (auto dropout_prob : {/*0.,*/ 0.5 /*, 1.*/}) {
for (auto dims : std::vector<std::vector<int64_t>>{
{3}, {3, 4}, {3, 4, 5}, {1, 2, 3, 4}, {2, 3, 4, 5}}) {
for (auto dropout_prob : {0., 0.5, 1.}) {
for (auto dropout_implementation :
{"downgrade_in_infer", "upscale_in_train"}) {
std::unique_ptr<arena::TestCase> tester(
new DropoutComputeTester(place,
"def",
DDim(dim),
DDim(dims),
dropout_prob,
true,
1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册