clip_op_xpu.cc 2.7 KB
Newer Older
W
wuyefeilin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

#ifdef PADDLE_WITH_XPU

#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class ClipXPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<Tensor>("X");
    auto* out = ctx.Output<Tensor>("Out");
    out->mutable_data<T>(ctx.GetPlace());

    auto max = static_cast<T>(ctx.Attr<float>("max"));
    if (ctx.HasInput("Max")) {
      Tensor max_cpu;
      auto* max_t = ctx.Input<Tensor>("Max");
      auto* max_data = max_t->data<T>();
      if (platform::is_xpu_place(max_t->place())) {
38 39
        paddle::framework::TensorCopySync(
            *max_t, platform::CPUPlace(), &max_cpu);
40 41 42 43 44 45 46 47 48 49 50
        max_data = max_cpu.data<T>();
      }
      max = max_data[0];
    }

    auto min = ctx.Attr<float>("min");
    if (ctx.HasInput("Min")) {
      Tensor min_cpu;
      auto* min_t = ctx.Input<Tensor>("Min");
      auto* min_data = min_t->data<T>();
      if (platform::is_xpu_place(min_t->place())) {
51 52
        paddle::framework::TensorCopySync(
            *min_t, platform::CPUPlace(), &min_cpu);
53 54 55 56 57 58 59 60 61
        min_data = min_cpu.data<T>();
      }
      min = min_data[0];
    }

    using XPUDataType = typename XPUTypeTrait<T>::Type;
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
    auto x_data = reinterpret_cast<const XPUDataType*>(x->data<T>());
    auto out_data = reinterpret_cast<XPUDataType*>(out->data<T>());
62 63
    int r = xpu::clip_v2(
        dev_ctx.x_context(), x_data, out_data, x->numel(), min, max);
64
    PADDLE_ENFORCE_EQ(
65 66
        r,
        XPU_SUCCESS,
67 68
        platform::errors::External("XPU API(clip_v2) return wrong "
                                   "value[%d %s]",
69 70
                                   r,
                                   XPUAPIErrorMsg[r]));
71 72 73 74 75 76 77 78 79 80 81 82
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_XPU_KERNEL(clip, ops::ClipXPUKernel<plat::XPUDeviceContext, float>);

#endif