diff --git a/lite/kernels/mlu/bridges/argmax_op.cc b/lite/kernels/mlu/bridges/argmax_op.cc index 8ad5941c4004ac5f4fe0320c2e1d27540436e38f..11ef93b7d29df3cad9275ac524103f44cfc6c183 100644 --- a/lite/kernels/mlu/bridges/argmax_op.cc +++ b/lite/kernels/mlu/bridges/argmax_op.cc @@ -66,6 +66,6 @@ int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // namespace lite } // namespace paddle -REGISTER_SUBGRAPH_BRIDGE(argmax, +REGISTER_SUBGRAPH_BRIDGE(arg_max, kMLU, paddle::lite::subgraph::mlu::ArgmaxConverter); diff --git a/lite/kernels/mlu/bridges/argmax_op_test.cc b/lite/kernels/mlu/bridges/argmax_op_test.cc index 8c6915d24873754b07c4724597f541d283858565..9eeb172812b8deecd6a8f1f2eb321ade4289fa9b 100644 --- a/lite/kernels/mlu/bridges/argmax_op_test.cc +++ b/lite/kernels/mlu/bridges/argmax_op_test.cc @@ -88,7 +88,7 @@ void test_argmax(const std::vector& input_shape, int axis) { FillTensor(x, -9, 9); // initialize op desc cpp::OpDesc opdesc; - opdesc.SetType("argmax"); + opdesc.SetType("arg_max"); opdesc.SetInput("X", {x_var_name}); opdesc.SetOutput("Out", {out_var_name}); opdesc.SetAttr("axis", static_cast(axis)); @@ -131,7 +131,7 @@ void test_argmax(const std::vector& input_shape, int axis) { } } -TEST(MLUBridges, argmax) { +TEST(MLUBridges, arg_max) { test_argmax({1, 2, 3, 4}, 1); test_argmax({1, 2, 3, 4}, 2); test_argmax({1, 2, 3, 4}, 3); @@ -142,4 +142,4 @@ TEST(MLUBridges, argmax) { } // namespace lite } // namespace paddle -USE_SUBGRAPH_BRIDGE(argmax, kMLU); +USE_SUBGRAPH_BRIDGE(arg_max, kMLU); diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index f286bb66fdde422d9900ee358c7834b8f15c4bf9..703687df875e86e65e007bddca62ef159bb8707e 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -32,7 +32,7 @@ USE_SUBGRAPH_BRIDGE(scale, kMLU); USE_SUBGRAPH_BRIDGE(sigmoid, kMLU); USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU); USE_SUBGRAPH_BRIDGE(dropout, kMLU); -USE_SUBGRAPH_BRIDGE(argmax, kMLU); +USE_SUBGRAPH_BRIDGE(arg_max, kMLU); USE_SUBGRAPH_BRIDGE(split, kMLU); USE_SUBGRAPH_BRIDGE(cast, kMLU); USE_SUBGRAPH_BRIDGE(layout, kMLU);