未验证 提交 db0ea0ce 编写于 作者: Y ykkk2333 提交者: GitHub

add masked_select_grad kernel (#48137)

* add stat tool

* add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun

* add masked_selected_grad kernel,test=kunlun
上级 a0f47350
......@@ -123,13 +123,17 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"concat",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"conv2d_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv3d",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_transpose_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"conv2d_transpose",
......@@ -375,6 +379,12 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"masked_select_grad",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......@@ -502,6 +512,9 @@ XPUOpMap& get_kl2_ops() {
{"sgd",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"sgd_dense_param_sparse_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"sigmoid_cross_entropy_with_logits_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"sigmoid_cross_entropy_with_logits",
......
......@@ -50,6 +50,7 @@ void ConcatKernel(const Context& dev_ctx,
x[0]->dims().size()));
// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && x[0]->lod().size() > 0) {
size_t lod_size_0 = x[0]->lod().size();
size_t lod_size = lod_size_0;
......@@ -79,7 +80,9 @@ void ConcatKernel(const Context& dev_ctx,
}
}
}
dev_ctx.template Alloc<T>(out);
std::vector<std::vector<int>> xdims_list;
std::vector<const XPUType*> ptrs;
for (unsigned int i = 0; i < x.size(); ++i) {
......@@ -97,6 +100,7 @@ void ConcatKernel(const Context& dev_ctx,
PADDLE_ENFORCE_GT(xdims_list.size(),
0,
phi::errors::InvalidArgument("No tensor need concat"));
int r = xpu::concat<XPUType>(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType*>(out->data<T>()),
......@@ -107,5 +111,10 @@ void ConcatKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
concat, XPU, ALL_LAYOUT, phi::ConcatKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(concat,
XPU,
ALL_LAYOUT,
phi::ConcatKernel,
float,
phi::dtype::float16,
int64_t) {}
......@@ -131,9 +131,96 @@ void DepthwiseConvKernel(const Context& dev_ctx,
out);
}
template <typename T, typename Context>
void Conv3DKernel(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations_t,
const std::string& data_format,
DenseTensor* out) {
using XPUT = typename XPUTypeTrait<T>::Type;
std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t;
// The filter will be reshaped in the calculations,
// so here use an assignment operation,
// that avoids modifying the variable in the Scope.
dev_ctx.template Alloc<T>(out);
phi::DDim in_data_dims =
phi::slice_ddim(input.dims(), 2, input.dims().size());
phi::DDim filter_data_dims =
phi::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
int batch_size = static_cast<int>(input.dims()[0]);
int img_c = static_cast<int>(input.dims()[1]);
int img_d = static_cast<int>(input.dims()[2]);
int img_h = static_cast<int>(input.dims()[3]);
int img_w = static_cast<int>(input.dims()[4]);
int f = static_cast<int>(filter.dims()[0]);
bool is_ncdhw = true;
if (data_format == "NDHWC") {
img_c = static_cast<int>(input.dims()[4]);
img_d = static_cast<int>(input.dims()[1]);
img_h = static_cast<int>(input.dims()[2]);
img_w = static_cast<int>(input.dims()[3]);
is_ncdhw = false;
}
XPUT* output_data = reinterpret_cast<XPUT*>(out->data<T>());
const XPUT* filter_data = reinterpret_cast<const XPUT*>(filter.data<T>());
const XPUT* input_data = reinterpret_cast<const XPUT*>(input.data<T>());
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUT* filter_data_tmp;
const XPUT* filter_data_ptr = filter_data;
if (data_format == "NDHWC") {
filter_data_tmp = RAII_GUARD.alloc<XPUT>(filter.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp);
std::vector<int> filter_shape = phi::vectorize<int>(filter.dims());
int r = xpu::transpose<XPUT>(dev_ctx.x_context(),
filter_data,
filter_data_tmp,
filter_shape,
{0, 2, 3, 4, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
filter_data_ptr = reinterpret_cast<const XPUT*>(filter_data_tmp);
}
int r = xpu::conv3d<XPUT, XPUT, XPUT, int16_t>(dev_ctx.x_context(),
input_data,
filter_data_ptr,
output_data,
batch_size,
img_c,
img_d,
img_h,
img_w,
f,
ksize,
strides,
paddings,
dilations,
groups,
nullptr,
nullptr,
nullptr,
is_ncdhw);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d");
}
} // namespace phi
PD_REGISTER_KERNEL(
conv2d, XPU, ALL_LAYOUT, phi::ConvKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
depthwise_conv2d, XPU, ALL_LAYOUT, phi::DepthwiseConvKernel, float) {}
PD_REGISTER_KERNEL(
conv3d, XPU, ALL_LAYOUT, phi::Conv3DKernel, float, phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/masked_select_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MaskedSelectGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* mask_data = mask.data<bool>();
auto* input_data = reinterpret_cast<const XPUType*>(out_grad.data<T>());
auto* out_data =
reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(x_grad));
auto mask_shape = phi::vectorize<int>(mask.dims());
auto xshape = phi::vectorize<int>(x_grad->dims());
int r = xpu::masked_select_grad(dev_ctx.x_context(),
input_data,
mask_data,
out_data,
xshape,
mask_shape,
1);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "masked_select_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(masked_select_grad,
XPU,
ALL_LAYOUT,
phi::MaskedSelectGradKernel,
float,
phi::dtype::float16,
int,
bool,
int64_t) {}
......@@ -20,14 +20,14 @@
namespace phi {
template <typename T, typename Context>
void SGDDenseKernel(const Context &dev_ctx,
const DenseTensor &param,
const DenseTensor &learning_rate,
const DenseTensor &grad,
const paddle::optional<DenseTensor> &master_param,
void SGDDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const DenseTensor& grad,
const paddle::optional<DenseTensor>& master_param,
bool multi_precision,
DenseTensor *param_out,
DenseTensor *master_param_out) {
DenseTensor* param_out,
DenseTensor* master_param_out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(
......@@ -49,37 +49,103 @@ void SGDDenseKernel(const Context &dev_ctx,
grad.numel(),
sz));
const T *lr_t = learning_rate.data<T>();
const T* lr_t = learning_rate.data<T>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const float *lr = nullptr;
const float* lr = nullptr;
if (std::is_same<T, dtype::float16>::value) {
float *lr_float = RAII_GUARD.alloc_l3_or_gm<float>(learning_rate.numel());
float* lr_float = RAII_GUARD.alloc_l3_or_gm<float>(learning_rate.numel());
int r = xpu::cast<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(lr_t),
reinterpret_cast<const XPUType*>(lr_t),
lr_float,
learning_rate.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
lr = lr_float;
} else {
lr = reinterpret_cast<const float *>(lr_t);
lr = reinterpret_cast<const float*>(lr_t);
}
const T *param_data = param.data<T>();
const T *grad_data = grad.data<T>();
const T* param_data = param.data<T>();
const T* grad_data = grad.data<T>();
dev_ctx.template Alloc<T>(param_out);
T *out_data = param_out->data<T>();
T* out_data = param_out->data<T>();
int r = xpu::sgd(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(grad_data),
reinterpret_cast<const XPUType *>(param_data),
reinterpret_cast<const XPUType*>(grad_data),
reinterpret_cast<const XPUType*>(param_data),
lr,
reinterpret_cast<XPUType *>(out_data),
reinterpret_cast<XPUType*>(out_data),
sz);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd");
}
template <typename T, typename Context>
void SGDDenseParamSparseGradKernel(
const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& learning_rate,
const SelectedRows& grad,
const paddle::optional<DenseTensor>& master_param,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* master_param_out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(param_out);
PADDLE_ENFORCE_EQ(
&param,
param_out,
phi::errors::InvalidArgument(
"The input tensor Param of SgdOp should be equal with ParamOut "
"if variable's type is SelectedRows."));
auto in_height = grad.height();
auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(in_height,
out_dims[0],
phi::errors::InvalidArgument(
"The input tensor Grad's height of SgdOp should be "
"equal with ParamOut's dims. But received Grad's "
"height [%s] and ParamOut's dims [%s]",
in_height,
out_dims[0]));
auto& in_value = grad.value();
auto& in_rows = grad.rows();
int64_t* in_rows_data = nullptr;
xpu::VectorParam<int64_t> in_rows_vec{
in_rows.data(), static_cast<int>(in_rows.size()), in_rows_data};
int64_t in_row_numel = in_value.numel() / in_rows.size();
PADDLE_ENFORCE_EQ(in_row_numel,
param_out->numel() / in_height,
phi::errors::InvalidArgument(
"The in_row_numel of SgdOp should be equal with "
"param_out's numel / in_height."));
auto* in_data = in_value.data<T>();
auto* out_data = param_out->data<T>();
int r = xpu::sparse_sgd<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(in_data),
reinterpret_cast<const XPUType*>(param.data<T>()),
learning_rate.data<float>(),
in_rows_vec,
reinterpret_cast<XPUType*>(out_data),
in_row_numel,
in_rows.size());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sparse_sgd");
}
} // namespace phi
PD_REGISTER_KERNEL(
sgd, XPU, ALL_LAYOUT, phi::SGDDenseKernel, phi::dtype::float16, float) {}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
XPU,
ALL_LAYOUT,
phi::SGDDenseParamSparseGradKernel,
phi::dtype::float16,
float) {}
......@@ -58,6 +58,9 @@ class XPUTestMaskedSelectOp(XPUOpTestWrapper):
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Y')
def init(self):
self.shape = (50, 3)
......
......@@ -19,6 +19,8 @@ import sys
sys.path.append("..")
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
......@@ -83,6 +85,77 @@ class TestSGDOpWithLargeInput(unittest.TestCase):
result = exe.run(fluid.default_main_program(), fetch_list=[avg_cost])
class TestSparseSGDOp(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
self.conf()
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), self.row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
# create and initialize Param Variable
param = scope.var('Param').get_tensor()
param_array = np.full((height, self.row_numel), 5.0).astype("float32")
param.set(param_array, place)
# create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run sgd operator
sgd_op = Operator(
"sgd",
Param='Param',
Grad='Grad',
ParamOut='Param',
LearningRate='LearningRate',
)
sgd_op.run(scope, place)
# get and compare result
result_array = np.array(param)
# rows[0] = 0, 5.0 - 2.0 * 2.0
self.assertAlmostEqual(1.0, result_array[rows[0], 0])
# rows[0] = 0, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[0], 2])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[1, 0])
# rows[1] = 4, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[1], 10])
# 5.0 - 2.0 * 0.0
self.assertAlmostEqual(5.0, result_array[5, 8])
# rows[2] = 7, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[2], 1])
# rows[2] = 7, 5.0 - 2.0 * 4.0
self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
def test_sparse_sgd(self):
places = [core.XPUPlace(0)]
for place in places:
self.check_with_place(place)
def conf(self):
self.row_numel = 12
class TestSparseSGDOpCase8X(TestSparseSGDOp):
def conf(self):
self.row_numel = 16
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册