未验证 提交 63d9a175 编写于 作者: H haosicheng 提交者: GitHub

add temporal shift and grad *test=kunlun (#45300)

上级 0bf40070
......@@ -552,6 +552,10 @@ XPUOpMap& get_kl2_ops() {
{"tanh",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"temporal_shift",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"temporal_shift_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"tril_triu",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void TemporalShiftGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* x_grad) {
auto* input_grad = x_grad;
auto* output_grad = &out_grad;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = output_grad->dims()[0];
const int n = nt / t;
const int c = (data_layout == DataLayout::kNCHW ? output_grad->dims()[1]
: output_grad->dims()[3]);
const int h = (data_layout == DataLayout::kNCHW ? output_grad->dims()[2]
: output_grad->dims()[1]);
const int w = (data_layout == DataLayout::kNCHW ? output_grad->dims()[3]
: output_grad->dims()[2]);
DDim in_grad_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* output_grad_data = output_grad->data<T>();
input_grad->Resize(in_grad_dims);
T* input_grad_data = dev_ctx.template Alloc<T>(input_grad);
if (data_layout == DataLayout::kNCHW) {
int r = xpu::temporal_shift_grad(dev_ctx.x_context(),
output_grad_data,
input_grad_data,
n,
c,
h,
w,
t,
shift_ratio,
false);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "temporal_shift_grad");
} else {
int r = xpu::temporal_shift_grad(dev_ctx.x_context(),
output_grad_data,
input_grad_data,
n,
c,
h,
w,
t,
shift_ratio,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "temporal_shift_grad");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
temporal_shift_grad, XPU, ALL_LAYOUT, phi::TemporalShiftGradKernel, float) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/temporal_shift_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
namespace phi {
template <typename T, typename Context>
void TemporalShiftKernel(const Context& dev_ctx,
const DenseTensor& x,
int seg_num,
float shift_ratio,
const std::string& data_format_str,
DenseTensor* out) {
auto* input = &x;
auto* output = out;
int t = seg_num;
const DataLayout data_layout =
paddle::framework::StringToDataLayout(data_format_str);
const int nt = input->dims()[0];
const int n = nt / t;
const int c =
(data_layout == DataLayout::kNCHW ? input->dims()[1] : input->dims()[3]);
const int h =
(data_layout == DataLayout::kNCHW ? input->dims()[2] : input->dims()[1]);
const int w =
(data_layout == DataLayout::kNCHW ? input->dims()[3] : input->dims()[2]);
DDim out_dims =
(data_layout == DataLayout::kNCHW ? phi::make_ddim({nt, c, h, w})
: phi::make_ddim({nt, h, w, c}));
const T* input_data = input->data<T>();
output->Resize(out_dims);
T* output_data = dev_ctx.template Alloc<T>(output);
if (data_layout == DataLayout::kNCHW) {
int r = xpu::temporal_shift(dev_ctx.x_context(),
input_data,
output_data,
n,
c,
h,
w,
t,
shift_ratio,
false);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "temporal_shift");
} else {
int r = xpu::temporal_shift(dev_ctx.x_context(),
input_data,
output_data,
n,
c,
h,
w,
t,
shift_ratio,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "temporal_shift");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
temporal_shift, XPU, ALL_LAYOUT, phi::TemporalShiftKernel, float) {}
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.nn.functional as F
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
np.random.seed(10)
def temporal_shift(x, seg_num, shift_ratio, data_format):
if data_format == "NHWC":
x = np.transpose(x, (0, 3, 1, 2))
shape = x.shape
reshape_x = x.reshape((-1, seg_num, shape[1], shape[2], shape[3]))
pad_x = np.pad(reshape_x, ((0, 0), (1, 1), (0, 0), (0, 0), (0, 0)),
'constant')
c1 = int(shape[1] * shift_ratio)
c2 = int(shape[1] * 2 * shift_ratio)
slice1 = pad_x[:, :seg_num, :c1, :, :]
slice2 = pad_x[:, 2:seg_num + 2, c1:c2, :, :]
slice3 = pad_x[:, 1:seg_num + 1, c2:, :, :]
concat_x = np.concatenate([slice1, slice2, slice3], axis=2)
out = concat_x.reshape(shape)
if data_format == "NHWC":
out = np.transpose(out, (0, 2, 3, 1))
return out
class XPUTestTemporalShiftOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "temporal_shift"
self.use_dynamic_create_class = False
class TestXPUTemporalShift(XPUOpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'temporal_shift'
self.python_api = F.temporal_shift
self.use_xpu = True
x = np.random.random(self.x_shape).astype(self.dtype)
self.attrs = {
"seg_num": self.seg_num,
"shift_ratio": self.shift_ratio,
"data_format": self.data_format
}
self.inputs = {
"X": x,
}
output = temporal_shift(x, self.seg_num, self.shift_ratio,
self.data_format)
self.outputs = {"Out": output}
self.python_out_sig = ["Out"]
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
def initTestCase(self):
self.x_shape = (6, 4, 4, 4)
self.seg_num = 3
self.shift_ratio = 0.25
self.dtype = 'float32'
self.data_format = 'NCHW'
class TestXPUTemporalShift2(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (1, 1, 1, 1)
self.seg_num = 1
self.shift_ratio = 0.1
self.dtype = 'float32'
self.data_format = 'NCHW'
class TestXPUTemporalShift3(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (4, 9, 1, 1)
self.seg_num = 2
self.shift_ratio = 0.2
self.dtype = 'float32'
self.data_format = 'NCHW'
class TestXPUTemporalShift4(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (4, 1, 10, 10)
self.seg_num = 2
self.shift_ratio = 0.3
self.dtype = 'float32'
self.data_format = 'NCHW'
class TestXPUTemporalShift5(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (1, 1, 1, 1)
self.seg_num = 1
self.shift_ratio = 0.3
self.dtype = 'float32'
self.data_format = 'NHWC'
class TestXPUTemporalShift6(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (6, 5, 5, 1)
self.seg_num = 3
self.shift_ratio = 0.25
self.dtype = 'float32'
self.data_format = 'NHWC'
class TestXPUTemporalShift7(TestXPUTemporalShift):
def initTestCase(self):
self.x_shape = (9, 1, 1, 4)
self.seg_num = 3
self.shift_ratio = 0.45
self.dtype = 'float32'
self.data_format = 'NHWC'
support_types = get_xpu_op_support_types('temporal_shift')
for stype in support_types:
create_test_class(globals(), XPUTestTemporalShiftOp, stype)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册