// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifdef PADDLE_WITH_XPU #include #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/phi/kernels/impl/momentum_kernel_impl.h" namespace paddle { namespace operators { template class MergedMomentumOpXPUKernel : public framework::OpKernel { using XPUType = typename XPUTypeTrait::Type; public: void Compute(const framework::ExecutionContext& ctx) const override { T mu = static_cast(ctx.Attr("mu")); auto params = ctx.MultiInput("Param"); auto params_out = ctx.MultiOutput("ParamOut"); auto lr = ctx.Input("LearningRate"); int op_num = params.size(); auto velocity = ctx.MultiInput("Velocity"); auto grad = ctx.MultiInput("Grad"); auto velocity_out = ctx.MultiOutput("VelocityOut"); auto use_nesterov = ctx.Attr("use_nesterov"); auto regularization_method = ctx.Attr>("regularization_method"); auto regularization_coeff = ctx.Attr>("regularization_coeff"); std::vector param_list(op_num); std::vector velocity_list(op_num); std::vector grad_list(op_num); std::vector velocity_out_list(op_num); std::vector param_out_list(op_num); std::vector sizes(op_num); std::vector l2_weight_decay(op_num); if (op_num > 0) { for (int j = 0; j < op_num; j++) { param_list[j] = reinterpret_cast(const_cast(params[j]->data())); velocity_list[j] = reinterpret_cast(const_cast(velocity[j]->data())); grad_list[j] = reinterpret_cast(const_cast(grad[j]->data())); param_out_list[j] = reinterpret_cast(params_out[j]->data()); velocity_out_list[j] = reinterpret_cast(velocity_out[j]->data()); sizes[j] = static_cast(params[j]->numel()); if (regularization_method[j] != "l2_decay") { l2_weight_decay[j] = 0.0f; } else { l2_weight_decay[j] = static_cast(regularization_coeff[j]); } PADDLE_ENFORCE_EQ(params[j], params_out[j], platform::errors::InvalidArgument( "The size of Input(Param) and Output(ParamOut) " "must be the same Tensors.")); PADDLE_ENFORCE_EQ( velocity[j], velocity_out[j], platform::errors::InvalidArgument( "The size of Input(velocity) and Output(velocity) " "must be the same Tensors.")); } } else { return; } auto& dev_ctx = ctx.template device_context(); PADDLE_ENFORCE_EQ(op_num, params_out.size(), platform::errors::InvalidArgument( "The size of Output(ParamOut) must be equal to " "Input(Param), but got the size of Output(ParamOut) " "is %d, the size of Input(Param) is %d.", params_out.size(), op_num)); PADDLE_ENFORCE_EQ(op_num, velocity.size(), platform::errors::InvalidArgument( "The size of Output(Velocity) must be equal to " "Input(Param), but got the size of Output(Velocity) " "is %d, the size of Input(Param) is %d.", velocity.size(), op_num)); PADDLE_ENFORCE_EQ( op_num, velocity_out.size(), platform::errors::InvalidArgument( "The size of Output(VelocityOut) must be equal to " "Input(Param), but got the size of Output(VelocityOut) " "is %d, the size of Input(Param) is %d.", velocity_out.size(), op_num)); PADDLE_ENFORCE_EQ( op_num, grad.size(), platform::errors::InvalidArgument( "The size of Input(Grad) must be equal to Input(Param), but got " "the size of Input(Grad) is %d, the size of Input(Param) is %d.", grad.size(), op_num)); int r = xpu::merged_momentum(dev_ctx.x_context(), param_list, velocity_list, grad_list, param_out_list, velocity_out_list, l2_weight_decay, sizes, lr->data(), mu, use_nesterov); PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_momentum"); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( merged_momentum, ops::MergedMomentumOpXPUKernel, ops::MergedMomentumOpXPUKernel); #endif