未验证 提交 7f3c7aeb 编写于 作者: Y ykkk2333 提交者: GitHub

migrate deformable_conv and merged momentum kernels to phi, test=kunlun (#45691)

上级 d8a09e25
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class DeformableConvXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* offset = ctx.Input<Tensor>("Offset");
auto* mask = ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1,
true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1,
true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8,
true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(phi::vectorize(output->dims()));
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset->data<T>();
const float* mask_ptr = mask->data<T>();
T* output_prt = output->data<T>();
// set zeros for d_table_data
const int zero = 0;
int r = xpu::constant<T>(
dev_ctx.x_context(), output_prt, output->numel(), zero);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS,
true,
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset->numel() / offset->dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv<float, float, float, int>(
dev_ctx.x_context(),
input_ptr + i * im2col_step * input_dim,
filter_ptr,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_prt + i * im2col_step * output_dim,
n,
c,
h,
w,
f,
ksize,
strides,
paddings,
dilations,
groups,
deformable_groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv kernel return wrong value[%d].", r));
}
}
};
template <typename DeviceContext, typename T>
class DeformableConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
Tensor* offset_grad = ctx.Output<Tensor>(framework::GradVarName("Offset"));
Tensor* mask_grad = ctx.Output<Tensor>(framework::GradVarName("Mask"));
T* dx_data = nullptr;
T* dw_data = nullptr;
T* dmask_data = nullptr;
T* doffset_data = nullptr;
if (input_grad != nullptr) {
input_grad->mutable_data<T>(ctx.GetPlace());
dx_data = input_grad->data<T>();
}
if (filter_grad != nullptr) {
filter_grad->mutable_data<T>(ctx.GetPlace());
dw_data = filter_grad->data<T>();
}
if (offset_grad != nullptr) {
offset_grad->mutable_data<T>(ctx.GetPlace());
doffset_data = offset_grad->data<T>();
}
if (mask_grad != nullptr) {
mask_grad->mutable_data<T>(ctx.GetPlace());
dmask_data = mask_grad->data<T>();
}
const Tensor* input = ctx.Input<Tensor>("Input");
Tensor offset = *ctx.Input<Tensor>("Offset");
Tensor mask = *ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
int groups = ctx.Attr<int>("groups");
int deformable_groups = ctx.Attr<int>("deformable_groups");
int im2col_step = ctx.Attr<int>("im2col_step");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1,
true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1,
true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8,
true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(phi::vectorize(output_grad->dims()));
const T* output_grad_ptr = output_grad->data<T>();
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset.data<float>();
const float* mask_ptr = mask.data<float>();
if (dx_data == nullptr) {
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&dx_data),
input->numel() * sizeof(T)),
XPU_SUCCESS,
platform::errors::ResourceExhausted("XPU has no enough memory"));
}
if (dw_data == nullptr) {
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&dw_data),
filter.numel() * sizeof(T)),
XPU_SUCCESS,
platform::errors::ResourceExhausted("XPU has no enough memory"));
}
if (doffset_data == nullptr) {
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&doffset_data),
offset.numel() * sizeof(T)),
XPU_SUCCESS,
platform::errors::ResourceExhausted("XPU has no enough memory"));
}
if (dmask_data == nullptr) {
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&dmask_data),
mask.numel() * sizeof(T)),
XPU_SUCCESS,
platform::errors::ResourceExhausted("XPU has no enough memory"));
}
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
T* filter_grad_tmp = nullptr;
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&filter_grad_tmp),
filter_grad->numel() * sizeof(T)),
XPU_SUCCESS,
platform::errors::ResourceExhausted("XPU has no enough memory"));
// set zeros for d_table_data
const int zero = 0;
int r_dx =
xpu::constant<T>(dev_ctx.x_context(), dx_data, input->numel(), zero);
int r_dw =
xpu::constant<T>(dev_ctx.x_context(), dw_data, filter.numel(), zero);
int r_doffset = xpu::constant<T>(
dev_ctx.x_context(), doffset_data, offset.numel(), zero);
int r_dmask =
xpu::constant<T>(dev_ctx.x_context(), dmask_data, mask.numel(), zero);
int r_filter = xpu::constant<T>(
dev_ctx.x_context(), filter_grad_tmp, filter.numel(), zero);
auto ret = (r_dx == xpu::Error_t::SUCCESS) && (r_dx == r_dw) &&
(r_dx == r_doffset) && (r_dx == r_dmask) && (r_dx == r_filter);
PADDLE_ENFORCE_EQ(ret,
true,
platform::errors::External(
"XPU API return wrong value, please check where "
"Baidu Kunlun Card is properly installed."));
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv_grad<float, float, float, int>(
dev_ctx.x_context(),
input_ptr + i * im2col_step * input_dim,
filter_ptr,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_grad_ptr + i * im2col_step * output_dim,
dx_data + i * im2col_step * input_dim,
filter_grad_tmp,
doffset_data + i * im2col_step * input_offset_dim,
dmask_data + i * im2col_step * input_mask_dim,
n,
c,
h,
w,
f,
ksize,
strides,
paddings,
dilations,
groups,
deformable_groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv_grad kernel return wrong value[%d].", r));
r = baidu::xpu::api::add<T>(dev_ctx.x_context(),
filter_grad_tmp,
dw_data,
dw_data,
filter.numel());
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
platform::errors::External(
"XPU add kernel return wrong value[%d].", r));
}
dev_ctx.Wait();
xpu_free(filter_grad_tmp);
if (input_grad == nullptr) {
xpu_free(dx_data);
}
if (filter_grad == nullptr) {
xpu_free(dw_data);
}
if (offset_grad == nullptr) {
xpu_free(doffset_data);
}
if (mask_grad == nullptr) {
xpu_free(dmask_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using XPUDeviceContext = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(deformable_conv,
ops::DeformableConvXPUKernel<XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
deformable_conv_grad,
ops::DeformableConvGradXPUKernel<XPUDeviceContext, float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <sys/syscall.h>
#include <unistd.h>
#include <iostream>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/phi/kernels/impl/momentum_kernel_impl.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MergedMomentumOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
auto lr = ctx.Input<framework::Tensor>("LearningRate");
int op_num = params.size();
auto velocity = ctx.MultiInput<framework::Tensor>("Velocity");
auto grad = ctx.MultiInput<framework::Tensor>("Grad");
auto velocity_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
auto regularization_method =
ctx.Attr<std::vector<std::string>>("regularization_method");
auto regularization_coeff =
ctx.Attr<std::vector<float>>("regularization_coeff");
PADDLE_ENFORCE_EQ(op_num,
params_out.size(),
platform::errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
op_num));
PADDLE_ENFORCE_EQ(op_num,
velocity.size(),
platform::errors::InvalidArgument(
"The size of Output(Velocity) must be equal to "
"Input(Param), but got the size of Output(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocity.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
velocity_out.size(),
platform::errors::InvalidArgument(
"The size of Output(VelocityOut) must be equal to "
"Input(Param), but got the size of Output(VelocityOut) "
"is %d, the size of Input(Param) is %d.",
velocity_out.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
grad.size(),
platform::errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grad.size(),
op_num));
if (regularization_method.size() == 0) {
regularization_method.resize(op_num);
}
std::vector<XPUType*> param_list(op_num);
std::vector<XPUType*> velocity_list(op_num);
std::vector<XPUType*> grad_list(op_num);
std::vector<XPUType*> velocity_out_list(op_num);
std::vector<XPUType*> param_out_list(op_num);
std::vector<int> sizes(op_num);
std::vector<float> l2_weight_decay(op_num);
if (op_num > 0) {
for (int j = 0; j < op_num; j++) {
param_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(params[j]->data<T>()));
velocity_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(velocity[j]->data<T>()));
grad_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(grad[j]->data<T>()));
param_out_list[j] =
reinterpret_cast<XPUType*>(params_out[j]->data<T>());
velocity_out_list[j] =
reinterpret_cast<XPUType*>(velocity_out[j]->data<T>());
sizes[j] = static_cast<int>(params[j]->numel());
if (regularization_method[j] != "l2_decay") {
l2_weight_decay[j] = 0.0f;
} else {
l2_weight_decay[j] = static_cast<float>(regularization_coeff[j]);
}
PADDLE_ENFORCE_EQ(params[j],
params_out[j],
platform::errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_EQ(
velocity[j],
velocity_out[j],
platform::errors::InvalidArgument(
"The size of Input(velocity) and Output(velocity) "
"must be the same Tensors."));
}
} else {
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::merged_momentum(dev_ctx.x_context(),
param_list,
velocity_list,
grad_list,
param_out_list,
velocity_out_list,
l2_weight_decay,
sizes,
lr->data<float>(),
mu,
use_nesterov);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_momentum");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
merged_momentum,
ops::MergedMomentumOpXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::MergedMomentumOpXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DeformableConvGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
const paddle::optional<DenseTensor>& mask,
const DenseTensor& out_grad,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* dx,
DenseTensor* offset_grad,
DenseTensor* filter_grad,
DenseTensor* mask_grad) {
T* dx_data = nullptr;
T* dw_data = nullptr;
T* dmask_data = nullptr;
T* doffset_data = nullptr;
if (dx != nullptr) {
dx_data = dev_ctx.template Alloc<T>(dx);
}
if (filter_grad != nullptr) {
dw_data = dev_ctx.template Alloc<T>(filter_grad);
}
if (offset_grad != nullptr) {
doffset_data = dev_ctx.template Alloc<T>(offset_grad);
}
if (mask_grad != nullptr) {
dmask_data = dev_ctx.template Alloc<T>(mask_grad);
}
PADDLE_ENFORCE_EQ(
deformable_groups == 1,
true,
errors::InvalidArgument(
("XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1,
true,
errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8,
true,
errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
const int batch_size = static_cast<int>(x.dims()[0]);
std::vector<int64_t> output_shape_vec(phi::vectorize(out_grad.dims()));
const T* output_grad_ptr = out_grad.data<T>();
const T* input_ptr = x.data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset.data<float>();
const float* mask_ptr = mask->data<float>();
if (dx_data == nullptr) {
PADDLE_ENFORCE_EQ(
xpu_malloc(reinterpret_cast<void**>(&dx_data), x.numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
}
if (dw_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dw_data),
filter.numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
}
if (doffset_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&doffset_data),
offset.numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
}
if (dmask_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dmask_data),
mask->numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
}
int input_dim = x.numel() / x.dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = x.dims()[1];
int h = x.dims()[2];
int w = x.dims()[3];
int f = filter.dims()[0];
T* filter_grad_tmp = nullptr;
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&filter_grad_tmp),
filter_grad->numel() * sizeof(T)),
XPU_SUCCESS,
errors::ResourceExhausted("XPU has no enough memory"));
// set zeros for d_table_data
const int zero = 0;
int r_dx = xpu::constant<T>(dev_ctx.x_context(), dx_data, x.numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r_dx, "constant");
int r_dw =
xpu::constant<T>(dev_ctx.x_context(), dw_data, filter.numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r_dw, "constant");
int r_doffset =
xpu::constant<T>(dev_ctx.x_context(), doffset_data, offset.numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r_doffset, "constant");
int r_dmask =
xpu::constant<T>(dev_ctx.x_context(), dmask_data, mask->numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r_dmask, "constant");
int r_filter = xpu::constant<T>(
dev_ctx.x_context(), filter_grad_tmp, filter.numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r_filter, "constant");
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv_grad<float, float, float, int>(
dev_ctx.x_context(),
input_ptr + i * im2col_step * input_dim,
filter_ptr,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_grad_ptr + i * im2col_step * output_dim,
dx_data + i * im2col_step * input_dim,
filter_grad_tmp,
doffset_data + i * im2col_step * input_offset_dim,
dmask_data + i * im2col_step * input_mask_dim,
n,
c,
h,
w,
f,
ksize,
strides,
paddings,
dilations,
groups,
deformable_groups,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "deformable_conv_grad");
r = baidu::xpu::api::add<T>(
dev_ctx.x_context(), filter_grad_tmp, dw_data, dw_data, filter.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add");
}
dev_ctx.Wait();
xpu_free(filter_grad_tmp);
if (dx == nullptr) {
xpu_free(dx_data);
}
if (filter_grad == nullptr) {
xpu_free(dw_data);
}
if (offset_grad == nullptr) {
xpu_free(doffset_data);
}
if (mask_grad == nullptr) {
xpu_free(dmask_data);
}
}
} // namespace phi
PD_REGISTER_KERNEL(deformable_conv_grad,
XPU,
ALL_LAYOUT,
phi::DeformableConvGradKernel,
float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/deformable_conv_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void DeformableConvKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& offset,
const DenseTensor& filter,
const paddle::optional<DenseTensor>& mask,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
int deformable_groups,
int groups,
int im2col_step,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
PADDLE_ENFORCE_EQ(
deformable_groups == 1,
true,
errors::InvalidArgument(
("XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1,
true,
errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8,
true,
errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
const int batch_size = static_cast<int>(x.dims()[0]);
std::vector<int64_t> output_shape_vec(phi::vectorize(out->dims()));
const T* input_ptr = x.data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset.data<T>();
const float* mask_ptr = mask->data<T>();
T* output_prt = out->data<T>();
// set zeros for d_table_data
const int zero = 0;
int r = xpu::constant<T>(dev_ctx.x_context(), output_prt, out->numel(), zero);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
int input_dim = x.numel() / x.dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = x.dims()[1];
int h = x.dims()[2];
int w = x.dims()[3];
int f = filter.dims()[0];
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv<float, float, float, int>(
dev_ctx.x_context(),
input_ptr + i * im2col_step * input_dim,
filter_ptr,
offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_prt + i * im2col_step * output_dim,
n,
c,
h,
w,
f,
ksize,
strides,
paddings,
dilations,
groups,
deformable_groups,
nullptr,
nullptr,
nullptr,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "deformable_conv");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
deformable_conv, XPU, ALL_LAYOUT, phi::DeformableConvKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/syscall.h>
#include <unistd.h>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/phi/kernels/merged_momentum_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MergedMomentumKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& params,
const std::vector<const DenseTensor*>& grad,
const std::vector<const DenseTensor*>& velocity,
const std::vector<const DenseTensor*>& learning_rate,
const paddle::optional<std::vector<const DenseTensor*>>& master_param,
float mu_in,
bool use_nesterov,
const std::vector<std::string>& regularization_method,
const std::vector<float>& regularization_coeff,
bool multi_precision,
float rescale_grad,
std::vector<DenseTensor*> params_out,
std::vector<DenseTensor*> velocity_out,
std::vector<DenseTensor*> master_param_out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto lr = learning_rate[0];
T mu = static_cast<T>(mu_in);
int op_num = params.size();
PADDLE_ENFORCE_EQ(op_num,
params_out.size(),
errors::InvalidArgument(
"The size of Output(ParamOut) must be equal to "
"Input(Param), but got the size of Output(ParamOut) "
"is %d, the size of Input(Param) is %d.",
params_out.size(),
op_num));
PADDLE_ENFORCE_EQ(op_num,
velocity.size(),
errors::InvalidArgument(
"The size of Output(Velocity) must be equal to "
"Input(Param), but got the size of Output(Velocity) "
"is %d, the size of Input(Param) is %d.",
velocity.size(),
op_num));
PADDLE_ENFORCE_EQ(op_num,
velocity_out.size(),
errors::InvalidArgument(
"The size of Output(VelocityOut) must be equal to "
"Input(Param), but got the size of Output(VelocityOut) "
"is %d, the size of Input(Param) is %d.",
velocity_out.size(),
op_num));
PADDLE_ENFORCE_EQ(
op_num,
grad.size(),
errors::InvalidArgument(
"The size of Input(Grad) must be equal to Input(Param), but got "
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
grad.size(),
op_num));
std::vector<XPUType*> param_list(op_num);
std::vector<XPUType*> velocity_list(op_num);
std::vector<XPUType*> grad_list(op_num);
std::vector<XPUType*> velocity_out_list(op_num);
std::vector<XPUType*> param_out_list(op_num);
std::vector<int> sizes(op_num);
std::vector<float> l2_weight_decay(op_num);
if (op_num > 0) {
for (int j = 0; j < op_num; j++) {
param_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(params[j]->data<T>()));
velocity_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(velocity[j]->data<T>()));
grad_list[j] =
reinterpret_cast<XPUType*>(const_cast<T*>(grad[j]->data<T>()));
param_out_list[j] = reinterpret_cast<XPUType*>(params_out[j]->data<T>());
velocity_out_list[j] =
reinterpret_cast<XPUType*>(velocity_out[j]->data<T>());
sizes[j] = static_cast<int>(params[j]->numel());
if (regularization_method[j] != "l2_decay") {
l2_weight_decay[j] = 0.0f;
} else {
l2_weight_decay[j] = static_cast<float>(regularization_coeff[j]);
}
PADDLE_ENFORCE_EQ(params[j],
params_out[j],
errors::InvalidArgument(
"The size of Input(Param) and Output(ParamOut) "
"must be the same Tensors."));
PADDLE_ENFORCE_EQ(velocity[j],
velocity_out[j],
errors::InvalidArgument(
"The size of Input(velocity) and Output(velocity) "
"must be the same Tensors."));
}
} else {
return;
}
int r = xpu::merged_momentum(dev_ctx.x_context(),
param_list,
velocity_list,
grad_list,
param_out_list,
velocity_out_list,
l2_weight_decay,
sizes,
lr->data<float>(),
mu,
use_nesterov);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merged_momentum");
}
} // namespace phi
PD_REGISTER_KERNEL(merged_momentum,
XPU,
ALL_LAYOUT,
phi::MergedMomentumKernel,
float,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册