未验证 提交 03ddf690 编写于 作者: Q QingshuChen 提交者: GitHub

cherry-pick kunlun PR: 29458, 29539 (#29583)

* support mobilenet for kunlun (#29458)

* add xpu ops for training transformer in kunlun (#29539)

* 1.fix matmul bug 2. add one hot

* add xpu error msg
Co-authored-by: Nprocr <procrboo@gmail.com>
Co-authored-by: Ntaixiurong <taixiurong@126.com>
上级 d82d59e6
...@@ -4,7 +4,7 @@ endif() ...@@ -4,7 +4,7 @@ endif()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(XPU_PROJECT "extern_xpu") SET(XPU_PROJECT "extern_xpu")
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_04.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_07_cdfbf0c.tar.gz" CACHE STRING "" FORCE)
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}")
SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu")
......
...@@ -61,13 +61,38 @@ void xpu_activation_forward(const framework::ExecutionContext &ctx, ...@@ -61,13 +61,38 @@ void xpu_activation_forward(const framework::ExecutionContext &ctx,
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace()); T *y_data = y->mutable_data<T>(ctx.GetPlace());
int r = 0; int r = 0;
if (xpu::Activation_t::ACT_POW == type.type) { auto xpu_context = ctx.device_context<DeviceContext>().x_context();
switch (type.type) {
case xpu::Activation_t::HARD_SWISH: {
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
PADDLE_ENFORCE_EQ(threshold, 6.0f,
platform::errors::External(
"Not support threshold [%f] in XPU", threshold));
PADDLE_ENFORCE_EQ(
scale, 6.0f,
platform::errors::External("Not support scale [%f] in XPU", scale));
PADDLE_ENFORCE_EQ(
offset, 3.0f,
platform::errors::External("Not support offset [%f] in XPU", offset));
r = xpu::hard_swish(xpu_context, reinterpret_cast<const float *>(x_data),
reinterpret_cast<float *>(y_data), x->numel());
break;
}
case xpu::Activation_t::ACT_POW: {
type.pow_factor = ctx.Attr<float>("factor"); type.pow_factor = ctx.Attr<float>("factor");
} }
auto xpu_context = ctx.device_context<DeviceContext>().x_context(); default: {
r = xpu::activation_forward(xpu_context, type, x->numel(), r = xpu::activation_forward(xpu_context, type, x->numel(),
reinterpret_cast<const float *>(x_data), reinterpret_cast<const float *>(x_data),
reinterpret_cast<float *>(y_data)); reinterpret_cast<float *>(y_data));
break;
}
}
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU API return wrong value[%d], please check whether "
...@@ -90,12 +115,40 @@ void xpu_activation_backward(const framework::ExecutionContext &ctx, ...@@ -90,12 +115,40 @@ void xpu_activation_backward(const framework::ExecutionContext &ctx,
if (y != nullptr) y_data = y->data<T>(); if (y != nullptr) y_data = y->data<T>();
if (dOut != nullptr) y_grad = dOut->data<T>(); if (dOut != nullptr) y_grad = dOut->data<T>();
T *x_grad = dX->mutable_data<T>(ctx.GetPlace()); T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
int r = 0;
auto xpu_context = ctx.device_context<DeviceContext>().x_context(); auto xpu_context = ctx.device_context<DeviceContext>().x_context();
int r = xpu::activation_backward(xpu_context, type, dX->numel(),
switch (type.type) {
case xpu::Activation_t::HARD_SWISH: {
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
float offset = ctx.Attr<float>("offset");
PADDLE_ENFORCE_EQ(threshold, 6.0f,
platform::errors::External(
"Not support threshold [%f] in XPU", threshold));
PADDLE_ENFORCE_EQ(
scale, 6.0f,
platform::errors::External("Not support scale [%f] in XPU", scale));
PADDLE_ENFORCE_EQ(
offset, 3.0f,
platform::errors::External("Not support offset [%f] in XPU", offset));
r = xpu::hard_swish_grad(xpu_context,
reinterpret_cast<const float *>(x_data),
reinterpret_cast<const float *>(y_data),
reinterpret_cast<const float *>(y_grad),
reinterpret_cast<float *>(x_grad), dX->numel());
break;
}
default: {
r = xpu::activation_backward(xpu_context, type, dX->numel(),
reinterpret_cast<const float *>(x_data), reinterpret_cast<const float *>(x_data),
reinterpret_cast<const float *>(y_data), reinterpret_cast<const float *>(y_data),
reinterpret_cast<const float *>(y_grad), reinterpret_cast<const float *>(y_grad),
reinterpret_cast<float *>(x_grad)); reinterpret_cast<float *>(x_grad));
break;
}
}
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU API return wrong value[%d], please check whether "
...@@ -132,6 +185,8 @@ using XPULogFunctor = XPUActivationFunc<T, xpu::Activation_t::LOG>; ...@@ -132,6 +185,8 @@ using XPULogFunctor = XPUActivationFunc<T, xpu::Activation_t::LOG>;
template <typename T> template <typename T>
using XPUSquareFunctor = XPUActivationFunc<T, xpu::Activation_t::SQUARE>; using XPUSquareFunctor = XPUActivationFunc<T, xpu::Activation_t::SQUARE>;
template <typename T> template <typename T>
using XPUHardSwishFunctor = XPUActivationFunc<T, xpu::Activation_t::HARD_SWISH>;
template <typename T>
using XPUSuareGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQUARE>; using XPUSuareGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQUARE>;
template <typename T> template <typename T>
using XPUReluGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::RELU>; using XPUReluGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::RELU>;
...@@ -147,6 +202,9 @@ using XPUSqrtFunctor = XPUActivationFunc<T, xpu::Activation_t::SQRT>; ...@@ -147,6 +202,9 @@ using XPUSqrtFunctor = XPUActivationFunc<T, xpu::Activation_t::SQRT>;
template <typename T> template <typename T>
using XPUSqrtGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQRT>; using XPUSqrtGradFunctor = XPUActivationGradFunc<T, xpu::Activation_t::SQRT>;
template <typename T> template <typename T>
using XPUHardSwishGradFunctor =
XPUActivationGradFunc<T, xpu::Activation_t::HARD_SWISH>;
template <typename T>
using XPUACTPowFunctor = XPUActivationFunc<T, xpu::Activation_t::ACT_POW>; using XPUACTPowFunctor = XPUActivationFunc<T, xpu::Activation_t::ACT_POW>;
template <typename T> template <typename T>
using XPUABSFunctor = XPUActivationFunc<T, xpu::Activation_t::ABS>; using XPUABSFunctor = XPUActivationFunc<T, xpu::Activation_t::ABS>;
...@@ -169,6 +227,8 @@ REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, ...@@ -169,6 +227,8 @@ REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSuareGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSuareGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, XPUHardSwishFunctor,
XPUHardSwishGradFunctor)
REGISTER_OP_XPU_KERNEL(log, REGISTER_OP_XPU_KERNEL(log,
ops::XPUActivationKernel<ops::XPULogFunctor<float>>); ops::XPUActivationKernel<ops::XPULogFunctor<float>>);
REGISTER_OP_XPU_KERNEL(pow, REGISTER_OP_XPU_KERNEL(pow,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "xpu/refactor/math.h"
namespace paddle {
namespace operators {
typedef enum { XPU_OR, XPU_AND } XpuLogicalType;
std::string XpuLogicalType2Str(XpuLogicalType ty) {
switch (ty) {
case XpuLogicalType::XPU_OR:
return std::string("logical or");
case XpuLogicalType::XPU_AND:
return std::string("logical and");
default:
return std::string("unknown type");
}
return std::string("unknown");
}
template <XpuLogicalType xpu_type, typename T>
class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
T* out_ptr = out->mutable_data<T>(context.GetPlace());
const T* x_ptr = x->data<T>();
const T* y_ptr = y->data<T>();
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
framework::Tensor broadcast_x;
framework::Tensor broadcast_y;
bool need_broad_cast = false;
if (x->numel() != out->numel()) {
// x need broadcast
T* broadcast_x_ptr =
broadcast_x.mutable_data<T>(context.GetPlace(), out->numel());
auto& out_dim = out->dims();
auto& x_dim = x->dims();
int dims = out_dim.size();
std::vector<int> bcast_xdims;
std::vector<int> bcast_ydims;
for (int i = 0; i < dims; ++i) {
if (out_dim[i] == x_dim[i]) {
bcast_xdims.push_back(x_dim[i]);
bcast_ydims.push_back(x_dim[i]);
continue;
}
bcast_xdims.push_back(1);
bcast_xdims.push_back(x_dim[i]);
bcast_ydims.push_back(out_dim[i] / x_dim[i]);
bcast_ydims.push_back(x_dim[i]);
}
int ret = xpu::broadcast<int8_t>(
dev_ctx.x_context(), reinterpret_cast<const int8_t*> x_ptr,
reinterpret_cast<int8_t*> broadcast_x_ptr, bcast_xdims, bcast_ydims);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU broadcast kernel return wrong value[%d %s]",
ret, XPUAPIErrorMsg[ret]));
x_ptr = (const T*)broadcast_x_ptr;
need_broad_cast = true;
}
if (y->numel() != out->numel()) {
// y need broadcast
T* broadcast_y_ptr =
broadcast_y.mutable_data<T>(context.GetPlace(), out->numel());
auto& out_dim = out->dims();
auto& y_dim = y->dims();
int dims = out_dim.size();
std::vector<int> bcast_xdims;
std::vector<int> bcast_ydims;
for (int i = 0; i < dims; ++i) {
if (out_dim[i] == y_dim[i]) {
bcast_xdims.push_back(y_dim[i]);
bcast_ydims.push_back(y_dim[i]);
continue;
}
bcast_xdims.push_back(1);
bcast_xdims.push_back(y_dim[i]);
bcast_ydims.push_back(out_dim[i] / y_dim[i]);
bcast_ydims.push_back(y_dim[i]);
}
int ret = xpu::broadcast<int8_t>(
dev_ctx.x_context(), reinterpret_cast<const int8_t*> y_ptr,
reinterpret_cast<int8_t*> broadcast_y_ptr, bcast_xdims, bcast_ydims);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU broadcast kernel return wrong value[%d %s]",
ret, XPUAPIErrorMsg[ret]));
y_ptr = (const T*)broadcast_y_ptr;
need_broad_cast = true;
}
// logical kernel
int ret = XPU_SUCCESS;
switch (xpu_type) {
case XpuLogicalType::XPU_OR:
ret = xpu::logical_or<bool>(dev_ctx.x_context(), x_ptr, y_ptr, out_ptr,
out->numel());
break;
case XpuLogicalType::XPU_AND:
ret = xpu::logical_and<bool>(dev_ctx.x_context(), x_ptr, y_ptr, out_ptr,
out->numel());
default:
LOG(ERROR) << "xpu not support logical xpu type = "
<< XpuLogicalType2Str(xpu_type);
break;
}
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s] in "
"op_name[%s].",
ret, XPUAPIErrorMsg[ret],
XpuLogicalType2Str(xpu_type)));
if (need_broad_cast && dev_ctx.x_context()->xpu_stream != nullptr) {
xpu_wait();
}
}
};
template <typename T>
class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
if (x->numel() == 0) {
return;
}
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::logical_not<bool>(dev_ctx.x_context(), x->data<T>(),
out->data<T>(), x->numel());
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s].", ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_and,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel<bool>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_or,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>);
#endif
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -120,30 +121,40 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -120,30 +121,40 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
float *data_c = out->data<T>(); float *data_c = out->data<T>();
if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) { int m = mat_dim_a.height_;
int r = int n = mat_dim_b.width_;
xpu::fc_int16(dev_ctx.x_context(), mat_dim_a.trans_, mat_dim_b.trans_, int k = mat_dim_a.width_;
mat_dim_a.height_, mat_dim_b.width_, mat_dim_a.width_, int ldx = mat_dim_a.trans_ ? m : k;
alpha, x->data<T>(), y->data<T>(), 0.0f, data_c); int ldy = mat_dim_b.trans_ ? k : n;
PADDLE_ENFORCE_EQ( int ldout = n;
r, XPU_SUCCESS, int batch_size = mat_dim_a.batch_size_;
if (batch_size == 0 || batch_size == 1) {
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), x->data<T>(), y->data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU fc_fusion kernel return wrong value[%d %s]", r,
"Baidu Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
} else { } else {
// batch matmul // batch matmul
int r = xpu::batched_gemm_int16(dev_ctx.x_context(), mat_dim_a.trans_, int x_stride = mat_dim_a.stride_;
mat_dim_b.trans_, mat_dim_a.batch_size_, int y_stride = mat_dim_b.stride_;
mat_dim_a.height_, mat_dim_b.width_, int out_stride = m * n;
mat_dim_a.width_, alpha, x->data<T>(), for (int i = 0; i < batch_size; ++i) {
y->data<T>(), data_c, nullptr, nullptr); const float *x_data = x->data<T>() + x_stride * i;
PADDLE_ENFORCE_EQ( const float *y_data = y->data<T>() + y_stride * i;
r, XPU_SUCCESS, float *out_data = data_c + out_stride * i;
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU fc_fusion kernel return wrong value[%d %s]",
"Baidu Kunlun Card is properly installed.", r, XPUAPIErrorMsg[r]));
r)); }
} }
} }
}; };
...@@ -171,9 +182,8 @@ static framework::Tensor XPUFoldHeadAndLastDims( ...@@ -171,9 +182,8 @@ static framework::Tensor XPUFoldHeadAndLastDims(
in_shape_host.data(), axis_host.data(), /*ndims=*/3); in_shape_host.data(), axis_host.data(), /*ndims=*/3);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU transpose kernel return wrong value[%d %s]", r,
"Baidu Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output; return output;
...@@ -224,30 +234,41 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> { ...@@ -224,30 +234,41 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
float *data_c = out->data<T>(); float *data_c = out->data<T>();
if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) {
int r = int m = mat_dim_a.height_;
xpu::fc_int16(dev_ctx.x_context(), mat_dim_a.trans_, mat_dim_b.trans_, int n = mat_dim_b.width_;
mat_dim_a.height_, mat_dim_b.width_, mat_dim_a.width_, int k = mat_dim_a.width_;
alpha, a.data<T>(), b.data<T>(), 0.0f, data_c); int ldx = mat_dim_a.trans_ ? m : k;
PADDLE_ENFORCE_EQ( int ldy = mat_dim_b.trans_ ? k : n;
r, XPU_SUCCESS, int ldout = n;
int batch_size = mat_dim_a.batch_size_;
if (batch_size == 0 || batch_size == 1) {
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), a.data<T>(), b.data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU fc_fusion kernel return wrong value[%d %s]", r,
"Baidu Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
} else { } else {
// batch matmul // batch matmul
int r = xpu::batched_gemm_int16(dev_ctx.x_context(), mat_dim_a.trans_, int x_stride = mat_dim_a.stride_;
mat_dim_b.trans_, mat_dim_a.batch_size_, int y_stride = mat_dim_b.stride_;
mat_dim_a.height_, mat_dim_b.width_, int out_stride = m * n;
mat_dim_a.width_, alpha, a.data<T>(), for (int i = 0; i < batch_size; ++i) {
b.data<T>(), data_c, nullptr, nullptr); const float *x_data = a.data<T>() + x_stride * i;
PADDLE_ENFORCE_EQ( const float *y_data = b.data<T>() + y_stride * i;
r, XPU_SUCCESS, float *out_data = data_c + out_stride * i;
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU API return wrong value[%d], please check whether " "XPU fc_fusion kernel return wrong value[%d %s]",
"Baidu Kunlun Card is properly installed.", r, XPUAPIErrorMsg[r]));
r)); }
} }
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/one_hot_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class OneHotXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
if (depth_tensor->place() == platform::XPUPlace()) {
xpu_memcpy(static_cast<void*>(&depth),
static_cast<const void*>(depth_data), sizeof(int32_t),
XPU_DEVICE_TO_HOST);
} else {
depth = depth_data[0];
}
auto in_dims = in->dims();
framework::DDim out_dims(in_dims);
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int len = in->numel();
int ret = xpu::one_hot<T>(dev_ctx.x_context(), in->data<T>(),
out->mutable_data<float>(context.GetPlace()), len,
depth);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU one_hot kernel return wrong value[%d %s]", ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
one_hot, ops::OneHotXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::OneHotXPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "xpu/api.h" #include "xpu/api.h"
#include "xpu/refactor/fusion.h"
#include "xpu/refactor/math.h" #include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h" #include "xpu/refactor/nn.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
from op_test_xpu import XPUOpTest
from scipy.special import expit, erf from scipy.special import expit, erf
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -30,7 +31,7 @@ from paddle.fluid import compiler, Program, program_guard ...@@ -30,7 +31,7 @@ from paddle.fluid import compiler, Program, program_guard
@unittest.skipIf(not paddle.is_compiled_with_xpu(), @unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU") "core is not compiled with XPU")
class TestXPUActivation(OpTest): class TestXPUActivation(XPUOpTest):
def setUp(self): def setUp(self):
self.op_type = "exp" self.op_type = "exp"
self.init_dtype() self.init_dtype()
...@@ -166,6 +167,33 @@ def gelu(x, approximate): ...@@ -166,6 +167,33 @@ def gelu(x, approximate):
return y_ref.astype(x.dtype) return y_ref.astype(x.dtype)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUHardSwish(TestXPUActivation):
def setUp(self):
self.op_type = "hard_swish"
self.init_dtype()
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
offset = 3.0
threshold = 6.0
scale = 6.0
out = hard_swish(x, offset, threshold, scale)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True}
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def hard_swish(x, offset, threshold, scale):
y_ref = np.minimum(threshold, np.maximum(0, x + offset)) * x / scale
return y_ref.astype(x.dtype)
@unittest.skipIf(not paddle.is_compiled_with_xpu(), @unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU") "core is not compiled with XPU")
class TestXPULog(TestXPUActivation): class TestXPULog(TestXPUActivation):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from paddle.fluid.op import Operator
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle
from op_test_xpu import XPUOpTest
from paddle.static import Program, program_guard
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
}, {
'op_str': 'logical_or',
'binary_op': True
}, {
'op_str': 'logical_not',
'binary_op': False
}]
TEST_META_SHAPE_DATA = {
'XDimLargerThanYDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 5]
},
'XDimLargerThanYDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 1]
},
'XDimLargerThanYDim3': {
'x_shape': [2, 3, 4, 5],
'y_shape': [1, 4, 1]
},
'XDimLargerThanYDim4': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4, 1]
},
'XDimLargerThanYDim5': {
'x_shape': [2, 3, 1, 5],
'y_shape': [3, 1, 1]
},
'XDimLessThanYDim1': {
'x_shape': [4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim2': {
'x_shape': [1, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim3': {
'x_shape': [3, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim4': {
'x_shape': [3, 1, 1],
'y_shape': [2, 3, 1, 5]
},
'XDimLessThanYDim5': {
'x_shape': [4, 5],
'y_shape': [2, 3, 4, 5]
},
'Axis1InLargerDim': {
'x_shape': [1, 4, 5],
'y_shape': [2, 3, 1, 5]
},
'EqualDim1': {
'x_shape': [10, 7],
'y_shape': [10, 7]
},
'EqualDim2': {
'x_shape': [1, 1, 4, 5],
'y_shape': [2, 3, 1, 5]
}
}
TEST_META_WRONG_SHAPE_DATA = {
'ErrorDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4]
},
'ErrorDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 3]
}
}
def run_static_xpu(x_np, y_np, op_str, binary_op=True):
paddle.enable_static()
startup_program = fluid.Program()
main_program = fluid.Program()
place = paddle.XPUPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
static_result = exe.run(main_program, feed=feed_list, fetch_list=[res])
return static_result
def run_dygraph_xpu(x_np, y_np, op_str, binary_op=True):
place = paddle.XPUPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, *args, **kwargs):
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
def test_xpu(unit_test, test_error=False):
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
np_op = getattr(np, meta_data['op_str'])
META_DATA = dict(TEST_META_SHAPE_DATA)
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static_xpu,
**meta_data)
continue
static_result = run_static_xpu(**meta_data)
dygraph_result = run_dygraph_xpu(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
def test_type_error(unit_test, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = TypeError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid.in_dygraph_mode():
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if type_str_map['x'] != 'bool':
unit_test.assertRaises(error_type, op, x=x)
if not fluid.in_dygraph_mode():
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.XPUPlace(0)
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
binary_op = meta_data['binary_op']
paddle.disable_static(place)
x = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['x'])
y = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name='x', shape=[10], dtype=type_str_map['x'])
y = paddle.static.data(
name='y', shape=[10], dtype=type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
def type_map_factory():
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
return [{
'x': x_type,
'y': y_type
} for x_type in x_type_list for y_type in y_type_list]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPU(unittest.TestCase):
def test(self):
test_xpu(self, True)
def test_error(self):
test_xpu(self, True)
def test_type_error(self):
type_map_list = type_map_factory()
for type_map in type_map_list:
test_type_error(self, type_map)
if __name__ == '__main__':
unittest.main()
...@@ -19,11 +19,13 @@ sys.path.append("..") ...@@ -19,11 +19,13 @@ sys.path.append("..")
import paddle.fluid.core as core import paddle.fluid.core as core
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test_xpu import XPUOpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
paddle.enable_static()
def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y): def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y):
BATCH_SIZE = 2 BATCH_SIZE = 2
...@@ -92,7 +94,9 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -92,7 +94,9 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
class Generator(object): class Generator(object):
def setUp(self): def setUp(self):
self.use_xpu = True
self.op_type = "matmul" self.op_type = "matmul"
# self.init_test_case()
X = np.random.random(self.shape_X).astype("float32") X = np.random.random(self.shape_X).astype("float32")
Y = np.random.random(self.shape_Y).astype("float32") Y = np.random.random(self.shape_Y).astype("float32")
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y) Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y)
...@@ -104,7 +108,7 @@ class Generator(object): ...@@ -104,7 +108,7 @@ class Generator(object):
self.outputs = {'Out': Out} self.outputs = {'Out': Out}
def test_check_output(self): def test_check_output(self):
self.check_output()
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len(
self.inputs['Y'].shape) and self.inputs['X'].shape[ self.inputs['Y'].shape) and self.inputs['X'].shape[
0] == self.inputs['Y'].shape[0]: 0] == self.inputs['Y'].shape[0]:
...@@ -112,7 +116,7 @@ class Generator(object): ...@@ -112,7 +116,7 @@ class Generator(object):
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-3)
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len(
self.inputs['Y'].shape) and self.inputs['X'].shape[ self.inputs['Y'].shape) and self.inputs['X'].shape[
0] == self.inputs['Y'].shape[0]: 0] == self.inputs['Y'].shape[0]:
...@@ -121,8 +125,7 @@ class Generator(object): ...@@ -121,8 +125,7 @@ class Generator(object):
place, ['X', 'Y'], 'Out', max_relative_error=5e-2) place, ['X', 'Y'], 'Out', max_relative_error=5e-2)
def test_check_grad_ignore_x(self): def test_check_grad_ignore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=1e-3, no_grad_set=set("X"))
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len(
self.inputs['Y'].shape) and self.inputs['X'].shape[ self.inputs['Y'].shape) and self.inputs['X'].shape[
0] == self.inputs['Y'].shape[0]: 0] == self.inputs['Y'].shape[0]:
...@@ -134,8 +137,7 @@ class Generator(object): ...@@ -134,8 +137,7 @@ class Generator(object):
no_grad_set=set("X")) no_grad_set=set("X"))
def test_check_grad_ignore_y(self): def test_check_grad_ignore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=1e-3, no_grad_set=set('Y'))
if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len(
self.inputs['Y'].shape) and self.inputs['X'].shape[ self.inputs['Y'].shape) and self.inputs['X'].shape[
0] == self.inputs['Y'].shape[0]: 0] == self.inputs['Y'].shape[0]:
...@@ -192,7 +194,7 @@ def test_negative_dims_program(obj): ...@@ -192,7 +194,7 @@ def test_negative_dims_program(obj):
for idx in range(len(Ref.shape)): for idx in range(len(Ref.shape)):
if output.shape[idx] != -1: if output.shape[idx] != -1:
obj.assertEqual(Ref.shape[idx], output.shape[idx]) obj.assertEqual(Ref.shape[idx], output.shape[idx])
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.XPUPlace(0))
res, = exe.run(fluid.default_main_program(), res, = exe.run(fluid.default_main_program(),
feed={'x': X, feed={'x': X,
'y': Y}, 'y': Y},
...@@ -221,7 +223,7 @@ def inject_test(dim_x, dim_y, trans_x, trans_y): ...@@ -221,7 +223,7 @@ def inject_test(dim_x, dim_y, trans_x, trans_y):
dim_x, dim_y, trans_x, trans_y)) dim_x, dim_y, trans_x, trans_y))
shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x,
trans_y) trans_y)
globals()[test_name] = type(test_name, (Generator, OpTest), { globals()[test_name] = type(test_name, (Generator, XPUOpTest), {
'shape_X': shape_x, 'shape_X': shape_x,
'shape_Y': shape_y, 'shape_Y': shape_y,
'transpose_X': trans_x, 'transpose_X': trans_x,
...@@ -231,8 +233,9 @@ def inject_test(dim_x, dim_y, trans_x, trans_y): ...@@ -231,8 +233,9 @@ def inject_test(dim_x, dim_y, trans_x, trans_y):
for dim_X in (1, 2, 3): for dim_X in (1, 2, 3):
for dim_Y in (1, 2, 3): for dim_Y in (1, 2, 3):
for transose_x in (False, True): transose_x = False
for transose_y in (False, True): transose_y = False
if dim_X == 3 and dim_Y == 3:
inject_test(dim_X, dim_Y, transose_x, transose_y) inject_test(dim_X, dim_Y, transose_x, transose_y)
api_test(dim_X, dim_Y, transose_x, transose_y) api_test(dim_X, dim_Y, transose_x, transose_y)
...@@ -267,7 +270,7 @@ for dim in [4]: ...@@ -267,7 +270,7 @@ for dim in [4]:
dim, dim, transpose_X, transpose_Y)) dim, dim, transpose_X, transpose_Y))
shape_X, shape_Y = generate_compatible_shapes(dim, transpose_X, shape_X, shape_Y = generate_compatible_shapes(dim, transpose_X,
transpose_Y) transpose_Y)
globals()[test_name] = type(test_name, (Generator, OpTest), { globals()[test_name] = type(test_name, (Generator, XPUOpTest), {
'shape_X': shape_X, 'shape_X': shape_X,
'shape_Y': shape_Y, 'shape_Y': shape_Y,
'transpose_X': transpose_X, 'transpose_X': transpose_X,
...@@ -282,7 +285,7 @@ class API_TestMm(unittest.TestCase): ...@@ -282,7 +285,7 @@ class API_TestMm(unittest.TestCase):
y = fluid.data(name='y', shape=[2], dtype='float64') y = fluid.data(name='y', shape=[2], dtype='float64')
res = fluid.data(name="output", shape=[1], dtype="float64") res = fluid.data(name="output", shape=[1], dtype="float64")
result = paddle.mm(x, y) result = paddle.mm(x, y)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.XPUPlace(0))
data1 = np.random.rand(2) data1 = np.random.rand(2)
data2 = np.random.rand(2) data2 = np.random.rand(2)
np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result]) np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result])
...@@ -296,7 +299,7 @@ class API_TestMm(unittest.TestCase): ...@@ -296,7 +299,7 @@ class API_TestMm(unittest.TestCase):
{}\n{}, check diff!".format(np_res, expected_result)) {}\n{}, check diff!".format(np_res, expected_result))
def test_dygraph_without_out(self): def test_dygraph_without_out(self):
device = fluid.CPUPlace() device = fluid.XPUPlace(0)
with fluid.dygraph.guard(device): with fluid.dygraph.guard(device):
input_array1 = np.random.rand(3, 4).astype("float64") input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64") input_array2 = np.random.rand(4, 3).astype("float64")
...@@ -309,7 +312,7 @@ class API_TestMm(unittest.TestCase): ...@@ -309,7 +312,7 @@ class API_TestMm(unittest.TestCase):
class Test_API_Matmul(unittest.TestCase): class Test_API_Matmul(unittest.TestCase):
def test_dygraph_without_out(self): def test_dygraph_without_out(self):
device = fluid.CPUPlace() device = fluid.XPUPlace(0)
with fluid.dygraph.guard(device): with fluid.dygraph.guard(device):
input_array1 = np.random.rand(3, 4).astype("float64") input_array1 = np.random.rand(3, 4).astype("float64")
input_array2 = np.random.rand(4, 3).astype("float64") input_array2 = np.random.rand(4, 3).astype("float64")
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import time
paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOp(XPUOpTest):
def setUp(self):
self.use_xpu = True
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOp_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOp_default_dtype(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOp_default_dtype_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOp_out_of_range(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestOneHotOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input must be Variable
in_w = np.random.random((4, 1)).astype("int32")
self.assertRaises(TypeError, fluid.layers.one_hot, in_w)
# the input must be int32 or int 64
in_w2 = fluid.layers.data(
name="in_w2",
shape=[4, 1],
append_batch_size=False,
dtype="float32")
self.assertRaises(TypeError, fluid.layers.one_hot, in_w2)
# the depth must be int, long or Variable
in_r = fluid.layers.data(
name="in_r",
shape=[4, 1],
append_batch_size=False,
dtype="int32")
depth_w = np.array([4])
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, 4.1)
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, depth_w)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册