提交 71266960 编写于 作者: 李寅

Fold embedding lookup

上级 fc7f4967
......@@ -20,15 +20,11 @@ namespace mace {
namespace ops {
template <DeviceType D, class T>
class GatherOp;
template <>
class GatherOp<DeviceType::CPU, float> : public Operation {
class GatherOp : public Operation {
public:
explicit GatherOp(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetOptionalArg<int>("axis", 0)),
y_(Operation::GetOptionalArg<float>("y", 1.0)) {}
axis_(Operation::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
......@@ -54,8 +50,8 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
Tensor::MappingGuard params_guard(params);
Tensor::MappingGuard output_guard(output);
const int32_t *indices_data = indices->data<int32_t>();
const float *params_data = params->data<float>();
float *output_data = output->mutable_data<float>();
const T *params_data = params->data<T>();
T *output_data = output->mutable_data<T>();
index_t axis_dim_size = params->dim(axis_);
index_t lhs_size = std::accumulate(params->shape().begin(),
......@@ -74,23 +70,18 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
memcpy(
output_data + ((l * index_size) + idx) * rhs_size,
params_data + ((l * axis_dim_size) + indices_data[idx]) * rhs_size,
sizeof(float) * rhs_size);
sizeof(T) * rhs_size);
}
}
if (std::fabs(y_ - 1.0) > 1e-6) {
#pragma omp parallel for
for (index_t i = 0; i < output->size(); ++i) {
output_data[i] *= y_;
}
}
output->SetScale(params->scale());
output->SetZeroPoint(params->zero_point());
return MaceStatus::MACE_SUCCESS;
}
private:
int axis_;
float y_;
MACE_OP_INPUT_TAGS(PARAMS, INDICES);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
......@@ -98,6 +89,8 @@ class GatherOp<DeviceType::CPU, float> : public Operation {
void RegisterGather(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Gather", GatherOp,
DeviceType::CPU, uint8_t);
}
} // namespace ops
......
......@@ -28,7 +28,6 @@ void TestGather(const std::vector<index_t> &weight_shape,
const std::vector<index_t> &input_shape,
const std::vector<int32_t> &input,
const int axis,
const float y,
const std::vector<index_t> &output_shape,
const std::vector<float> &output) {
OpsTestNet net;
......@@ -40,7 +39,6 @@ void TestGather(const std::vector<index_t> &weight_shape,
.Input("Params")
.Input("Indices")
.AddIntArg("axis", axis)
.AddFloatArg("y", y)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
......@@ -55,25 +53,25 @@ void TestGather(const std::vector<index_t> &weight_shape,
TEST_F(GatherOpTest, CPUScalarIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{}, {5}, 0, 2.0, {2}, {20, 22});
{}, {5}, 0, {2}, {10, 11});
}
TEST_F(GatherOpTest, CPURank1Index) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{3}, {2, 4, 6}, 0, 1.0, {3, 2}, {4, 5, 8, 9, 12, 13});
{3}, {2, 4, 6}, 0, {3, 2}, {4, 5, 8, 9, 12, 13});
}
TEST_F(GatherOpTest, CPURank1IndexWithAxis1) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1}, {1}, 1, 1.0, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19});
{1}, {1}, 1, {10, 1}, {1, 3, 5, 7, 9, 11, 13, 15, 17, 19});
}
TEST_F(GatherOpTest, CPURankHighIndex) {
TestGather({10, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
{1, 3}, {2, 4, 6}, 0, 1.0, {1, 3, 2}, {4, 5, 8, 9, 12, 13});
{1, 3}, {2, 4, 6}, 0, {1, 3, 2}, {4, 5, 8, 9, 12, 13});
}
} // namespace test
......
......@@ -222,6 +222,7 @@ class TransformerRule(Enum):
FOLD_DECONV_AND_BN = 32
FOLD_SQRDIFF_MEAN = 33
TRANSPOSE_MATMUL_WEIGHT = 34
FOLD_EMBEDDING_LOOKUP = 35
class ConverterInterface(object):
......
......@@ -79,6 +79,7 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv,
TransformerRule.FOLD_ACTIVATION: self.fold_activation,
TransformerRule.FOLD_SQRDIFF_MEAN: self.fold_squared_diff_mean,
TransformerRule.FOLD_EMBEDDING_LOOKUP: self.fold_embedding_lookup,
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.TRANSPOSE_MATMUL_WEIGHT:
self.transpose_matmul_weight,
......@@ -392,6 +393,27 @@ class Transformer(base_converter.ConverterInterface):
return False
def fold_embedding_lookup(self):
net = self._model
for op in net.op:
# gather -> mul
if (op.type == MaceOp.Gather.name and
self.consumer_count(op.output[0]) == 1):
consumer_op = self._consumers[op.output[0]][0]
if (consumer_op.type == MaceOp.Eltwise.name and
ConverterUtil.get_arg(consumer_op,
MaceKeyword.mace_element_type_str).i == EltwiseType.PROD.value and # noqa
len(consumer_op.input) == 1 and
op.input[0] in self._consts and
self.consumer_count(op.input[0]) == 1):
print("Fold Gather and Mul: %s" % op.name)
gather_weights = self._consts[op.input[0]]
mul_weight = ConverterUtil.get_arg(consumer_op,
MaceKeyword.mace_scalar_input_str).f # noqa
gather_weights.float_data[:] = gather_weights.float_data * mul_weight # noqa
self.safe_remove_node(consumer_op, None,
remove_input_tensor=True)
def transform_lstmcell_zerostate(self):
net = self._model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册