未验证 提交 d5323dab 编写于 作者: T TTerror 提交者: GitHub

add squeeze_op/unsqueeze_op on kunlun;fix conv op and parallel...

add squeeze_op/unsqueeze_op on kunlun;fix conv op and parallel executor;optimize lookup_table op (#31056)

* add squeeze_op/unsqueeze_op on kunlun; fix conv op and parallel executor on kunlun; optimize lookup_table op on kunlun

* update squeeze/unsqueeze op
上级 16b4260b
......@@ -11,7 +11,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_03.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_19.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
......@@ -1072,12 +1072,13 @@ void ParallelExecutor::BCastParamsToDevices(
platform::errors::Unavailable("bkcl_group_start failed"));
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]);
auto broadcast_numel = numel;
if (main_tensor.type() == framework::proto::VarType::INT64) {
numel *= 2;
broadcast_numel *= 2;
}
PADDLE_ENFORCE_EQ(
bkcl_broadcast(bkcl_ctx.comm(), buffers[i], buffers[i], numel,
data_type, 0, NULL),
bkcl_broadcast(bkcl_ctx.comm(), buffers[i], buffers[i],
broadcast_numel, data_type, 0, NULL),
BKCL_SUCCESS,
platform::errors::Unavailable("bkcl_broadcast failed"));
}
......
......@@ -32,20 +32,31 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ(data_format == "NHWC" || data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv op.")));
framework::DDim in_data_dims =
framework::slice_ddim(input->dims(), 2, input->dims().size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
const int win_h = static_cast<int>(filter.dims()[2]);
const int win_w = static_cast<int>(filter.dims()[3]);
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> k_size;
k_size.push_back(win_h);
k_size.push_back(win_w);
int r = xpu::conv2d<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, k_size,
output->data<float>(), batch_size, img_c, img_h, img_w, f, ksize,
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
......@@ -53,6 +64,7 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
r, XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
class GemmConvGradXPUKernel : public framework::OpKernel<T> {
public:
......@@ -73,13 +85,28 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const std::string data_format = context.Attr<std::string>("data_format");
const std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ(
data_format == "NHWC" || data_format == "NDHWC", false,
platform::errors::InvalidArgument(
("XPU do support data_format is NCHW in conv grad op.")));
framework::DDim in_data_dims =
framework::slice_ddim(input->dims(), 2, input->dims().size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
const int batch_size = static_cast<int>(input->dims()[0]);
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
const int f = static_cast<int>(filter.dims()[0]);
const int win_h = static_cast<int>(filter.dims()[2]);
const int win_w = static_cast<int>(filter.dims()[3]);
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
}
......@@ -87,14 +114,11 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace());
}
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> k_size;
k_size.push_back(win_h);
k_size.push_back(win_w);
int r = xpu::conv2d_grad<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(),
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c,
img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr,
img_h, img_w, f, ksize, strides, paddings, dilations, groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
......
......@@ -17,11 +17,10 @@ limitations under the License. */
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
#ifdef PADDLE_WITH_XPU
template <typename DeviceContext, typename T>
class LookupTableV2XPUKernel : public framework::OpKernel<T> {
public:
......@@ -96,26 +95,19 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2GradXPUKernel."));
int ids_numel_int32 = static_cast<int>(ids_numel);
const int64_t *ids_data = ids_t->data<int64_t>();
int D = d_table_t->dims()[1];
auto &dev_ctx = context.template device_context<DeviceContext>();
const int64_t *ids_data = ids_t->data<int64_t>();
const T *d_output_data = d_output_t->data<T>();
T *d_table_data = d_table_t->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
// set zeros for d_table_data
const int zero = 0;
int r = xpu::memset(dev_ctx.x_context(), d_table_data, zero,
d_table_t->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
r = xpu::embedding_backward<T, int64_t>(dev_ctx.x_context(),
ids_numel_int32, ids_data, D,
d_output_data, d_table_data);
int xm = d_table_t->dims()[0];
int ym = static_cast<int>(ids_numel);
int n = d_table_t->dims()[1];
int padding_idx = context.Attr<int64_t>("padding_idx");
int r = xpu::embedding_grad<T, int64_t>(dev_ctx.x_context(), d_output_data,
ids_data, d_table_data, xm, n, ym,
padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
......@@ -123,13 +115,10 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
r));
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
lookup_table_v2,
ops::LookupTableV2XPUKernel<paddle::platform::XPUDeviceContext, float>);
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
#ifdef PADDLE_WITH_XPU
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::XPUDeviceContext, float>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, double>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
#ifdef PADDLE_WITH_XPU
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::XPUDeviceContext, int64_t>);
REGISTER_OP_XPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
# Correct: General.
class TestSqueezeOp(XPUOpTest):
def setUp(self):
self.op_type = "squeeze2"
self.use_xpu = True
self.use_mkldnn = False
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['XShape'])
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = (0, -2)
self.new_shape = (20, 5)
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
self.new_shape = (20, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (6, 5, 1, 4)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
# Correct: General.
class TestSqueezeOp(XPUOpTest):
def setUp(self):
self.op_type = "squeeze"
self.use_xpu = True
self.use_mkldnn = False
self.init_test_case()
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape), }
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, 2)
self.new_shape = (3, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: There is mins axis.
class TestSqueezeOp1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 3, 1, 40)
self.axes = (0, -2)
self.new_shape = (3, 40)
# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 20, 1, 5)
self.axes = ()
self.new_shape = (20, 5)
# Correct: Just part of axes be squeezed.
class TestSqueezeOp3(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (6, 5, 1, 4)
# Correct: The demension of axis is not of size 1 remains unchanged.
class TestSqueezeOp4(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (6, 1, 5, 1, 4, 1)
self.axes = (1, 2)
self.new_shape = (6, 5, 1, 4, 1)
class TestSqueezeOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
# The input type of softmax_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], paddle.XPUPlace(0))
self.assertRaises(TypeError, paddle.squeeze, x1)
# The input axes of squeeze must be list.
x2 = paddle.static.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, paddle.squeeze, x2, axes=0)
# The input dtype of squeeze not support float16.
x3 = paddle.static.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, paddle.squeeze, x3, axes=0)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
# Correct: General.
class TestUnsqueezeOp(XPUOpTest):
def setUp(self):
self.init_test_case()
self.use_xpu = True
self.use_mkldnn = False
self.op_type = "unsqueeze2"
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['XShape'])
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a list(with tensor)
class TestUnsqueezeOp_AxesTensorList(XPUOpTest):
def setUp(self):
self.init_test_case()
self.use_xpu = True
self.use_mkldnn = False
self.op_type = "unsqueeze2"
axes_tensor_list = []
for index, ele in enumerate(self.axes):
axes_tensor_list.append(("axes" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
"X": np.random.random(self.ori_shape).astype("float32"),
"AxesTensorList": axes_tensor_list
}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['XShape'])
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (1, 2)
self.new_shape = (20, 1, 1, 5)
def init_attrs(self):
self.attrs = {}
class TestUnsqueezeOp1_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensorList(TestUnsqueezeOp_AxesTensorList):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
# axes is a Tensor
class TestUnsqueezeOp_AxesTensor(XPUOpTest):
def setUp(self):
self.init_test_case()
self.use_xpu = True
self.use_mkldnn = False
self.op_type = "unsqueeze2"
self.inputs = {
"X": np.random.random(self.ori_shape).astype("float32"),
"AxesTensor": np.array(self.axes).astype("int32")
}
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float32")
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place, no_check_set=['XShape'])
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (1, 2)
self.new_shape = (20, 1, 1, 5)
def init_attrs(self):
self.attrs = {}
class TestUnsqueezeOp1_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
class TestUnsqueezeOp2_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
class TestUnsqueezeOp3_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
# Correct: General.
class TestUnsqueezeOp(XPUOpTest):
def setUp(self):
self.init_test_case()
self.op_type = "unsqueeze"
self.use_xpu = True
self.use_mkldnn = False
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
self.init_attrs()
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def init_test_case(self):
self.ori_shape = (3, 40)
self.axes = (1, 2)
self.new_shape = (3, 1, 1, 40)
def init_attrs(self):
self.attrs = {"axes": self.axes}
# Correct: Single input index.
class TestUnsqueezeOp1(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (-1, )
self.new_shape = (20, 5, 1)
# Correct: Mixed input axis.
class TestUnsqueezeOp2(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (20, 5)
self.axes = (0, -1)
self.new_shape = (1, 20, 5, 1)
# Correct: There is duplicated axis.
class TestUnsqueezeOp3(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (0, 3, 3)
self.new_shape = (1, 10, 2, 1, 1, 5)
# Correct: Reversed axes.
class TestUnsqueezeOp4(TestUnsqueezeOp):
def init_test_case(self):
self.ori_shape = (10, 2, 5)
self.axes = (3, 1, 1)
self.new_shape = (10, 1, 1, 2, 5, 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册