diff --git a/mace/ops/gather.cc b/mace/ops/gather.cc index f8ceb54383cb21a45a63d2797ec7346a94fc020a..4357e466df43c6c9e7d5338fdd80a897f79d828f 100644 --- a/mace/ops/gather.cc +++ b/mace/ops/gather.cc @@ -20,15 +20,11 @@ namespace mace { namespace ops { template -class GatherOp; - -template <> -class GatherOp : public Operation { +class GatherOp : public Operation { public: explicit GatherOp(OpConstructContext *context) : Operation(context), - axis_(Operation::GetOptionalArg("axis", 0)), - y_(Operation::GetOptionalArg("y", 1.0)) {} + axis_(Operation::GetOptionalArg("axis", 0)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); @@ -54,8 +50,8 @@ class GatherOp : public Operation { Tensor::MappingGuard params_guard(params); Tensor::MappingGuard output_guard(output); const int32_t *indices_data = indices->data(); - const float *params_data = params->data(); - float *output_data = output->mutable_data(); + const T *params_data = params->data(); + T *output_data = output->mutable_data(); index_t axis_dim_size = params->dim(axis_); index_t lhs_size = std::accumulate(params->shape().begin(), @@ -74,23 +70,18 @@ class GatherOp : public Operation { memcpy( output_data + ((l * index_size) + idx) * rhs_size, params_data + ((l * axis_dim_size) + indices_data[idx]) * rhs_size, - sizeof(float) * rhs_size); + sizeof(T) * rhs_size); } } - if (std::fabs(y_ - 1.0) > 1e-6) { -#pragma omp parallel for - for (index_t i = 0; i < output->size(); ++i) { - output_data[i] *= y_; - } - } + output->SetScale(params->scale()); + output->SetZeroPoint(params->zero_point()); return MaceStatus::MACE_SUCCESS; } private: int axis_; - float y_; MACE_OP_INPUT_TAGS(PARAMS, INDICES); MACE_OP_OUTPUT_TAGS(OUTPUT); }; @@ -98,6 +89,8 @@ class GatherOp : public Operation { void RegisterGather(OpRegistryBase *op_registry) { MACE_REGISTER_OP(op_registry, "Gather", GatherOp, DeviceType::CPU, float); + MACE_REGISTER_OP(op_registry, "Gather", GatherOp, + DeviceType::CPU, uint8_t); } } // namespace ops diff --git a/mace/ops/gather_test.cc b/mace/ops/gather_test.cc index 2da0338ba667e0161a4bd81ac3c9fef10ac849d1..e5cbdb26bf550206f12f42f3df97d59a7f05f841 100644 --- a/mace/ops/gather_test.cc +++ b/mace/ops/gather_test.cc @@ -28,7 +28,6 @@ void TestGather(const std::vector &weight_shape, const std::vector &input_shape, const std::vector &input, const int axis, - const float y, const std::vector &output_shape, const std::vector &output) { OpsTestNet net; @@ -40,7 +39,6 @@ void TestGather(const std::vector &weight_shape, .Input("Params") .Input("Indices") .AddIntArg("axis", axis) - .AddFloatArg("y", y) .Output("Output") .Finalize(net.NewOperatorDef()); // Run @@ -55,25 +53,25 @@ void TestGather(const std::vector &weight_shape, TEST_F(GatherOpTest, CPUScalarIndex) { TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, - {}, {5}, 0, 2.0, {2}, {20, 22}); + {}, {5}, 0, {2}, {10, 11}); } TEST_F(GatherOpTest, CPURank1Index) { TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, - {3}, {2, 4, 6}, 0, 1.0, {3, 2}, {4, 5, 8, 9, 12, 13}); + {3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13}); } TEST_F(GatherOpTest, CPURank1IndexWithAxis1) { TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, - {1}, {1}, 1, 1.0, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); + {1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19}); } TEST_F(GatherOpTest, CPURankHighIndex) { TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, - {1, 3}, {2, 4, 6}, 0, 1.0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); + {1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13}); } } // namespace test diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index fe030985efc7484a648dc98c81c3d1c438140a27..6b5d227eb5e5e6967b7510602e7d03fd9ef033c4 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -222,6 +222,7 @@ class TransformerRule(Enum): FOLD_DECONV_AND_BN = 32 FOLD_SQRDIFF_MEAN = 33 TRANSPOSE_MATMUL_WEIGHT = 34 + FOLD_EMBEDDING_LOOKUP = 35 class ConverterInterface(object): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 107aab023a119ea524f9f06aac0390b94fac8f82..7d6893442acf982529ea5c2b38425bf33795eb5a 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -79,6 +79,7 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv, TransformerRule.FOLD_ACTIVATION: self.fold_activation, TransformerRule.FOLD_SQRDIFF_MEAN: self.fold_squared_diff_mean, + TransformerRule.FOLD_EMBEDDING_LOOKUP: self.fold_embedding_lookup, TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters, TransformerRule.TRANSPOSE_MATMUL_WEIGHT: self.transpose_matmul_weight, @@ -392,6 +393,27 @@ class Transformer(base_converter.ConverterInterface): return False + def fold_embedding_lookup(self): + net = self._model + for op in net.op: + # gather -> mul + if (op.type == MaceOp.Gather.name and + self.consumer_count(op.output[0]) == 1): + consumer_op = self._consumers[op.output[0]][0] + if (consumer_op.type == MaceOp.Eltwise.name and + ConverterUtil.get_arg(consumer_op, + MaceKeyword.mace_element_type_str).i == EltwiseType.PROD.value and # noqa + len(consumer_op.input) == 1 and + op.input[0] in self._consts and + self.consumer_count(op.input[0]) == 1): + print("Fold Gather and Mul: %s" % op.name) + gather_weights = self._consts[op.input[0]] + mul_weight = ConverterUtil.get_arg(consumer_op, + MaceKeyword.mace_scalar_input_str).f # noqa + gather_weights.float_data[:] = gather_weights.float_data * mul_weight # noqa + self.safe_remove_node(consumer_op, None, + remove_input_tensor=True) + def transform_lstmcell_zerostate(self): net = self._model