提交 3153a201 编写于 作者: D dingminghui 提交者: jackzhang235

feat(leaky_relu): support leaky_relu

上级 2042a830
......@@ -31,20 +31,34 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create act node and set params from op
auto fp_type = graph->FPType();
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, graph->FPType());
out_var_name, output_dims, CNML_TENSOR, CNML_NHWC, fp_type);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlActiveFunction_t act_type = OpTypeToCNMLActType(op_type);
cnmlBaseOp_t activation_op;
CNML_CALL(cnmlCreateActiveOp(&activation_op,
act_type,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
std::vector<int64_t> shape = {1, 1, 1, 1};
std::string alpha_var_name = string_format("leaky_relu_alpha_%p", op);
auto alpha_tensor =
graph->AddNode(alpha_var_name, shape, CNML_CONST, CNML_NHWC, fp_type);
graph->BindConstRawData(alpha_var_name, &alpha, 1, true);
CNML_CALL(cnmlCreatePreluOp(&activation_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
alpha_tensor->mlu_tensor()));
} else {
cnmlActiveFunction_t act_type = OpTypeToCNMLActType(op_type);
CNML_CALL(cnmlCreateActiveOp(&activation_op,
act_type,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
}
graph->FuseOp(activation_op);
return SUCCESS;
}
......@@ -59,3 +73,6 @@ REGISTER_SUBGRAPH_BRIDGE(sigmoid,
paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu, kMLU, paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(tanh, kMLU, paddle::lite::subgraph::mlu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
kMLU,
paddle::lite::subgraph::mlu::ActConverter);
......@@ -134,7 +134,7 @@ void test_act(std::vector<int64_t> x_shape, std::string op_type) {
TEST(MLUBridges, activation) {
std::vector<std::vector<int64_t>> shapes{{1}, {2, 3}, {1, 2, 3, 4}};
std::vector<std::string> types{"sigmoid", "relu", "tanh"};
std::vector<std::string> types{"sigmoid", "relu", "tanh", "leaky_relu"};
for (auto x_shape : shapes) {
for (auto op_type : types) {
test_act(x_shape, op_type);
......@@ -150,3 +150,4 @@ TEST(MLUBridges, activation) {
USE_SUBGRAPH_BRIDGE(sigmoid, kMLU)
USE_SUBGRAPH_BRIDGE(relu, kMLU)
USE_SUBGRAPH_BRIDGE(tanh, kMLU)
USE_SUBGRAPH_BRIDGE(leaky_relu, kMLU)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册