fake_quantize_op.cc 19.3 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
视言's avatar
视言 已提交
20 21 22 23

namespace paddle {
namespace operators {

24 25 26 27 28
template <typename T>
struct Compare {
 public:
  bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); }
};
29 30 31 32 33

template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx, const T* in,
                  const int num, T* out) {
34
    *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
35 36 37 38 39
  }
};

template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx, const T* in,
                  const int num, const int channel, T* out) {
    const int channel_size = num / channel;
    for (int i = 0; i < channel; i++) {
      auto* start = in + i * channel_size;
      auto* end = in + (i + 1) * channel_size;
      out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
    }
  }
};

template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;

55 56 57 58 59 60 61 62 63 64
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    T s = scale.data<T>()[0];
    platform::Transform<platform::CPUDeviceContext> trans;
    trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
          out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
    auto out_e = framework::EigenVector<T>::Flatten(*out);
65
    out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
66 67 68 69 70
  }
};

template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int channel,
                  framework::Tensor* out) {
    auto* scale_data = scale.data<T>();
    auto* in_data = in.data<T>();
    auto* out_data = out->mutable_data<T>(ctx.GetPlace());
    const int channel_size = in.numel() / channel;
    platform::Transform<platform::CPUDeviceContext> trans;
    for (int i = 0; i < channel; i++) {
      T s = scale_data[i];
      auto* start = in_data + i * channel_size;
      auto* end = in_data + (i + 1) * channel_size;
      trans(ctx, start, end, out_data + i * channel_size,
            ClipFunctor<T>(-s, s));
    }
    for (int i = 0; i < channel; i++) {
      T s = scale_data[i];
      framework::Tensor one_channel_out = out->Slice(i, i + 1);
      auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
      out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
    }
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
                                               float>;

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
    T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
    int64_t it = iter.data<int64_t>()[0];
    int idx = it % window_size;
    T removed = scale_arr[idx];
    T cur = cur_scale.data<T>()[0];
    scale_arr[idx] = cur;

    T max = last_scale.data<T>()[0];
    if (max < cur) {
      max = cur;
    } else if (fabs(removed - max) < 1e-6) {
      int size = (it > window_size) ? window_size : it;
      FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
                                                         &max);
    }
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
  }
};

template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;

129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
    T accum = in_accum.data<T>()[0];
    T state = in_state.data<T>()[0];
    T scale = cur_scale[0];

    state = rate * state + 1;
    accum = rate * accum + scale;
    scale = accum / state;

    out_state->mutable_data<T>(ctx.GetPlace())[0] = state;
    out_accum->mutable_data<T>(ctx.GetPlace())[0] = accum;
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = scale;
  }
};

template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
                                               float>;

153
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
视言's avatar
视言 已提交
154
 public:
155 156 157 158
  FakeQuantizeAbsMaxOp(const std::string& type,
                       const framework::VariableNameMap& inputs,
                       const framework::VariableNameMap& outputs,
                       const framework::AttributeMap& attrs)
视言's avatar
视言 已提交
159 160
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

161
  void InferShape(framework::InferShapeContext* ctx) const override {
视言's avatar
视言 已提交
162 163 164 165
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of FakeQuantizeOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of FakeQuantizeOp should not be null.");
166 167
    PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
                   "Output(Scale) of FakeQuantizeOp should not be null.");
视言's avatar
视言 已提交
168
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
169
    ctx->SetOutputDim("OutScale", {1});
视言's avatar
视言 已提交
170 171
    ctx->ShareLoD("X", /*->*/ "Out");
  }
172 173 174 175

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
176 177
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.device_context());
178
  }
视言's avatar
视言 已提交
179 180
};

181
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
视言's avatar
视言 已提交
182 183
 public:
  void Make() override {
184 185 186 187 188
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
    AddOutput("OutScale", "(Tensor) Current scale");
视言's avatar
视言 已提交
189 190
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
191
        .AddCustomChecker([](const int& bit_length) {
视言's avatar
视言 已提交
192 193 194 195 196 197
          PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
                         "'bit_length' should be between 1 and 16.");
        });
    AddComment(R"DOC(
FakeQuantize operator

198
$$scale = max(abs(X))$$
199 200
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
201

202 203 204
)DOC");
  }
};
视言's avatar
视言 已提交
205

Z
Zhen Wang 已提交
206 207 208 209 210 211 212 213 214 215 216
class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of FakeChannelWiseQuantizeOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("Out"),
        "Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
    PADDLE_ENFORCE(
217 218
        ctx->HasOutput("OutScale"),
        "Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
Z
Zhen Wang 已提交
219
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
220
    ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
Z
Zhen Wang 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.GetPlace());
  }
};

class FakeChannelWiseQuantizeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
240
    AddOutput("OutScale", "(Tensor) Current channel wise scale");
Z
Zhen Wang 已提交
241 242 243 244 245 246 247 248 249 250 251
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
          PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
                         "'bit_length' should be between 1 and 16.");
        });
    AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.

$$scale_c = max(abs(X_c))$$
Z
Zhen Wang 已提交
252 253
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
Z
Zhen Wang 已提交
254
In above three formulas, the range value of c is as follow:
Z
Zhen Wang 已提交
255
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
Z
Zhen Wang 已提交
256 257 258 259
)DOC");
  }
};

260 261 262 263 264 265 266
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  FakeQuantizeRangeAbsMaxOp(const std::string& type,
                            const framework::VariableNameMap& inputs,
                            const framework::VariableNameMap& outputs,
                            const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}
视言's avatar
视言 已提交
267

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of FakeQuantizeRangeAbsMaxOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("Out"),
        "Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("OutScale"),
        "Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null");
    if (ctx->HasOutput("OutScales")) {
      int window_size = ctx->Attrs().Get<int>("window_size");
      ctx->SetOutputDim("OutScales", {window_size});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }
视言's avatar
视言 已提交
285

286 287 288
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
Y
Yu Yang 已提交
289 290
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.device_context());
291 292
  }
};
视言's avatar
视言 已提交
293

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
class FakeQuantizeRangeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("Iter", "Global step iteration.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
    AddAttr<int>("window_size", "(int, default 10000) window range size.")
        .SetDefault(10000);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
          PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
                         "'bit_length' should be between 1 and 16.");
        });
312 313 314 315
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
316 317
    AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
视言's avatar
视言 已提交
318

319
$$scale = max(max(abs(x)), history_abs_max)$$
320 321
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
322 323 324 325 326

)DOC");
  }
};

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel {
 public:
  FakeQuantizeMovingAverageAbsMaxOp(const std::string& type,
                                    const framework::VariableNameMap& inputs,
                                    const framework::VariableNameMap& outputs,
                                    const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(
        ctx->HasInput("X"),
        "Input(X) of FakeQuantizeMovingAverageAbsMaxOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("Out"),
        "Output(Out) of FakeQuantizeMovingAverageAbsMaxOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
                   "Output(OutScale) of FakeQuantizeMovingAverageAbsMaxOp "
                   "should not be null");
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.device_context());
  }
};

class FakeQuantizeMovingAverageAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
          PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
                         "'bit_length' should be between 1 and 16.");
        });
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
FakeQuantize operator is used in static quantization.

Z
Zhen Wang 已提交
391 392
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
393 394 395 396 397 398
$$Out = round(X/scale * range)$$

)DOC");
  }
};

Z
Zhen Wang 已提交
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    PADDLE_ENFORCE(
        ctx->HasInput("X"),
        "Input(X) of MovingAverageAbsMaxScaleOp should not be null.");
    PADDLE_ENFORCE(
        ctx->HasOutput("Out"),
        "Output(Out) of MovingAverageAbsMaxScaleOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
                   "Output(OutScale) of MovingAverageAbsMaxScaleOp"
                   "should not be null");
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
                                   ctx.GetPlace());
  }
};

class MovingAverageAbsMaxScaleOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("Out",
              "(Tensor) Output tensor is just equivalent to the input tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set true for inference only and false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
MovingAverageAbsMaxScale operator is only used for calculating the quantization scale.
And it will not quantize the input tensor.

$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$Out = X$$

)DOC");
  }
};

视言's avatar
视言 已提交
461 462 463 464
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
465 466 467 468 469 470 471
using CPU = paddle::platform::CPUDeviceContext;

REGISTER_OPERATOR(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp,
                  ops::FakeQuantizeAbsMaxOpMaker,
                  paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
                       ops::FakeQuantizeAbsMaxKernel<CPU, float>);
视言's avatar
视言 已提交
472

473 474
REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
                  ops::FakeQuantizeRangeAbsMaxOpMaker,
视言's avatar
视言 已提交
475
                  paddle::framework::EmptyGradOpMaker);
476 477
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
                       ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
478

479 480 481 482 483 484
REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
                  ops::FakeQuantizeMovingAverageAbsMaxOp,
                  ops::FakeQuantizeMovingAverageAbsMaxOpMaker,
                  paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
                       ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
485 486 487 488 489 490
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
                  ops::FakeChannelWiseQuantizeAbsMaxOp,
                  ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
                  paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
                       ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
491 492 493 494 495 496

REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
                  ops::MovingAverageAbsMaxScaleOpMaker,
                  paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
                       ops::MovingAverageAbsMaxScaleKernel<CPU, float>);