未验证 提交 6bfc5721 编写于 作者: L liuyuhui 提交者: GitHub

[2.0 rc1/cherrypick] cherry-pick kunlun PR:29234/29229/29293/29367/29280/29448 (#29466)

* add deformable_conv op on xpu (#29234)

* rebase develop

* update deformable_conv op on xpu

* update deformable_conv op on xpu

* update kunlun conv2d/softmax/elementwise implemetation (#29229)

* update conv2d & softmax to new xpu api
* test=kunlun

* remove useless comments
* test=kunlun

* remote softmax xpu op
* test=kunlun

* update kunlun softmax
* test=kunlun

* update xpu unitest
* test=kunlun

* fix elementwise_grad bug for kunlun
*test=kunlun

* support global pooling for kunlun (#29293)

* test=kunlun

* update reduce_sum op on xpu (#29367)

* update reduce_sum op on xpu

* update reduce_sum op on xpu

* support running on xpu

* fix expand/uniform_random && concat/transpose to new api on xpu (#29280)

* fix expand && concat/transpose to new api

* update uniform_random_op

* update xpu_header

* 1. fix elementwise ops'bug 2. fix softmax_with_cross_entropy_op 3. add biliner_interp_op (#29448)
Co-authored-by: Nroot <root@bjhw-sys-rpm0223.bjhw.baidu.com>
Co-authored-by: N卖鱼的哲学 <tangzhiyi11@users.noreply.github.com>
Co-authored-by: NQingshuChen <qingshu.chen714@gmail.com>
Co-authored-by: Ntaixiurong <taixiurong@126.com>
Co-authored-by: Nroot <root@bjhw-sys-rpm0223.bjhw.baidu.com>
上级 6b9302a2
......@@ -4,7 +4,7 @@ endif()
INCLUDE(ExternalProject)
SET(XPU_PROJECT "extern_xpu")
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_11_10.tar.gz" CACHE STRING "" FORCE)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_04.tar.gz" CACHE STRING "" FORCE)
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}")
SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu")
......
......@@ -11,18 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string>
#include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
......@@ -32,8 +26,8 @@ template <typename DeviceContext, typename T>
class ConcatXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
int axis = ctx.Attr<int>("axis");
PADDLE_ENFORCE_NE(ins[0], nullptr, platform::errors::InvalidArgument(
"The input should not be null."));
......@@ -47,6 +41,7 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LT(axis, ins[0]->dims().size(),
platform::errors::InvalidArgument(
"concat: axis shoud < ins[0]->dims()!"));
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
std::vector<int> choose_idx;
......@@ -57,43 +52,54 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
n++;
}
}
PADDLE_ENFORCE_LE(n, 8, platform::errors::InvalidArgument(
"XPU only surpport at most 8 tensors for now"));
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument("No tensor need concat?"));
int h = 1;
int w_except_axis = 1;
for (int i = 0; i < axis; ++i) {
h *= (ins[choose_idx[0]]->dims())[i];
}
for (int i = axis + 1; i < ins[0]->dims().size(); ++i) {
w_except_axis *= (ins[choose_idx[0]]->dims())[i];
}
for (int i = 1; i < n; ++i) {
int hh = 1;
int ww = 1;
for (int j = 0; j < axis; ++j) {
hh *= (ins[choose_idx[i]]->dims())[j];
// If axis is 0, the lod of the output is not the same as inputs.
if (axis == 0 && ins[0]->lod().size() > 0) {
size_t lod_size_0 = ins[0]->lod().size();
size_t lod_size = lod_size_0;
for (size_t i = 1; i < ins.size(); ++i) {
if (ins[i]->lod().size() > 0) {
PADDLE_ENFORCE_EQ(
ins[i]->lod().size(), lod_size_0,
platform::errors::Unimplemented(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
"it is not supported currently. The lod level of %dth input "
"is %d and first input is %d.",
i, ins[i]->lod().size(), lod_size_0));
} else {
lod_size = 0;
break;
}
}
for (int j = axis + 1; j < ins[i]->dims().size(); ++j) {
ww *= (ins[choose_idx[i]]->dims())[j];
if (lod_size) {
auto* out_lod = out->mutable_lod();
for (size_t i = 1; i < ins.size(); ++i) {
auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod());
AppendLoD(out_lod, in_lod);
}
}
PADDLE_ENFORCE_EQ(hh, h, platform::errors::InvalidArgument(
"concat: h should be eual!"));
PADDLE_ENFORCE_EQ(ww, w_except_axis,
platform::errors::InvalidArgument(
"concat: w should be eual except for axis!"));
}
auto input_dims = ins[0]->dims();
std::vector<std::vector<int>> xdims_list(n);
for (int i = 0; i < n; ++i) {
std::vector<int> tmp_dims(input_dims.size());
for (int j = 0; j < input_dims.size(); ++j) {
tmp_dims[j] = ins[i]->dims()[j];
}
xdims_list[i] = tmp_dims;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::unique_ptr<int[]> in_w_host(new int[n]);
std::unique_ptr<const float* []> ptrs(new const float*[n]);
std::vector<const T*> ptrs;
for (int i = 0; i < n; ++i) {
ptrs[i] = ins[choose_idx[i]]->data<T>();
in_w_host[i] = w_except_axis * (ins[choose_idx[i]]->dims())[axis];
ptrs.push_back(ins[choose_idx[i]]->data<T>());
}
int r =
xpu::concat<float>(dev_ctx.x_context(), h, (const int*)in_w_host.get(),
n, (const float**)ptrs.get(), out->data<T>());
int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(),
xdims_list, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
......@@ -102,6 +108,7 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
r));
}
};
template <typename DeviceContext, typename T>
class ConcatGradXPUKernel : public framework::OpKernel<T> {
public:
......@@ -132,13 +139,15 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<framework::Tensor*> outputs;
std::vector<int> choose_idx;
int n = 0;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
outputs.push_back(outs[j]);
} else {
outputs.push_back(nullptr);
choose_idx.push_back(j);
n++;
}
}
PADDLE_ENFORCE_GE(axis, 0, platform::errors::InvalidArgument(
......@@ -146,23 +155,31 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LT(axis, out_grad->dims().size(),
platform::errors::InvalidArgument(
"concat_grad: axis shoud < ins[0]->dims()!"));
auto out_grad_stride = framework::stride_numel(out_grad->dims());
int n = outputs.size();
PADDLE_ENFORCE_LE(n, 16,
platform::errors::InvalidArgument(
"XPU only surpport at most 16 tensors for now"));
int h = out_grad_stride[0] / out_grad_stride[axis];
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::unique_ptr<int[]> in_w_host(new int[n]);
std::unique_ptr<float* []> ptrs(new float*[n]);
auto input_dims = ins[0]->dims();
std::vector<int> split_list(n);
std::vector<int> xdims_list(input_dims.size());
int total_length = 0;
for (int i = 0; i < n; ++i) {
split_list[i] = ins[i]->dims()[axis];
total_length += ins[i]->dims()[axis];
}
for (int i = 0; i < input_dims.size(); ++i) {
if (i == axis) {
continue;
}
xdims_list[i] = input_dims[i];
}
xdims_list[axis] = total_length;
std::vector<T*> ptrs(n);
for (int i = 0; i < n; ++i) {
auto out_stride = framework::stride_numel(outputs[i]->dims());
ptrs[i] = outputs[i]->data<T>();
in_w_host[i] = out_stride[axis];
}
int r = xpu::concat_grad(dev_ctx.x_context(), h, in_w_host.get(), n,
reinterpret_cast<float**>(ptrs.get()),
out_grad->data<T>());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs,
xdims_list, split_list, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
......
......@@ -27,10 +27,6 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
// that avoids modifying the variable in the Scope.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
// Tensor* max_input = context.Output<Tensor>("MaxInput");
// Tensor* max_filter = context.Output<Tensor>("MaxFilter");
// max_input->mutable_data<T>(context.GetPlace());
// max_filter->mutable_data<T>(context.GetPlace());
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
......@@ -43,52 +39,18 @@ class GemmConvXPUKernel : public framework::OpKernel<T> {
const int f = static_cast<int>(filter.dims()[0]);
const int win_h = static_cast<int>(filter.dims()[2]);
const int win_w = static_cast<int>(filter.dims()[3]);
PADDLE_ENFORCE_EQ(
dilations[0] == 1 && dilations[1] == 1, true,
platform::errors::InvalidArgument("XPU only support dilation == 1."));
auto& dev_ctx = context.template device_context<DeviceContext>();
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), input->data<T>(), input->numel(),
// max_input->data<T>()) == xpu::Error_t::SUCCESS,
// true, platform::errors::InvalidArgument(
// "XPU conv kernel error,can not finde max_input,please "
// "check whether Baidu Kunlun "
// "Card is properly installed."));
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), filter.data<T>(), filter.numel(),
// max_filter->data<T>()) == xpu::Error_t::SUCCESS,
// true, platform::errors::InvalidArgument(
// "XPU conv kernel error,can not find max_filter,please "
// "check whether Baidu Kunlun "
// "Card is properly installed."));
if (groups == 1) {
int r = xpu::conv2d_forward_int16<float, float, float, float>(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, input->data<float>(), filter.data<float>(),
output->data<float>(), nullptr, nullptr, xpu::Activation_t::LINEAR,
nullptr, nullptr);
// max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
"please check whether Baidu Kunlun Card "
"is properly installed.",
r));
} else {
int r = xpu::conv2d_int16_with_group<float, float, float>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, win_h,
win_w, groups, strides[0], strides[1], paddings[0], paddings[1],
nullptr, nullptr);
// max_input->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
"please check whether Baidu Kunlun Card "
"is properly installed.",
r));
}
std::vector<int> k_size;
k_size.push_back(win_h);
k_size.push_back(win_w);
int r = xpu::conv2d<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<float>(), filter.data<float>(),
output->data<float>(), batch_size, img_c, img_h, img_w, f, k_size,
strides, paddings, dilations, groups, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
......@@ -96,9 +58,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// const Tensor* max_input = context.Input<Tensor>("MaxInput");
// const Tensor* max_filter = context.Input<Tensor>("MaxFilter");
// Tensor* max_output_grad = context.Output<Tensor>("MaxOutputGrad");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
......@@ -115,11 +74,6 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(input->dims()[0]);
PADDLE_ENFORCE_EQ(groups == 1, true, platform::errors::InvalidArgument(
"XPU only support groups == 1."));
PADDLE_ENFORCE_EQ(
dilations[0] == 1 && dilations[1] == 1, true,
platform::errors::InvalidArgument("XPU only support dilation == 1."));
const int img_c = static_cast<int>(input->dims()[1]);
const int img_h = static_cast<int>(input->dims()[2]);
const int img_w = static_cast<int>(input->dims()[3]);
......@@ -133,52 +87,24 @@ class GemmConvGradXPUKernel : public framework::OpKernel<T> {
filter_grad->mutable_data<T>(context.GetPlace());
}
auto& dev_ctx = context.template device_context<DeviceContext>();
// max_output_grad->Resize({4});
// max_output_grad->mutable_data<T>(context.GetPlace());
// PADDLE_ENFORCE_EQ(
// xpu::findmax(dev_ctx.x_context(), output_grad->data<T>(),
// output_grad->numel(),
// max_output_grad->data<T>()) == xpu::Error_t::SUCCESS,
// true,
// platform::errors::External(
// "XPU conv kernel error, can not find max_output_grad, please
// check "
// "whether Baidu Kunlun Card is "
// "properly installed."));
if (input_grad) {
int r = xpu::conv2d_backward_int16(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(),
filter.data<float>(), input_grad->data<float>(), nullptr, nullptr);
// max_output_grad->data<float>(), max_filter->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
"please check whether Baidu Kunlun Card "
"is properly installed.",
r));
}
if (filter_grad) {
int r = xpu::conv2d_backward_weight_int16(
dev_ctx.x_context(), batch_size, img_c, img_h, img_w, f, win_h, win_w,
strides[0], strides[1], paddings[0], paddings[1], dilations[0],
dilations[1], groups, output_grad->data<float>(),
input->data<float>(), filter_grad->data<float>(), nullptr, nullptr);
// max_output_grad->data<float>(), max_input->data<float>());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d], "
"please check whether Baidu Kunlun Card "
"is properly installed.",
r));
}
std::vector<int> k_size;
k_size.push_back(win_h);
k_size.push_back(win_w);
int r = xpu::conv2d_grad<float, float, float, int16_t>(
dev_ctx.x_context(), input->data<T>(), filter.data<T>(),
output_grad->data<T>(), input_grad ? input_grad->data<T>() : nullptr,
filter_grad ? filter_grad->data<T>() : nullptr, batch_size, img_c,
img_h, img_w, f, k_size, strides, paddings, dilations, groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU conv kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// TODO(xingzhaolong): neon kernel for mobile
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d,
ops::GemmConvXPUKernel<paddle::platform::XPUDeviceContext, float>);
......@@ -187,4 +113,7 @@ REGISTER_OP_XPU_KERNEL(
REGISTER_OP_XPU_KERNEL(
conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
depthwise_conv2d_grad,
ops::GemmConvGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class DeformableConvXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto* offset = ctx.Input<Tensor>("Offset");
auto* mask = ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
Tensor* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int groups = ctx.Attr<int>("groups");
const int deformable_groups = ctx.Attr<int>("deformable_groups");
const int im2col_step = ctx.Attr<int>("im2col_step");
const std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1, true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1, true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset->data<T>();
const float* mask_ptr = mask->data<T>();
T* output_prt = output->data<T>();
// set zeros for d_table_data
const int zero = 0;
int r = xpu::constant<T>(dev_ctx.x_context(), output_prt, output->numel(),
zero);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d], please check where "
"Baidu Kunlun Card is properly installed.",
r));
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset->numel() / offset->dims()[0];
int input_mask_dim = mask->numel() / mask->dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv<float, float, float, int>(
dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim,
filter_ptr, offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_prt + i * im2col_step * output_dim, n, c, h, w, f, ksize,
strides, paddings, dilations, groups, deformable_groups, nullptr,
nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv kernel return wrong value[%d].", r));
}
}
};
template <typename DeviceContext, typename T>
class DeformableConvGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
Tensor* offset_grad = ctx.Output<Tensor>(framework::GradVarName("Offset"));
Tensor* mask_grad = ctx.Output<Tensor>(framework::GradVarName("Mask"));
T* dx_data = nullptr;
T* dw_data = nullptr;
T* dmask_data = nullptr;
T* doffset_data = nullptr;
if (input_grad != nullptr) {
input_grad->mutable_data<T>(ctx.GetPlace());
dx_data = input_grad->data<T>();
}
if (filter_grad != nullptr) {
filter_grad->mutable_data<T>(ctx.GetPlace());
dw_data = filter_grad->data<T>();
}
if (offset_grad != nullptr) {
offset_grad->mutable_data<T>(ctx.GetPlace());
doffset_data = offset_grad->data<T>();
}
if (mask_grad != nullptr) {
mask_grad->mutable_data<T>(ctx.GetPlace());
dmask_data = mask_grad->data<T>();
}
const Tensor* input = ctx.Input<Tensor>("Input");
Tensor offset = *ctx.Input<Tensor>("Offset");
Tensor mask = *ctx.Input<Tensor>("Mask");
Tensor filter = *ctx.Input<Tensor>("Filter");
int groups = ctx.Attr<int>("groups");
int deformable_groups = ctx.Attr<int>("deformable_groups");
int im2col_step = ctx.Attr<int>("im2col_step");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
PADDLE_ENFORCE_EQ(
deformable_groups == 1, true,
platform::errors::InvalidArgument((
"XPU only support deformable_groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(
groups == 1, true,
platform::errors::InvalidArgument(
("XPU only support groups == 1 in deformable_conv op.")));
PADDLE_ENFORCE_EQ(filter.dims()[2] <= 8 && filter.dims()[3] <= 8, true,
platform::errors::InvalidArgument(
"Filter high and weight should less than 8 on xpu "
"in deformable_conv op."));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> output_shape_vec(
framework::vectorize(output_grad->dims()));
const T* output_grad_ptr = output_grad->data<T>();
const T* input_ptr = input->data<T>();
const T* filter_ptr = filter.data<T>();
const float* offset_ptr = offset.data<float>();
const float* mask_ptr = mask.data<float>();
if (dx_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dx_data),
input->numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (dw_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dw_data),
filter.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (doffset_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&doffset_data),
offset.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
if (dmask_data == nullptr) {
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&dmask_data),
mask.numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
}
int input_dim = input->numel() / input->dims()[0];
int input_offset_dim = offset.numel() / offset.dims()[0];
int input_mask_dim = mask.numel() / mask.dims()[0];
int output_dim =
output_shape_vec[1] * output_shape_vec[2] * output_shape_vec[3];
std::vector<int> ksize{static_cast<int>(filter.dims()[2]),
static_cast<int>(filter.dims()[3])};
int n = im2col_step;
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
int f = filter.dims()[0];
T* filter_grad_tmp = nullptr;
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(&filter_grad_tmp),
filter_grad->numel() * sizeof(T)),
XPU_SUCCESS, platform::errors::ResourceExhausted(
"XPU has no enough memory"));
// set zeros for d_table_data
const int zero = 0;
int r_dx =
xpu::constant<T>(dev_ctx.x_context(), dx_data, input->numel(), zero);
int r_dw =
xpu::constant<T>(dev_ctx.x_context(), dw_data, filter.numel(), zero);
int r_doffset = xpu::constant<T>(dev_ctx.x_context(), doffset_data,
offset.numel(), zero);
int r_dmask =
xpu::constant<T>(dev_ctx.x_context(), dmask_data, mask.numel(), zero);
int r_filter = xpu::constant<T>(dev_ctx.x_context(), filter_grad_tmp,
filter.numel(), zero);
auto ret = (r_dx == xpu::Error_t::SUCCESS) && (r_dx == r_dw) &&
(r_dx == r_doffset) && (r_dx == r_dmask) && (r_dx == r_filter);
PADDLE_ENFORCE_EQ(ret, true,
platform::errors::External(
"XPU API return wrong value, please check where "
"Baidu Kunlun Card is properly installed."));
for (int i = 0; i < batch_size / im2col_step; ++i) {
int r = xpu::deformable_conv_grad<float, float, float, int>(
dev_ctx.x_context(), input_ptr + i * im2col_step * input_dim,
filter_ptr, offset_ptr + i * im2col_step * input_offset_dim,
mask_ptr + i * im2col_step * input_mask_dim,
output_grad_ptr + i * im2col_step * output_dim,
dx_data + i * im2col_step * input_dim, filter_grad_tmp,
doffset_data + i * im2col_step * input_offset_dim,
dmask_data + i * im2col_step * input_mask_dim, n, c, h, w, f, ksize,
strides, paddings, dilations, groups, deformable_groups, nullptr,
nullptr, nullptr, nullptr, nullptr, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU deformable_conv_grad kernel return wrong value[%d].", r));
r = baidu::xpu::api::add<T>(dev_ctx.x_context(), filter_grad_tmp, dw_data,
dw_data, filter.numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU add kernel return wrong value[%d].", r));
}
dev_ctx.Wait();
xpu_free(filter_grad_tmp);
if (input_grad == nullptr) {
xpu_free(dx_data);
}
if (filter_grad == nullptr) {
xpu_free(dw_data);
}
if (offset_grad == nullptr) {
xpu_free(doffset_data);
}
if (mask_grad == nullptr) {
xpu_free(dmask_data);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using XPUDeviceContext = paddle::platform::XPUDeviceContext;
REGISTER_OP_XPU_KERNEL(deformable_conv,
ops::DeformableConvXPUKernel<XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
deformable_conv_grad,
ops::DeformableConvGradXPUKernel<XPUDeviceContext, float>);
#endif
......@@ -28,9 +28,10 @@ class ElementwiseDivXPUKernel : public framework::OpKernel<T> {
};
template <typename DeviceContext, typename T>
class ElementwiseDivGradXPUKernel : public framework::OpKernel<T> {
class ElementwiseDivGradXPUKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::div_grad<T>, true);
}
};
......
......@@ -29,9 +29,10 @@ class ElementwiseMaxXPUKernel : public framework::OpKernel<T> {
};
template <typename DeviceContext, typename T>
class ElementwiseMaxGradXPUKernel : public framework::OpKernel<T> {
class ElementwiseMaxGradXPUKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::max_grad<T>, true);
}
};
......
......@@ -29,9 +29,10 @@ class ElementwiseMinXPUKernel : public framework::OpKernel<T> {
};
template <typename DeviceContext, typename T>
class ElementwiseMinGradXPUKernel : public framework::OpKernel<T> {
class ElementwiseMinGradXPUKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::min_grad<T>, true);
}
};
......
......@@ -27,9 +27,10 @@ class ElementwiseMulXPUKernel : public framework::OpKernel<T> {
};
// DEFINE_XPU_GRAD_KERNEL(Mul, mul, true);
template <typename DeviceContext, typename T>
class ElementwiseMulGradXPUKernel : public framework::OpKernel<T> {
class ElementwiseMulGradXPUKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::mul_grad<T>, true);
}
};
......
......@@ -65,7 +65,7 @@ static std::pair<std::vector<int>, std::vector<int>> XPUReducesAxisVector(
}
int yidx = 0;
for (size_t i = 0; i < x_vector.size(); ++i) {
if (y[yidx] == 1) {
if (yidx >= y.size() || y[yidx] == 1) {
axis_v.push_back(i);
yidx++;
continue;
......@@ -134,10 +134,10 @@ void XPUElementwise(
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), z->numel()),
bcast_v.first, bcast_v.second);
ret = xpu::broadcast<T>(dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
......@@ -153,10 +153,10 @@ void XPUElementwise(
std::vector<int> bcast_y_v;
std::pair<std::vector<int>, std::vector<int>> bcast_v =
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), z->numel()),
bcast_v.first, bcast_v.second);
ret = xpu::broadcast<T>(dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(
ctx.GetPlace(), z->numel() * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
......@@ -231,13 +231,15 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len);
bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len);
T* dx_data = ((dx == nullptr) || dx_need_reduce)
? (dx_local_tensor.mutable_data<T>(ctx.GetPlace(), len))
: (dx->mutable_data<T>(ctx.GetPlace()));
T* dx_data =
((dx == nullptr) || dx_need_reduce)
? (dx_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
: (dx->mutable_data<T>(ctx.GetPlace()));
T* dy_data = ((dy == nullptr) || dy_need_reduce)
? (dy_local_tensor.mutable_data<T>(ctx.GetPlace(), len))
: (dy->mutable_data<T>(ctx.GetPlace()));
T* dy_data =
((dy == nullptr) || dy_need_reduce)
? (dy_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
: (dy->mutable_data<T>(ctx.GetPlace()));
int ret = xpu::SUCCESS;
auto& dev_ctx =
......@@ -250,8 +252,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), x_data,
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len), bcast_v.first,
bcast_v.second);
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
......@@ -267,8 +269,8 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
ret = xpu::broadcast<T>(
dev_ctx.x_context(), y_data,
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len), bcast_v.first,
bcast_v.second);
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
bcast_v.first, bcast_v.second);
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel broadcast error occur! %d", ret));
......@@ -287,9 +289,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
const framework::DDim& dx_dims = dx->dims();
std::pair<std::vector<int>, std::vector<int>> reduce_v =
XPUReducesAxisVector(out_dim, dx_dims);
ret = xpu::reduce_sum(dev_ctx.x_context(), dx_data,
dx->mutable_data<T>(ctx.GetPlace()), reduce_v.first,
reduce_v.second);
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dx_data,
dx->mutable_data<T>(ctx.GetPlace()),
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
......@@ -302,9 +304,9 @@ void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
const framework::DDim& dy_dims = dy->dims();
std::pair<std::vector<int>, std::vector<int>> reduce_v =
XPUReducesAxisVector(out_dim, dy_dims);
ret = xpu::reduce_sum(dev_ctx.x_context(), dy_data,
dy->mutable_data<T>(ctx.GetPlace()), reduce_v.first,
reduce_v.second);
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_data,
dy->mutable_data<T>(ctx.GetPlace()),
reduce_v.first, reduce_v.second);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External("XPU kernel reduce_sum occur error in "
......
......@@ -56,6 +56,12 @@ inline std::vector<int> get_expand_times(
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
expand_data = cpu_expand_tensor.data<int>();
}
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(expand_tensor->place())) {
TensorCopySync(*expand_tensor, platform::CPUPlace(), &cpu_expand_tensor);
expand_data = cpu_expand_tensor.data<int>();
}
#endif
auto vec_epxand_times =
std::vector<int>(expand_data, expand_data + expand_tensor->numel());
return vec_epxand_times;
......@@ -72,7 +78,15 @@ inline std::vector<int> get_expand_times(
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_epxand_times.push_back(*temp.data<int32_t>());
} else {
}
#ifdef PADDLE_WITH_XPU
else if (platform::is_xpu_place(tensor->place())) { // NOLINT
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_epxand_times.push_back(*temp.data<int32_t>());
}
#endif
else { // NOLINT
vec_epxand_times.push_back(*tensor->data<int32_t>());
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/interpolate_op.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
inline std::vector<int> get_new_shape_xpu(
const std::vector<const Tensor*>& list_new_shape_tensor) {
// get tensor from
std::vector<int> vec_new_shape;
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument("shape of dim tensor should be [1]"));
if (platform::is_xpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
} else {
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
}
}
return vec_new_shape;
}
template <typename T>
inline std::vector<T> get_new_data_from_tensor_xpu(
const Tensor* new_data_tensor) {
std::vector<T> vec_new_data;
auto* new_data = new_data_tensor->data<T>();
framework::Tensor cpu_starts_tensor;
if (platform::is_xpu_place(new_data_tensor->place())) {
TensorCopySync(*new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
new_data = cpu_starts_tensor.data<T>();
}
vec_new_data = std::vector<T>(new_data, new_data + new_data_tensor->numel());
return vec_new_data;
}
template <typename T>
class InterpolateXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto input_dims = input->dims();
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
platform::errors::External("XPU Interpolate kernel only support 2d"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input_dims, data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape_xpu(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
} else {
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor_xpu<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor_xpu<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
}
PADDLE_ENFORCE_GT(out_h, 0, platform::errors::InvalidArgument(
"out_h in Attr(out_shape) of "
"Op(interpolate) "
"should be greater than 0."));
PADDLE_ENFORCE_GT(out_w, 0, platform::errors::InvalidArgument(
"out_w in Attr(out_shape) of "
"Op(interpolate) "
"should be greater than 0."));
framework::DDim dim_out;
if (data_layout == DataLayout::kNCHW) {
dim_out = {n, c, out_h, out_w};
} else {
dim_out = {n, out_h, out_w, c};
}
output->mutable_data<T>(dim_out, ctx.GetPlace());
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
bool nearest = "nearest" == interp_method;
int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2));
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
if (nearest) {
PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true,
platform::errors::InvalidArgument(
"XPU nearest is only support NCHW"));
}
int r = xpu::interpolate2d<float>(dev_ctx.x_context(), input->data<float>(),
output->data<float>(), n, c, in_h, in_w,
out_h, out_w, nearest, trans_mode,
(data_layout == DataLayout::kNCHW));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU interpolate2d kernel "
"return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
template <typename T>
class InterpolateGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto output_grad_dims = output_grad->dims();
PADDLE_ENFORCE_EQ(output_grad_dims.size(), 4,
platform::errors::External(
"XPU Interpolategrad kernel only support 2d"));
auto* input = ctx.Input<Tensor>("X");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
ExtractNCDWH(input->dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
auto interp_method = ctx.Attr<std::string>("interp_method");
bool align_corners = ctx.Attr<bool>("align_corners");
int align_mode = ctx.Attr<int>("align_mode");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
float scale;
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor_xpu<float>(scale_tensor);
scale = scale_data[0];
} else {
scale = ctx.Attr<float>("scale");
}
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = get_new_data_from_tensor_xpu<int>(out_size);
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
if (list_new_size_tensor.size() > 0) {
// have size tensor
auto new_size = get_new_shape_xpu(list_new_size_tensor);
out_h = new_size[0];
out_w = new_size[1];
}
framework::DDim dim_grad;
if (data_layout == DataLayout::kNCHW) {
dim_grad = {n, c, in_h, in_w};
} else {
dim_grad = {n, in_h, in_w, c};
}
input_grad->mutable_data<T>(dim_grad, ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
r = xpu::constant<T>(dev_ctx.x_context(), input_grad->data<T>(),
input_grad->numel(), static_cast<T>(0.0));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU constant in interpolate2d_grad kernel return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
bool nearest = "nearest" == interp_method;
int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2));
if (nearest) {
PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true,
platform::errors::InvalidArgument(
"XPU nearest is only support NCHW"));
}
r = xpu::interpolate2d_grad<T>(dev_ctx.x_context(), output_grad->data<T>(),
input_grad->data<T>(), n, c, in_h, in_w,
out_h, out_w, nearest, trans_mode,
(data_layout == DataLayout::kNCHW));
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU interpolate2d_grad kernel return "
"wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(bilinear_interp, ops::InterpolateXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(bilinear_interp_grad,
ops::InterpolateGradXPUKernel<float>);
#endif
......@@ -43,16 +43,18 @@ class PoolXPUKernel : public framework::OpKernel<T> {
bool exclusive = context.Attr<bool>("exclusive");
bool is_test = context.Attr<bool>("is_test");
bool adaptive = context.Attr<bool>("adaptive");
PADDLE_ENFORCE_EQ(
!adaptive, true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(
ksize.size(), 2,
platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
int* index_data = nullptr;
if (context.Attr<bool>("global_pooling")) {
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
......@@ -104,16 +106,18 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
const int* index_data = nullptr;
PADDLE_ENFORCE_EQ(
!adaptive, true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(ksize.size(), 2, platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 "
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
if (context.Attr<bool>("global_pooling")) {
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
......
......@@ -16,6 +16,8 @@
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
......@@ -27,86 +29,120 @@ class ReduceSumXPUKernel : public framework::OpKernel<T> {
platform::is_xpu_place(context.GetPlace()), true,
platform::errors::Unavailable("This kernel only runs on XPU."));
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto dims = context.Attr<std::vector<int>>("dim");
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
int out_dtype = context.Attr<int>("out_dtype");
PADDLE_ENFORCE_EQ(
out_dtype == -1, true,
platform::errors::InvalidArgument(
"XPU only support out_dtype == -1 in reduce_sum op."));
const auto* x_data = x->data<T>();
auto* y_data = y->data<T>();
const auto& input_dim_size = x->dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
}
std::vector<int> reduce_dims;
std::vector<int> xdims((input_dim_size));
for (int i = 0; i < input_dim_size; ++i) {
xdims[i] = x->dims()[i];
}
if (reduce_all) {
int input_len = input->numel();
int r = xpu::sum(dev_ctx.x_context(), input->data<T>(), output->data<T>(),
input_len);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU kernel error!"));
for (int i = 0; i < input_dim_size; ++i) {
reduce_dims.push_back(i);
}
} else {
int ndim = input->dims().size();
std::vector<int> idims;
for (int i = 0; i < input->dims().size(); i++) {
idims.push_back(input->dims()[i]);
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) != dims_set.end()) {
if (x->dims()[i] != 1) {
reduce_dims.push_back(i);
}
}
}
auto dims = context.Attr<std::vector<int>>("dim");
int rdim = dims.size();
int r =
xpu::reduce(dev_ctx.x_context(), input->data<T>(), output->data<T>(),
idims.data(), ndim, dims.data(), rdim, xpu::REDUCE_SUM);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU kernel error!"));
}
if (reduce_dims.size() == 0) {
int r = xpu::copy<T>(dev_ctx.x_context(), x_data, y_data,
x->numel() * sizeof(T));
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU copy in reduce_sum op return "
"wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
} else {
int r = xpu::reduce_sum<T>(dev_ctx.x_context(), x_data, y_data, xdims,
reduce_dims);
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU reduce_sum in reduce_sum op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
}
}
};
template <typename DeviceContext, typename T>
class ReduceSumGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto dims = context.Attr<std::vector<int>>("dim");
bool reduce_all = context.Attr<bool>("reduce_all");
auto* input0 = context.Input<Tensor>("X");
auto* input2 = context.Input<Tensor>(framework::GradVarName("Out"));
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
const auto* input2_d = input2->data<T>();
auto* output_d = output->data<T>();
auto* x = context.Input<Tensor>("X");
auto* out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int in_dtype = context.Attr<int>("in_dtype");
PADDLE_ENFORCE_EQ(
in_dtype == -1, true,
platform::errors::InvalidArgument(
"XPU only support in_dtype == -1 in reduce_sum_grad op."));
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = 0;
std::vector<int> idims;
int reduce_dim = 0;
if (reduce_all) {
idims.push_back(input0->numel());
idims.push_back(1);
idims.push_back(1);
r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d,
idims.data(), idims.size(), &reduce_dim, 1,
xpu::REDUCE_SUM);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU kernel error!"));
} else if (dims.size() == 1) {
// handle reduce by one dimension
int reduce_dim_index = dims[0];
if (reduce_dim_index < 0) {
reduce_dim_index += input0->dims().size();
}
auto& input_dim = input0->dims();
int before_dim = 1;
for (int i = 0; i < reduce_dim_index; ++i) {
before_dim *= input_dim[i];
x_grad->mutable_data<T>(context.GetPlace());
const auto* out_data = out->data<T>();
auto* x_grad_data = x_grad->data<T>();
const auto& input_dim_size = x->dims().size();
std::vector<int> true_dims;
for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) {
true_dims.push_back(dims[i] + input_dim_size);
} else {
true_dims.push_back(dims[i]);
}
int reduce_dim = input_dim[reduce_dim_index];
int after_dim = 1;
for (int i = reduce_dim_index + 1; i < input_dim.size(); ++i) {
after_dim *= input_dim[i];
}
std::vector<int> ydims(input_dim_size);
std::vector<int> xdims((input_dim_size));
std::set<int> dims_set(true_dims.begin(), true_dims.end());
for (auto i = 0; i < input_dim_size; i++) {
xdims[i] = x->dims()[i];
if (dims_set.find(i) != dims_set.end() || reduce_all) {
ydims[i] = 1;
} else {
ydims[i] = x->dims()[i];
}
idims.push_back(before_dim);
idims.push_back(input_dim[reduce_dim_index]);
idims.push_back(after_dim);
reduce_dim = 1;
r = xpu::reduce_grad(dev_ctx.x_context(), input2_d, output_d,
idims.data(), idims.size(), &reduce_dim, 1,
xpu::REDUCE_SUM);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU kernel error!"));
} else {
PADDLE_THROW(
platform::errors::Unimplemented("unsupport reduce sum grad"));
}
int r = xpu::broadcast<T>(dev_ctx.x_context(), out_data, x_grad_data, ydims,
xdims);
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU broadcast in reduce_sum_grad op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
}
};
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -30,29 +27,27 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
const int rank = x->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(axis == -1 || axis == rank - 1, true,
platform::errors::InvalidArgument(
"xpu softmax kernel only support last dimension of x "
"(axis==-1 or axis==x_dims-1), but received axis: "
"%d, x's shape: %s.",
axis, x->dims()));
int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
out->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, x->dims());
const int d = SizeFromAxis(axis, x->dims());
std::vector<int> x_dims;
for (int i = 0; i < rank; i++) {
x_dims.push_back(x->dims()[i]);
}
if (axis < 0) {
axis += rank;
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::softmax2d_forward(dev_ctx.x_context(), x->data<float>(),
out->data<float>(), n, d, d <= 2048);
int r = xpu::softmax<T>(dev_ctx.x_context(), x->data<float>(),
out->data<float>(), x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong "
"value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
......@@ -64,24 +59,28 @@ class SoftmaxGradXPUKernel : public framework::OpKernel<T> {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
const int rank = dx->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
// allocate memory on device.
dx->mutable_data<T>(context.GetPlace());
const int n = SizeToAxis(axis, dx->dims());
const int d = SizeFromAxis(axis, dx->dims());
std::vector<int> x_dims;
for (int i = 0; i < rank; i++) {
x_dims.push_back(dx->dims()[i]);
}
if (axis < 0) {
axis += rank;
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r =
xpu::softmax2d_backward(dev_ctx.x_context(), out->data<float>(),
dout->data<float>(), dx->data<float>(), n, d);
int r = xpu::softmax_grad<T>(dev_ctx.x_context(), out->data<float>(),
dout->data<float>(), dx->data<float>(), x_dims,
axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_backward) return wrong "
"value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
......
......@@ -70,7 +70,8 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
r));
} else {
Tensor labels_int32;
labels_int32.mutable_data<int32_t>(context.GetPlace(), labels->numel());
labels_int32.mutable_data<int32_t>(context.GetPlace(),
labels->numel() * sizeof(int32_t));
r = xpu::cast_v2<int64_t, int32_t>(
dev_ctx.x_context(), labels->data<int64_t>(),
labels_int32.data<int32_t>(), labels->numel());
......
......@@ -17,105 +17,27 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/platform/xpu_header.h"
namespace paddle {
namespace operators {
using framework::Tensor;
bool XPUSupported(int ndims, const std::vector<int>& axis) {
/*
* XPU currently support:
* permute = {0, 2, 1}, permute = {1, 0},
* permute = {0, 2, 1, 3}, permute = {1, 0, 2},
* permute = {0, 2, 3, 1}
*/
bool is_supported = false;
std::vector<int> permute_10(2, 0);
std::vector<int> permute_102(3, 0);
std::vector<int> permute_021(3, 0);
std::vector<int> permute_210(3, 0);
std::vector<int> permute_0213(4, 0);
std::vector<int> permute_0231(4, 0);
std::vector<int> permute_0312(4, 0);
std::vector<int> permute_3201(4, 0);
permute_10[0] = 1;
permute_102[0] = 1;
permute_102[2] = 2;
permute_021[1] = 2;
permute_021[2] = 1;
permute_210[0] = 2;
permute_210[1] = 1;
permute_0213[1] = 2;
permute_0213[2] = 1;
permute_0213[3] = 3;
permute_0231[1] = 2;
permute_0231[2] = 3;
permute_0231[3] = 1;
permute_0312[1] = 3;
permute_0312[2] = 1;
permute_0312[3] = 2;
permute_3201[0] = 3;
permute_3201[1] = 2;
permute_3201[3] = 1;
switch (ndims) {
case 2:
if (axis == permute_10) {
is_supported = true;
}
break;
case 3:
if ((axis == permute_021) || (axis == permute_102) ||
(axis == permute_210)) {
is_supported = true;
}
break;
case 4:
if ((axis == permute_0213) || (axis == permute_0231) ||
(axis == permute_0312) || (axis == permute_3201)) {
is_supported = true;
}
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Tensors with rank only 2, 3 and 4 are supported on XPU"));
}
return is_supported;
}
template <typename DeviceContext, typename T>
class TransposeXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto x = context.Input<framework::Tensor>("X");
auto out = context.Output<framework::Tensor>("Out");
// axis is permute
auto axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
const auto x_dims = x->dims();
const T* x_data = x->data<T>();
T* y_data = out->mutable_data<T>(context.GetPlace());
if (!XPUSupported(ndims, axis)) {
VLOG(0) << "XPU does not support the permute, try to do on cpu";
framework::Tensor x_cpu;
framework::Tensor out_cpu;
auto x_cpu_data = x_cpu.mutable_data<T>(x->dims(), platform::CPUPlace());
auto out_cpu_data =
out_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
memory::Copy(platform::CPUPlace(), reinterpret_cast<void*>(x_cpu_data),
BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()),
(const void*)x_data, x->numel() * sizeof(T));
const platform::CPUDeviceContext* cpu_dev_ctx =
static_cast<const platform::CPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(
platform::CPUPlace()));
TransCompute<platform::CPUDeviceContext, T>(ndims, *cpu_dev_ctx, x_cpu,
&out_cpu, axis);
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()),
reinterpret_cast<void*>(y_data), platform::CPUPlace(),
(const void*)out_cpu_data, out->numel() * sizeof(T));
if (out->numel() == 0) {
return;
}
......@@ -123,10 +45,9 @@ class TransposeXPUKernel : public framework::OpKernel<T> {
for (int i = 0; i < ndims; ++i) {
x_shape_host[i] = x_dims[i];
}
int* permute_host = axis.data();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose(dev_ctx.x_context(), x_data, y_data,
x_shape_host.data(), permute_host, ndims);
int r = xpu::transpose<T>(dev_ctx.x_context(), x_data, y_data, x_shape_host,
axis);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
......@@ -151,20 +72,13 @@ class TransposeGradXPUKernel : public framework::OpKernel<T> {
}
int ndims = axis.size();
if (!XPUSupported(ndims, reversed_axis)) {
PADDLE_THROW(
platform::errors::Unimplemented("XPU does not support the permute"));
}
std::vector<int> out_shape_host(ndims, 0);
for (int i = 0; i < ndims; ++i) {
out_shape_host[i] = out_grad->dims()[i];
}
int* permute_host = reversed_axis.data();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose(dev_ctx.x_context(), out_grad->data<T>(),
x_grad->data<T>(), out_shape_host.data(),
permute_host, ndims);
int r = xpu::transpose<T>(dev_ctx.x_context(), out_grad->data<T>(),
x_grad->data<T>(), out_shape_host, reversed_axis);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
......
......@@ -29,37 +29,68 @@ class XPUUniformRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
framework::Tensor *tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
std::vector<int64_t> new_shape;
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensorList");
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ShapeTensor")) {
auto *shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
new_shape = GetNewDataFromShapeTensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
if (out_var->IsType<framework::SelectedRows>()) {
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
if (!new_shape.empty()) shape = new_shape;
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
} else if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
if (!new_shape.empty()) tensor->Resize(framework::make_ddim(new_shape));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Output(out) in uniform_random_op must be "
"LoDTensor, "
"SelectedRows. But got unsupport type: %s.",
"Expected type of Output(out) in uniform_random_op must be Tensor, "
"SelectedRows. But got "
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
T *data = tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
auto engine = framework::GetCPURandomEngine(seed);
std::unique_ptr<T[]> data_cpu(new T[size]);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size, (diag_num - 1) * (diag_step + 1),
platform::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num, diag_step, (diag_num - 1) * (diag_step + 1), size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data_cpu[pos] = diag_val;
}
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), data,
platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T));
......
......@@ -15,11 +15,13 @@
#pragma once
#ifdef PADDLE_WITH_XPU
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/fluid/platform/errors.h"
#include "xpu/api.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h"
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
......@@ -48,4 +50,11 @@ class XPUActHelper {
return res->second;
}
};
static std::map<int, std::string> XPUAPIErrorMsg = {
{xpu::Error_t::SUCCESS, "xpu api success"},
{xpu::Error_t::INVALID_PARAM, "xpu api invalid param"},
{xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"},
{xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}};
#endif
......@@ -1915,6 +1915,10 @@ def load(program, model_path, executor=None, var_list=None):
place = paddle.fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace()
elif p.is_xpu_place():
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.XPUPlace(p.xpu_device_id())
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
......
......@@ -362,17 +362,6 @@ class XPUOpTest(OpTest):
if not type(output_names) is list:
output_names = [output_names]
numeric_grads = user_defined_grads or [
get_numeric_gradient(
place,
self.scope,
self.op,
self.inputs,
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
]
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
return analytic_grads
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import time
paddle.enable_static()
def bilinear_interp_np(input,
out_h,
out_w,
out_size=None,
actual_shape=None,
align_corners=True,
align_mode=0,
data_layout='NCHW'):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if data_layout == "NHWC":
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
if (align_corners):
ratio_h = (in_h - 1.0) / (out_h - 1.0)
else:
ratio_h = 1.0 * in_h / out_h
if out_w > 1:
if (align_corners):
ratio_w = (in_w - 1.0) / (out_w - 1.0)
else:
ratio_w = 1.0 * in_w / out_w
out = np.zeros((batch_size, channel, out_h, out_w))
for i in range(out_h):
if (align_mode == 0 and not align_corners):
h = int(ratio_h * (i + 0.5) - 0.5)
else:
h = int(ratio_h * i)
h = max(0, h)
hid = 1 if h < in_h - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_h = max(ratio_h * (i + 0.5) - 0.5, 0)
h1lambda = idx_src_h - h
else:
h1lambda = ratio_h * i - h
h2lambda = 1.0 - h1lambda
for j in range(out_w):
if (align_mode == 0 and not align_corners):
w = int(ratio_w * (j + 0.5) - 0.5)
else:
w = int(ratio_w * j)
w = max(0, w)
wid = 1 if w < in_w - 1 else 0
if (align_mode == 0 and not align_corners):
idx_src_w = max(ratio_w * (j + 0.5) - 0.5, 0)
w1lambda = idx_src_w - w
else:
w1lambda = ratio_w * j - w
w2lambda = 1.0 - w1lambda
out[:, :, i, j] = h2lambda*(w2lambda*input[:, :, h, w] +
w1lambda*input[:, :, h, w+wid]) + \
h1lambda*(w2lambda*input[:, :, h+hid, w] +
w1lambda*input[:, :, h+hid, w+wid])
if data_layout == "NHWC":
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
return out.astype(input.dtype)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpOp(XPUOpTest):
def setUp(self):
self.use_xpu = True
self.out_size = None
self.actual_shape = None
self.data_layout = 'NCHW'
self.init_test_case()
self.op_type = "bilinear_interp"
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW":
in_h = self.input_shape[2]
in_w = self.input_shape[3]
else:
in_h = self.input_shape[1]
in_w = self.input_shape[2]
if self.scale > 0:
out_h = int(in_h * self.scale)
out_w = int(in_w * self.scale)
else:
out_h = self.out_h
out_w = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners,
self.align_mode, self.data_layout)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'scale': self.scale,
'interp_method': self.interp_method,
'align_corners': self.align_corners,
'align_mode': self.align_mode,
'data_layout': self.data_layout
}
self.outputs = {'Out': output_np}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase4(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.scale = 0.
self.out_size = np.array([2, 2]).astype("int32")
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase5(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = np.array([11, 11]).astype("int32")
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpCase6(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 32, 64]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([65, 33]).astype("int32")
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpSame(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 32, 64]
self.out_h = 32
self.out_w = 64
self.scale = 0.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpActualShape(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpDataLayout(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 5, 5, 3]
self.out_h = 2
self.out_w = 2
self.scale = 0.
self.out_size = np.array([3, 3]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = "NHWC"
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpOtherMethod1(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpWithMethod2(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = False
self.align_mode = 0
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
def set_align_mode(self):
self.align_corners = True
self.align_mode = 0
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 2.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 1.5
self.align_corners = True
self.align_mode = 1
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpZero(TestBilinearInterpOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 7]
self.out_h = 60
self.out_w = 25
self.scale = 0.2
self.align_corners = False
self.align_mode = 0
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpOp_attr_tensor(XPUOpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.shape_by_1Dtensor = False
self.scale_by_1Dtensor = False
self.attrs = {
'interp_method': self.interp_method,
'align_corners': self.align_corners,
}
input_np = np.random.random(self.input_shape).astype("float32")
self.inputs = {'X': input_np}
if self.scale_by_1Dtensor:
self.inputs['Scale'] = np.array([self.scale]).astype("float32")
elif self.scale > 0:
out_h = int(self.input_shape[2] * self.scale)
out_w = int(self.input_shape[3] * self.scale)
self.attrs['scale'] = self.scale
else:
out_h = self.out_h
out_w = self.out_w
if self.shape_by_1Dtensor:
self.inputs['OutSize'] = self.out_size
elif self.out_size is not None:
size_tensor = []
for index, ele in enumerate(self.out_size):
size_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs['SizeTensor'] = size_tensor
self.attrs['out_h'] = self.out_h
self.attrs['out_w'] = self.out_w
output_np = bilinear_interp_np(input_np, out_h, out_w, self.out_size,
self.actual_shape, self.align_corners)
self.outputs = {'Out': output_np}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_h = 3
self.out_w = 3
self.scale = 0.
self.out_size = [3, 3]
self.align_corners = True
# out_size is a 1-D tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterp_attr_tensor_Case1(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.scale = 0.
self.out_size = [8, 12]
self.align_corners = True
# scale is a 1-D tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterp_attr_tensor_Case2(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 0.
self.out_size = np.array([66, 40]).astype("int32")
self.align_corners = True
self.shape_by_1Dtensor = True
# scale is a 1-D tensor
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterp_attr_tensor_Case3(TestBilinearInterpOp_attr_tensor):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.scale = 2.0
self.out_size = None
self.align_corners = True
self.scale_by_1Dtensor = True
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestBilinearInterpOpAPI(unittest.TestCase):
def test_case(self):
x = fluid.data(name="x", shape=[2, 3, 6, 6], dtype="float32")
dim = fluid.data(name="dim", shape=[1], dtype="int32")
shape_tensor = fluid.data(name="shape_tensor", shape=[2], dtype="int32")
actual_size = fluid.data(name="actual_size", shape=[2], dtype="int32")
scale_tensor = fluid.data(
name="scale_tensor", shape=[1], dtype="float32")
out1 = fluid.layers.resize_bilinear(x, out_shape=[12, 12])
out2 = fluid.layers.resize_bilinear(x, out_shape=[12, dim])
out3 = fluid.layers.resize_bilinear(x, out_shape=shape_tensor)
out4 = fluid.layers.resize_bilinear(
x, out_shape=[4, 4], actual_shape=actual_size)
out5 = fluid.layers.resize_bilinear(x, scale=scale_tensor)
x_data = np.random.random((2, 3, 6, 6)).astype("float32")
dim_data = np.array([12]).astype("int32")
shape_data = np.array([12, 12]).astype("int32")
actual_size_data = np.array([12, 12]).astype("int32")
scale_data = np.array([2.0]).astype("float32")
place = core.XPUPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
results = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"dim": dim_data,
"shape_tensor": shape_data,
"actual_size": actual_size_data,
"scale_tensor": scale_data
},
fetch_list=[out1, out2, out3, out4, out5],
return_numpy=True)
expect_res = bilinear_interp_np(
x_data, out_h=12, out_w=12, align_corners=True)
for res in results:
self.assertTrue(np.allclose(res, expect_res))
if __name__ == "__main__":
unittest.main()
......@@ -19,16 +19,20 @@ import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
class TestConcatOp(OpTest):
class TestConcatOp(XPUOpTest):
def setUp(self):
self.op_type = "concat"
self.dtype = self.get_dtype()
self.use_xpu = True
self.use_mkldnn = False
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
......@@ -44,7 +48,7 @@ class TestConcatOp(OpTest):
}
def get_dtype(self):
return "float64"
return "float32"
def test_check_output(self):
if paddle.is_compiled_with_xpu():
......@@ -131,7 +135,7 @@ class TestConcatOp6(TestConcatOp):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_dygraph=False)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
......@@ -147,94 +151,6 @@ class TestConcatOp6(TestConcatOp):
self.axis = 0
class TestConcatOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of concat_op should be list.
x1 = fluid.layers.data(shape=[4], dtype='int32', name='x1')
fluid.layers.concat(x1)
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.concat, [x2])
# The input dtype of concat_op must be float16, float32, float64, int32, int64.
x4 = fluid.layers.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.layers.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
fluid.layers.concat([x6, x7])
# The type of axis in concat_op should be int or Variable.
def test_axis_type():
fluid.layers.concat([x6, x7], 3.2)
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
fluid.layers.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
class TestConcatAPI(unittest.TestCase):
def test_fluid_api(self):
x_1 = fluid.data(shape=[None, 1, 4, 5], dtype='float32', name='x_1')
fluid.layers.concat([x_1, x_1], 0)
input_2 = np.random.random([2, 1, 4, 5]).astype("float32")
input_3 = np.random.random([2, 2, 4, 5]).astype("float32")
x_2 = fluid.data(shape=[2, 1, 4, 5], dtype='float32', name='x_2')
x_3 = fluid.data(shape=[2, 2, 4, 5], dtype='float32', name='x_3')
positive_1_int32 = fluid.layers.fill_constant([1], "float32", 1)
positive_1_int64 = fluid.layers.fill_constant([1], "float32", 1)
out_1 = fluid.layers.concat(input=[x_2, x_3], axis=1)
out_2 = fluid.layers.concat(input=[x_2, x_3], axis=1)
out_3 = fluid.layers.concat(input=[x_2, x_3], axis=1)
exe = fluid.Executor(place=fluid.XPUPlace(0))
[res_1, res_2, res_3] = exe.run(
fluid.default_main_program(),
feed={"x_1": input_2,
"x_2": input_2,
"x_3": input_3},
fetch_list=[out_1, out_2, out_3])
assert np.array_equal(res_1, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_2, np.concatenate((input_2, input_3), axis=1))
assert np.array_equal(res_3, np.concatenate((input_2, input_3), axis=1))
def test_errors(self):
with program_guard(Program(), Program()):
# The item in input must be Variable.
x2 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
x3 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, paddle.concat, [x2])
# The input dtype of concat_op must be float32.
x4 = fluid.data(shape=[4], dtype='uint8', name='x4')
x5 = fluid.data(shape=[4], dtype='uint8', name='x5')
self.assertRaises(TypeError, fluid.layers.concat, [x4, x5])
# The type of axis in concat_op should be int or Variable.
x6 = fluid.layers.data(shape=[4], dtype='float16', name='x6')
x7 = fluid.layers.data(shape=[4], dtype='float16', name='x7')
x8 = fluid.layers.data(shape=[4], dtype='float32', name='x8')
def test_axis_type():
paddle.concat([x6, x7], 3.2)
self.assertRaises(TypeError, test_axis_type)
def test_input_same_dtype():
paddle.concat([x7, x8])
self.assertRaises(TypeError, test_input_same_dtype)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from op_test_xpu import OpTest, XPUOpTest
import paddle
from paddle.fluid import Program, program_guard
def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h))
w_low = int(np.floor(w))
h_high = h_low + 1
w_high = w_low + 1
lh = h - h_low
lw = w - w_low
hh = 1 - lh
hw = 1 - lw
v1 = 0
if h_low >= 0 and w_low >= 0:
v1 = data_im[h_low, w_low]
v2 = 0
if h_low >= 0 and w_high <= width - 1:
v2 = data_im[h_low, w_high]
v3 = 0
if h_high <= height - 1 and w_low >= 0:
v3 = data_im[h_high, w_low]
v4 = 0
if h_high <= height - 1 and w_high <= width - 1:
v4 = data_im[h_high, w_high]
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
def dconv_im2col_gemm(input, offset, mask, filter, group, conv_param):
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
assert mask.shape == (in_n, f_h * f_w, in_h, in_w)
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
assert out_h == in_h
assert out_w == in_w
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
for n in range(in_n):
for c in range(in_c):
for h in range(out_h):
for w in range(out_w):
for kh in range(f_h):
for kw in range(f_w):
offset_h_table = \
offset[n, ::2, h, w].reshape(f_h, f_w)
offset_w_table = \
offset[n, 1::2, h, w].reshape(f_h, f_w)
mask_table = \
mask[n, :, h, w].reshape(f_h, f_w)
offset_h = offset_h_table[kh, kw]
offset_w = offset_w_table[kh, kw]
val = 0
im_h = h * stride[0] + kh * dilation[0] \
+ offset_h - pad[0]
im_w = w * stride[0] + kw * dilation[0] \
+ offset_w - pad[1]
if im_h > -1 and im_w > -1 and \
im_h < in_h and im_w < in_h:
val = dmc_bilinear(input[n, c], in_h, in_w,
im_h, im_w)
val_out = val * mask_table[kh, kw]
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
in_w + w] = val_out
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
col_buffer = col_buffer.reshape(
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
for n in range(in_n):
for g in range(group):
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
out = out.reshape(in_n, out_c, out_h, out_w)
return out
class TestModulatedDeformableConvOp(XPUOpTest):
def setUp(self):
self.op_type = "deformable_conv"
self.dtype = np.float32
self.init_group()
self.init_dilation()
self.init_test_case()
conv_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
mask = 10 * np.random.random(self.mask_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = dconv_im2col_gemm(input, offset, mask, filter, self.groups,
conv_param)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'Mask': OpTest.np_dtype_to_fluid_dtype(mask),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'deformable_groups': self.deformable_groups,
'im2col_step': self.im2col_step,
'dilations': self.dilations,
}
self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda() and (self.use_cudnn or
self.use_cuda)
def test_check_output(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if core.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, {'Input', 'Offset', 'Mask', 'Filter'},
'Output',
max_relative_error=0.06)
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 8, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [8, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [4, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [2, 2]
class TestWith3x3(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
mask_c = self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
self.mask_size = [
self.input_size[0], mask_c, self.input_size[2], self.input_size[3]
]
class TestModulatedDeformableConvInvalidInput(unittest.TestCase):
def test_error(self):
def test_invalid_input():
paddle.enable_static()
input = [1, 3, 32, 32]
offset = fluid.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32')
mask = fluid.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32')
loss = fluid.layers.deformable_conv(
input, offset, mask, num_filters=4, filter_size=1)
self.assertRaises(TypeError, test_invalid_input)
def test_invalid_offset():
paddle.enable_static()
input = fluid.data(
name='input', shape=[None, 3, 32, 32], dtype='int32')
offset = fluid.data(
name='offset', shape=[None, 3, 32, 32], dtype='float32')
mask = fluid.data(
name='mask', shape=[None, 3, 32, 32], dtype='float32')
loss = fluid.layers.deformable_conv(
input, offset, mask, num_filters=4, filter_size=1)
self.assertRaises(TypeError, test_invalid_offset)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,8 @@ import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
from op_test_xpu import OpTest, XPUOpTest
from op_test import skip_check_grad_ci
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
......@@ -26,180 +27,128 @@ from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import convert_np_dtype_to_dtype_
class TestSumOp(OpTest):
class TestXPUReduceSumOp(XPUOpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def check_grad_(self):
self.check_grad(['X'], 'Out')
class TestSumOp5D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((1, 2, 5, 6, 10)).astype("float64")
self.init_op_type()
self.initTestCase()
self.use_xpu = True
self.use_mkldnn = False
self.attrs = {
'dim': self.axis,
'keep_dim': self.keep_dim,
'reduce_all': self.reduce_all
}
self.attrs = {'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
if self.attrs['reduce_all']:
self.outputs = {'Out': self.inputs['X'].sum()}
else:
self.outputs = {
'Out': self.inputs['X'].sum(axis=self.axis,
keepdims=self.attrs['keep_dim'])
}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOp6D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64")
}
self.attrs = {'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self.check_grad_with_place(place, ['X'], 'Out')
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_op_type(self):
self.op_type = "reduce_sum"
self.use_mkldnn = False
self.keep_dim = False
self.reduce_all = False
def initTestCase(self):
self.shape = (5, 6, 10)
self.axis = (0, )
class TestSumOp8D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((1, 3, 1, 2, 1, 4, 3, 10)).astype("float64")
}
self.attrs = {'dim': (0, 3), 'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
class TestSumOp5D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 2, 5, 6, 10)
self.axis = (0, )
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSumOp6D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 1, 2, 5, 6, 10)
self.axis = (0, )
class Test1DReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random(120).astype("float64")}
self.attrs = {'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
class TestSumOp8D(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (1, 3, 1, 2, 1, 4, 3, 10)
self.axis = (0, 3)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class Test1DReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = 120
self.axis = (0, )
class Test2DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [0], 'use_xpu': True}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
class Test2DReduce0(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (20, 10)
self.axis = (0, )
class Test2DReduce1(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [1], 'use_xpu': True}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class Test2DReduce1(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (20, 10)
self.axis = (1, )
class Test3DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [1], 'use_xpu': True}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class Test3DReduce0(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (1, )
class Test3DReduce1(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [2], 'use_xpu': True}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class Test3DReduce1(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (2, )
class Test3DReduce2(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [-2], 'use_xpu': True}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class Test3DReduce2(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (-2, )
class Test3DReduce3(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.attrs = {'dim': [1, 2], 'use_xpu': True}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
class Test3DReduce3(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 7)
self.axis = (1, 2)
class TestKeepDimReduce(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': True, 'use_xpu': True}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']),
keepdims=self.attrs['keep_dim'])
}
class TestKeepDimReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 10)
self.axis = (1, )
self.keep_dim = True
class TestKeepDim8DReduce(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.random((2, 5, 3, 2, 2, 3, 4, 2)).astype("float64")
}
self.attrs = {'dim': (3, 4, 5), 'keep_dim': True, 'use_xpu': True}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']),
keepdims=self.attrs['keep_dim'])
}
class TestKeepDim8DReduce(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (2, 5, 3, 2, 2, 3, 4, 2)
self.axis = (3, 4, 5)
self.keep_dim = True
class TestReduceAll(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float64")}
self.attrs = {'reduce_all': True, 'use_xpu': True}
self.outputs = {'Out': self.inputs['X'].sum()}
class TestReduceAll(TestXPUReduceSumOp):
def initTestCase(self):
self.shape = (5, 6, 2, 10)
self.axis = (0, )
self.reduce_all = True
if __name__ == '__main__':
......
......@@ -13,15 +13,16 @@
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
from test_softmax_op import stable_softmax
from op_test import OpTest
from op_test_xpu import XPUOpTest
import paddle.fluid.core as core
import paddle
import unittest
import numpy as np
import sys
sys.path.append("..")
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
......@@ -44,7 +45,7 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
return result.reshape(label.shape)
class TestSoftmaxWithCrossEntropyOp(OpTest):
class TestSoftmaxWithCrossEntropyOp(XPUOpTest):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
"""
......
......@@ -19,24 +19,27 @@ import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import OpTest, XPUOpTest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
class TestXPUTransposeOp(OpTest):
class TestXPUTransposeOp(XPUOpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.inputs = {'X': np.random.random(self.shape).astype("float64")}
self.use_xpu = True
self.use_mkldnn = False
self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': False,
'use_xpu': True
}
self.outputs = {
'XShape': np.random.random(self.shape).astype("float64"),
'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis)
}
......@@ -121,110 +124,5 @@ class TestCase9(TestXPUTransposeOp):
self.axis = (6, 1, 3, 5, 0, 2, 4, 7)
class TestTransposeOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[10, 5, 3], dtype='float64')
def test_x_Variable_check():
# the Input(x)'s type must be Variable
fluid.layers.transpose("not_variable", perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_Variable_check)
def test_x_dtype_check():
# the Input(x)'s dtype must be one of [float16, float32, float64, int32, int64]
x1 = fluid.layers.data(
name='x1', shape=[10, 5, 3], dtype='bool')
fluid.layers.transpose(x1, perm=[1, 0, 2])
self.assertRaises(TypeError, test_x_dtype_check)
def test_perm_list_check():
# Input(perm)'s type must be list
fluid.layers.transpose(x, perm="[1, 0, 2]")
self.assertRaises(TypeError, test_perm_list_check)
def test_perm_length_and_x_dim_check():
# Input(perm) is the permutation of dimensions of Input(input)
# its length should be equal to dimensions of Input(input)
fluid.layers.transpose(x, perm=[1, 0, 2, 3, 4])
self.assertRaises(ValueError, test_perm_length_and_x_dim_check)
def test_each_elem_value_check():
# Each element in Input(perm) should be less than Input(x)'s dimension
fluid.layers.transpose(x, perm=[3, 5, 7])
self.assertRaises(ValueError, test_each_elem_value_check)
class TestTAPI(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10], dtype="float64", name="data")
data_t = paddle.t(data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data_np = np.random.random([10]).astype("float64")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[10, 5], dtype="float64", name="data")
data_t = paddle.t(data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data_np = np.random.random([10, 5]).astype("float64")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
data = fluid.data(shape=[1, 5], dtype="float64", name="data")
data_t = paddle.t(data)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data_np = np.random.random([1, 5]).astype("float64")
result, = exe.run(feed={"data": data_np}, fetch_list=[data_t])
expected_result = np.transpose(data_np)
self.assertEqual((result == expected_result).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([10]).astype("float64")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([10, 5]).astype("float64")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
with fluid.dygraph.guard():
np_x = np.random.random([1, 5]).astype("float64")
data = fluid.dygraph.to_variable(np_x)
z = paddle.t(data)
np_z = z.numpy()
z_expected = np.array(np.transpose(np_x))
self.assertEqual((np_z == z_expected).all(), True)
def test_errors(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name='x', shape=[10, 5, 3], dtype='float64')
def test_x_dimension_check():
paddle.t(x)
self.assertRaises(ValueError, test_x_dimension_check)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册