提交 25e070ec 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into fea/jit/vadd

...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) { ...@@ -57,10 +57,10 @@ ThreadPool::ThreadPool(int num_threads) : running_(true) {
ThreadPool::~ThreadPool() { ThreadPool::~ThreadPool() {
{ {
// notify all threads to stop running // notify all threads to stop running
std::lock_guard<std::mutex> l(mutex_); std::unique_lock<std::mutex> l(mutex_);
running_ = false; running_ = false;
scheduled_.notify_all();
} }
scheduled_.notify_all();
for (auto& t : threads_) { for (auto& t : threads_) {
t->join(); t->join();
...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() { ...@@ -70,19 +70,25 @@ ThreadPool::~ThreadPool() {
void ThreadPool::TaskLoop() { void ThreadPool::TaskLoop() {
while (true) { while (true) {
std::unique_lock<std::mutex> lock(mutex_); Task task;
scheduled_.wait( {
lock, [this] { return !this->tasks_.empty() || !this->running_; }); std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(
lock, [this] { return !this->tasks_.empty() || !this->running_; });
if (!running_ || tasks_.empty()) { if (!running_ && tasks_.empty()) {
return; return;
} }
if (tasks_.empty()) {
PADDLE_THROW("This thread has no task to Run");
}
// pop a task from the task queue // pop a task from the task queue
auto task = std::move(tasks_.front()); task = std::move(tasks_.front());
tasks_.pop(); tasks_.pop();
lock.unlock(); }
// run the task // run the task
task(); task();
......
...@@ -58,7 +58,7 @@ class ThreadPool { ...@@ -58,7 +58,7 @@ class ThreadPool {
~ThreadPool(); ~ThreadPool();
// Run pushes a function to the task queue and returns a std::future // Run pushes a function to the task queue and returns a std::future
// object. To wait for the completion of the task, call // object. To wait for the completion of the task, call
// std::future::wait(). // std::future::wait().
template <typename Callback> template <typename Callback>
std::future<void> Run(Callback fn) { std::future<void> Run(Callback fn) {
...@@ -69,7 +69,6 @@ class ThreadPool { ...@@ -69,7 +69,6 @@ class ThreadPool {
template <typename Callback> template <typename Callback>
std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException( std::future<std::unique_ptr<platform::EnforceNotMet>> RunAndGetException(
Callback fn) { Callback fn) {
std::unique_lock<std::mutex> lock(mutex_);
Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> { Task task([fn]() -> std::unique_ptr<platform::EnforceNotMet> {
try { try {
fn(); fn();
...@@ -84,7 +83,13 @@ class ThreadPool { ...@@ -84,7 +83,13 @@ class ThreadPool {
return nullptr; return nullptr;
}); });
std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future(); std::future<std::unique_ptr<platform::EnforceNotMet>> f = task.get_future();
tasks_.push(std::move(task)); {
std::unique_lock<std::mutex> lock(mutex_);
if (!running_) {
PADDLE_THROW("enqueue on stopped ThreadPool");
}
tasks_.push(std::move(task));
}
scheduled_.notify_one(); scheduled_.notify_one();
return f; return f;
} }
......
...@@ -26,6 +26,8 @@ namespace plat = paddle::platform; ...@@ -26,6 +26,8 @@ namespace plat = paddle::platform;
act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \ act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<float>>, \ ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \ ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
...@@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -333,8 +333,7 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out); dx.device(d) = static_cast<T>(0.5) * dout / out;
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
} }
}; };
...@@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -740,7 +739,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) * dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1))); x.pow(static_cast<T>(factor) - static_cast<T>(1));
} }
}; };
......
...@@ -219,8 +219,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -219,8 +219,8 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace()); d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace()); d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
...@@ -272,8 +272,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -272,8 +272,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *saved_mean = ctx.Input<Tensor>("SavedMean"); const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance"); const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data = saved_mean->template data<T>(); const void *saved_mean_data =
const void *saved_var_data = saved_var->template data<T>(); saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(), dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
...@@ -281,10 +283,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -281,10 +283,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(), CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_, data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_, d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<T>(ctx.GetPlace()), d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<T>(ctx.GetPlace()), epsilon, d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
saved_mean_data, saved_var_data)); epsilon, saved_mean_data, saved_var_data));
// clean when exit. // clean when exit.
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
...@@ -304,4 +306,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -304,4 +306,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>); ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>, batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>); ops::BatchNormGradKernel<plat::CUDADeviceContext, double>,
ops::BatchNormGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -143,9 +143,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -143,9 +143,11 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo // Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
VLOG(5) << "use cudnn_tensor_op_math";
} else { } else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); cudnn_conv_desc, CUDNN_DEFAULT_MATH));
VLOG(5) << "NOT use cudnn_tensor_op_math";
} }
#endif #endif
...@@ -361,7 +363,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace, ...@@ -361,7 +363,8 @@ REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<plat::float16>); paddle::operators::CUDNNConvOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>, paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>); paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>, paddle::operators::CUDNNConvOpKernel<float>,
......
...@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h" #include "paddle/fluid/operators/cross_entropy_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDACtx = paddle::platform::CUDADeviceContext; using CUDACtx = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(cross_entropy, REGISTER_OP_CUDA_KERNEL(cross_entropy,
ops::CrossEntropyOpKernel<CUDACtx, float>, ops::CrossEntropyOpKernel<CUDACtx, float>,
ops::CrossEntropyOpKernel<CUDACtx, double>); ops::CrossEntropyOpKernel<CUDACtx, double>,
REGISTER_OP_CUDA_KERNEL(cross_entropy_grad, ops::CrossEntropyOpKernel<CUDACtx, plat::float16>);
ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>); REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
...@@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -30,4 +30,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>); ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -365,7 +365,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( ...@@ -365,7 +365,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
int j = blockIdx.x; int j = blockIdx.x;
int i = threadIdx.x; int i = threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
T val = 0; T val(0);
do { do {
int x_offset = i * w + j; int x_offset = i * w + j;
...@@ -433,7 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( ...@@ -433,7 +433,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
int tid = threadIdx.x; int tid = threadIdx.x;
int j = blockIdx.x; int j = blockIdx.x;
T val = 0; T val(0);
int ttid = tid; int ttid = tid;
while (true) { while (true) {
......
...@@ -21,6 +21,16 @@ namespace operators { ...@@ -21,6 +21,16 @@ namespace operators {
namespace math { namespace math {
namespace { namespace {
__device__ __forceinline__ float real_log(float x) { return logf(x); }
__device__ __forceinline__ double real_log(double x) { return log(x); }
__device__ __forceinline__ platform::float16 real_log(
const platform::float16& val) {
return static_cast<platform::float16>(hlog(static_cast<half>(val)));
}
template <typename T> template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D, const int N, const int D,
...@@ -29,8 +39,8 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, ...@@ -29,8 +39,8 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
i += blockDim.x * gridDim.x) { i += blockDim.x * gridDim.x) {
PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index); PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index);
Y[i] = ignore_index == label[i] Y[i] = ignore_index == label[i]
? 0 ? static_cast<T>(0)
: -math::TolerableValue<T>()(log(X[i * D + label[i]])); : -math::TolerableValue<T>()(real_log(X[i * D + label[i]]));
} }
} }
...@@ -38,12 +48,12 @@ template <typename T> ...@@ -38,12 +48,12 @@ template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) { const int class_num) {
int tid = threadIdx.x; int tid = threadIdx.x;
T val = 0; T val(0);
int idx = blockIdx.x * class_num + tid; int idx = blockIdx.x * class_num + tid;
int end = blockIdx.x * class_num + class_num; int end = blockIdx.x * class_num + class_num;
for (; idx < end; idx += blockDim.x) { for (; idx < end; idx += blockDim.x) {
val += math::TolerableValue<T>()(std::log(X[idx])) * label[idx]; val += math::TolerableValue<T>()(real_log(X[idx])) * label[idx];
} }
val = paddle::platform::reduceSum(val, tid, blockDim.x); val = paddle::platform::reduceSum(val, tid, blockDim.x);
...@@ -53,8 +63,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -53,8 +63,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
} }
} // namespace } // namespace
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> { class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public: public:
...@@ -89,6 +97,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> { ...@@ -89,6 +97,8 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
template class CrossEntropyFunctor<platform::CUDADeviceContext, float>; template class CrossEntropyFunctor<platform::CUDADeviceContext, float>;
template class CrossEntropyFunctor<platform::CUDADeviceContext, double>; template class CrossEntropyFunctor<platform::CUDADeviceContext, double>;
template class CrossEntropyFunctor<platform::CUDADeviceContext,
platform::float16>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
...@@ -33,6 +34,26 @@ struct TolerableValue { ...@@ -33,6 +34,26 @@ struct TolerableValue {
} }
}; };
// NOTE(dzh): float16 value clip behave different.
// 1. Our ValueClipping has a hardcore threshold 1e20
// for float number. 1e20 will resulting in overflow in float16.
// 2. float16 should expose the the real number overflow to python.
// because mixed-training depends the inf/nan value to determine
// if the scale value will be adjusted.
// Also. In standard implementation of cross entropy, other
// framework not has the ValueClipping.
template <>
struct TolerableValue<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& x) const {
if (platform::isfinite(x))
return x;
else if (x > static_cast<platform::float16>(0))
return std::numeric_limits<platform::float16>::max();
else
return std::numeric_limits<platform::float16>::min();
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CrossEntropyFunctor { class CrossEntropyFunctor {
public: public:
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -118,7 +119,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -118,7 +119,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
auto* out_data = output->data<T>(); auto* out_data = output->data<T>();
SetConstant<platform::CUDADeviceContext, T> functor; SetConstant<platform::CUDADeviceContext, T> functor;
functor(context, output, 0.0); functor(context, output, static_cast<T>(0));
const int block_size = 256; const int block_size = 256;
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
...@@ -136,6 +137,9 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> { ...@@ -136,6 +137,9 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>; template struct SelectedRowsAddTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>; template struct SelectedRowsAddTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAdd<platform::CUDADeviceContext, platform::float16>;
template struct SelectedRowsAddTensor<platform::CUDADeviceContext,
platform::float16>;
template <typename T> template <typename T>
struct SelectedRowsAddTo<platform::CUDADeviceContext, T> { struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
...@@ -175,6 +179,8 @@ template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>; ...@@ -175,6 +179,8 @@ template struct SelectedRowsAddTo<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>; template struct SelectedRowsAddTo<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>; template struct SelectedRowsAddTo<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>; template struct SelectedRowsAddTo<platform::CUDADeviceContext, int64_t>;
template struct SelectedRowsAddTo<platform::CUDADeviceContext,
platform::float16>;
namespace { namespace {
template <typename T, int block_size> template <typename T, int block_size>
...@@ -227,6 +233,8 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>; ...@@ -227,6 +233,8 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>; template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext,
platform::float16>;
namespace scatter { namespace scatter {
...@@ -287,7 +295,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -287,7 +295,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor; math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0); constant_functor(context, out.mutable_value(), static_cast<T>(0));
auto* out_data = out.mutable_value()->data<T>(); auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>(); auto* input_data = input.value().data<T>();
...@@ -347,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -347,7 +355,7 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
context.GetPlace()); context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor; math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0); constant_functor(context, out.mutable_value(), static_cast<T>(0));
auto* out_data = out.mutable_value()->data<T>(); auto* out_data = out.mutable_value()->data<T>();
...@@ -374,6 +382,7 @@ template struct MergeAdd<platform::CUDADeviceContext, float>; ...@@ -374,6 +382,7 @@ template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>; template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>; template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>; template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template <typename T, int block_size> template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows, __global__ void UpdateToTensorKernel(const T* selected_rows,
......
...@@ -96,12 +96,15 @@ template class SoftmaxCUDNNFunctor<float>; ...@@ -96,12 +96,15 @@ template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>; template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>; template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>; template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>; template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>; template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>; template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>; template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext,
platform::float16>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -15,11 +15,15 @@ limitations under the License. */ ...@@ -15,11 +15,15 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanKernel<paddle::platform::CUDADeviceContext, float>, mean, ops::MeanKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, double>); ops::MeanKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
mean_grad, ops::MeanGradKernel<paddle::platform::CUDADeviceContext, float>, mean_grad, ops::MeanGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, double>); ops::MeanGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
...@@ -55,8 +55,7 @@ class MeanGradKernel : public framework::OpKernel<T> { ...@@ -55,8 +55,7 @@ class MeanGradKernel : public framework::OpKernel<T> {
IG->mutable_data<T>(context.GetPlace()); IG->mutable_data<T>(context.GetPlace());
T ig_size = static_cast<T>(IG->numel()); T ig_size = static_cast<T>(IG->numel());
Eigen::DSizes<int, 1> bcast(ig_size); Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*IG).device( EigenVector<T>::Flatten(*IG).device(
*context.template device_context<DeviceContext>().eigen_device()) = *context.template device_context<DeviceContext>().eigen_device()) =
(EigenVector<T>::From(*OG) / ig_size).broadcast(bcast); (EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
......
...@@ -20,6 +20,7 @@ namespace plat = paddle::platform; ...@@ -20,6 +20,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>, REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>, ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>); ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(mul_grad, REGISTER_OP_CUDA_KERNEL(
ops::MulGradKernel<plat::CUDADeviceContext, float>, mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>); ops::MulGradKernel<plat::CUDADeviceContext, double>,
ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -178,7 +178,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace, ...@@ -178,7 +178,8 @@ REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNOpKernel<plat::float16>); ops::PoolCUDNNOpKernel<plat::float16>);
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNGradOpKernel<float>, ops::PoolCUDNNGradOpKernel<float>,
ops::PoolCUDNNGradOpKernel<double>); ops::PoolCUDNNGradOpKernel<double>,
ops::PoolCUDNNGradOpKernel<plat::float16>);
REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
ops::PoolCUDNNOpKernel<float>, ops::PoolCUDNNOpKernel<float>,
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h" #include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/platform/float16.h"
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
scale, scale,
...@@ -20,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -20,4 +22,6 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, double>, paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, double>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, int>, paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, int>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext, paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext,
int64_t>); int64_t>,
paddle::operators::ScaleKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
...@@ -80,4 +80,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace, ...@@ -80,4 +80,5 @@ REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<plat::float16>); ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace, REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>, ops::SoftmaxGradCUDNNKernel<float>,
ops::SoftmaxGradCUDNNKernel<double>); ops::SoftmaxGradCUDNNKernel<double>,
ops::SoftmaxGradCUDNNKernel<plat::float16>);
...@@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -23,4 +23,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>); ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>, softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>); ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -11,10 +11,13 @@ limitations under the License. */ ...@@ -11,10 +11,13 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
sum, ops::SumKernel<paddle::platform::CUDADeviceContext, float>, sum, ops::SumKernel<paddle::platform::CUDADeviceContext, float>,
ops::SumKernel<paddle::platform::CUDADeviceContext, double>, ops::SumKernel<paddle::platform::CUDADeviceContext, double>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int>, ops::SumKernel<paddle::platform::CUDADeviceContext, int>,
ops::SumKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::SumKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>);
...@@ -61,7 +61,7 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -61,7 +61,7 @@ class SumKernel : public framework::OpKernel<T> {
if (start != 2) { if (start != 2) {
math::SetConstant<DeviceContext, T> constant_functor; math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), constant_functor(context.template device_context<DeviceContext>(),
out, 0.0); out, static_cast<T>(0));
} }
} }
......
...@@ -65,7 +65,7 @@ def is_persistable(var): ...@@ -65,7 +65,7 @@ def is_persistable(var):
Examples: Examples:
.. code-block:: python .. code-block:: python
param = fluid.default_main_program().global_block().var('fc.w') param = fluid.default_main_program().global_block().var('fc.b')
res = fluid.io.is_persistable(param) res = fluid.io.is_persistable(param)
""" """
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
...@@ -625,8 +625,13 @@ def save_inference_model(dirname, ...@@ -625,8 +625,13 @@ def save_inference_model(dirname,
main_program._distributed_lookup_table, main_program._distributed_lookup_table,
main_program._endpoints) main_program._endpoints)
if not os.path.isdir(dirname): # when a pserver and a trainer running on the same machine, mkdir may conflict
try:
os.makedirs(dirname) os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
if model_filename is not None: if model_filename is not None:
model_basename = os.path.basename(model_filename) model_basename = os.path.basename(model_filename)
else: else:
......
...@@ -41,9 +41,6 @@ def convert_reader_to_recordio_file( ...@@ -41,9 +41,6 @@ def convert_reader_to_recordio_file(
""" """
Convert a Python Reader to a recordio file. Convert a Python Reader to a recordio file.
Please see :ref:`api_guide_python_reader` and :ref:`api_guide_reader_op` for
details.
Examples: Examples:
>>> import paddle.fluid as fluid >>> import paddle.fluid as fluid
......
...@@ -54,14 +54,6 @@ def get_numeric_gradient(place, ...@@ -54,14 +54,6 @@ def get_numeric_gradient(place,
def product(dim): def product(dim):
return six.moves.reduce(lambda a, b: a * b, dim, 1) return six.moves.reduce(lambda a, b: a * b, dim, 1)
def get_output():
sum = []
op.run(scope, place)
for output_name in output_names:
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).sum() / len(output_names)
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.shape()) tensor_size = product(tensor_to_check.shape())
tensor_to_check_dtype = tensor_to_check._dtype() tensor_to_check_dtype = tensor_to_check._dtype()
...@@ -77,6 +69,15 @@ def get_numeric_gradient(place, ...@@ -77,6 +69,15 @@ def get_numeric_gradient(place,
raise ValueError("Not supported data type " + str( raise ValueError("Not supported data type " + str(
tensor_to_check_dtype)) tensor_to_check_dtype))
def get_output():
sum = []
op.run(scope, place)
for output_name in output_names:
sum.append(
np.array(scope.find_var(output_name).get_tensor()).astype(
tensor_to_check_dtype).mean())
return tensor_to_check_dtype(np.array(sum).sum() / len(output_names))
gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype) gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)
def __get_elem__(tensor, i): def __get_elem__(tensor, i):
......
...@@ -223,106 +223,81 @@ class TestWithInput1x1Filter1x1(TestConv2dOp): ...@@ -223,106 +223,81 @@ class TestWithInput1x1Filter1x1(TestConv2dOp):
#----------------Conv2dCUDNN---------------- #----------------Conv2dCUDNN----------------
class TestCUDNN(TestConv2dOp):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNN(TestConv2dOp):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self): def create_test_cudnn_class(parent, cls_name):
if core.is_compiled_with_cuda(): @unittest.skipIf(not core.is_compiled_with_cuda(),
place = core.CUDAPlace(0) "core is not compiled with CUDA")
if core.is_float16_supported(place): class TestCUDNNCase(parent):
self.check_output_with_place(place, atol=2e-2) def init_kernel_type(self):
self.use_cudnn = True
cls_name = "{0}".format(cls_name)
TestCUDNNCase.__name__ = cls_name
globals()[cls_name] = TestCUDNNCase
class TestCUDNNWithPad(TestWithPad):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNWithPad(TestWithPad):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWithStride(TestWithStride):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNWithStride(TestWithStride):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
create_test_cudnn_class(TestConv2dOp, "TestPool2DCUDNNOp")
create_test_cudnn_class(TestWithPad, "TestPool2DCUDNNOpCase1")
create_test_cudnn_class(TestWithStride, "TestPool2DCUDNNOpCase2")
create_test_cudnn_class(TestWithGroup, "TestPool2DCUDNNOpCase3")
create_test_cudnn_class(TestWith1x1, "TestPool2DCUDNNOpCase4")
create_test_cudnn_class(TestWithInput1x1Filter1x1, "TestPool2DCUDNNOpCase4")
class TestCUDNNWithGroup(TestWithGroup): #----------------Conv2dCUDNN----------------
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNWithGroup(TestWithGroup):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_cudnn = True
def create_test_cudnn_fp16_class(parent, cls_name, grad_check=True):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestConv2DCUDNNFp16(parent):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
class TestFP16CUDNNWith1x1(TestWith1x1): def test_check_output(self):
def init_kernel_type(self): if core.is_compiled_with_cuda():
self.use_cudnn = True place = core.CUDAPlace(0)
self.dtype = np.float16 if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
def test_check_output(self): def test_check_grad_no_filter(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place) and grad_check:
self.check_output_with_place(place, atol=2e-2) self.check_grad_with_place(
place, ['Input'],
'Output',
class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): max_relative_error=0.02,
def init_kernel_type(self): no_grad_set=set(['Filter']))
self.use_cudnn = True
def test_check_grad_no_input(self):
class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place) and grad_check:
self.check_output_with_place(place, atol=2e-2) self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
cls_name = "{0}".format(cls_name)
TestConv2DCUDNNFp16.__name__ = cls_name
globals()[cls_name] = TestConv2DCUDNNFp16
create_test_cudnn_fp16_class(
TestConv2dOp, "TestPool2DCUDNNFp16Op", grad_check=False)
create_test_cudnn_fp16_class(
TestWithPad, "TestPool2DCUDNNFp16OpCase1", grad_check=False)
create_test_cudnn_fp16_class(
TestWithStride, "TestPool2DCUDNNFp16OpCase2", grad_check=False)
create_test_cudnn_fp16_class(
TestWithGroup, "TestPool2DCUDNNFp16OpCase3", grad_check=False)
create_test_cudnn_fp16_class(
TestWith1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
create_test_cudnn_fp16_class(
TestWithInput1x1Filter1x1, "TestPool2DCUDNNFp16OpCase4", grad_check=False)
# -------TestDepthwiseConv
class TestDepthwiseConv(TestConv2dOp): class TestDepthwiseConv(TestConv2dOp):
......
...@@ -16,28 +16,58 @@ from __future__ import print_function ...@@ -16,28 +16,58 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid.core as core
from op_test import OpTest, randomize_probability from op_test import OpTest, randomize_probability
class TestCrossEntropyOp1(OpTest): class TestCrossEntropyOp(OpTest):
"""Test cross-entropy with discrete one-hot labels. """Test cross-entropy with discrete one-hot labels.
""" """
def setUp(self): def setUp(self):
self.op_type = "cross_entropy" self.op_type = "cross_entropy"
batch_size = 30 self.soft_label = False
class_num = 10 self.ignore_index = -100
self.dtype = np.float64
self.batch_size = 30
self.class_num = 10
self.init_dtype_type()
self.init_attr_type()
self.init_bs_class_num()
self.init_x()
self.init_label()
self.get_cross_entropy()
self.inputs = {"X": self.x, "Label": self.label}
self.outputs = {"Y": self.cross_entropy}
self.attrs = {
"soft_label": self.soft_label,
"ignore_index": self.ignore_index
}
def init_x(self):
self.x = randomize_probability(
self.batch_size, self.class_num, dtype=self.dtype)
def init_label(self):
self.label = np.random.randint(
0, self.class_num, (self.batch_size, 1), dtype="int64")
def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
[[-np.log(self.x[i][self.label[i][0]])]
for i in range(self.x.shape[0])],
dtype="float64")
X = randomize_probability(batch_size, class_num, dtype='float64') def init_attr_type(self):
pass
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64") def init_dtype_type(self):
cross_entropy = np.asmatrix( pass
[[-np.log(X[i][label[i][0]])] for i in range(X.shape[0])],
dtype="float64")
self.inputs = {"X": X, "Label": label} def init_bs_class_num(self):
self.outputs = {"Y": cross_entropy} pass
self.attrs = {"soft_label": False}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -46,197 +76,231 @@ class TestCrossEntropyOp1(OpTest): ...@@ -46,197 +76,231 @@ class TestCrossEntropyOp1(OpTest):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001) self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOp2(OpTest): class TestCrossEntropyOp2(TestCrossEntropyOp):
"""Test cross-entropy with vectorized soft labels. """Test cross-entropy with vectorized soft labels.
""" """
def setUp(self): def init_label(self):
self.op_type = "cross_entropy" self.label = np.random.uniform(
batch_size = 5 0.1, 1.0, [self.batch_size, self.class_num]).astype(self.dtype)
class_num = 37 self.label /= self.label.sum(axis=1, keepdims=True)
X = randomize_probability(batch_size, class_num) def get_cross_entropy(self):
label = np.random.uniform(0.1, 1.0, self.cross_entropy = (-self.label * np.log(self.x)).sum(
[batch_size, class_num]).astype("float32") axis=1, keepdims=True).astype(self.dtype)
label /= label.sum(axis=1, keepdims=True)
cross_entropy = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label} def init_attr_type(self):
self.outputs = {"Y": cross_entropy} self.soft_label = True
self.attrs = {"soft_label": True}
def test_check_output(self): def init_dtype_type(self):
self.check_output() self.dtype = np.float32
def init_bs_class_num(self):
self.batch_size = 5
self.class_num = 37
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp3(OpTest): class TestCrossEntropyOp3(TestCrossEntropyOp):
"""Test cross-entropy with vectorized one-hot representation of labels. """Test cross-entropy with vectorized one-hot representation of labels.
""" """
def setUp(self): def init_label(self):
self.op_type = "cross_entropy" self.label_index = np.random.randint(0, self.class_num,
batch_size = 5 (self.batch_size))
class_num = 17 self.label = np.zeros(self.x.shape).astype(self.dtype)
self.label[np.arange(self.batch_size), self.label_index] = 1
X = randomize_probability(batch_size, class_num) def get_cross_entropy(self):
label_index = np.random.randint( self.cross_entropy = np.asmatrix(
0, class_num, (batch_size), dtype="int32") [[-np.log(self.x[i][self.label_index[i]])]
label = np.zeros(X.shape) for i in range(self.x.shape[0])]).astype(self.dtype)
label[np.arange(batch_size), label_index] = 1
cross_entropy = np.asmatrix( def init_attr_type(self):
[[-np.log(X[i][label_index[i]])] for i in range(X.shape[0])], self.soft_label = True
dtype="float32")
cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label.astype(np.float32)} def init_dtype_type(self):
self.outputs = {"Y": cross_entropy} self.dtype = np.float32
self.attrs = {"soft_label": True}
def test_check_output(self): def init_bs_class_num(self):
self.check_output() self.batch_size = 5
self.class_num = 17
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp4(OpTest): class TestCrossEntropyOp4(TestCrossEntropyOp):
"""Test high rank tensor cross-entropy with discrete one-hot labels. """Test high rank tensor cross-entropy with discrete one-hot labels.
""" """
def setUp(self): def init_x(self):
self.op_type = "cross_entropy" self.shape = [10, 2, 4]
shape = [10, 2, 4] self.ins_num = np.prod(np.array(self.shape))
ins_num = np.prod(np.array(shape)) self.X_2d = randomize_probability(self.ins_num,
class_num = 10 self.class_num).astype(self.dtype)
self.x = self.X_2d.reshape(self.shape + [self.class_num])
X_2d = randomize_probability(ins_num, class_num, dtype='float64') def init_label(self):
self.label_2d = np.random.randint(
0, self.class_num, (self.ins_num, 1), dtype="int64")
self.label = self.label_2d.reshape(self.shape + [1])
label_2d = np.random.randint(0, class_num, (ins_num, 1), dtype="int64") def get_cross_entropy(self):
cross_entropy_2d = np.asmatrix( cross_entropy_2d = np.asmatrix(
[[-np.log(X_2d[i][label_2d[i][0]])] for i in range(X_2d.shape[0])], [[-np.log(self.X_2d[i][self.label_2d[i][0]])]
dtype="float64") for i in range(self.X_2d.shape[0])]).astype(self.dtype)
self.cross_entropy = np.array(cross_entropy_2d).reshape(self.shape +
[1])
X = X_2d.reshape(shape + [class_num]) def init_attr_type(self):
label = label_2d.reshape(shape + [1]) self.soft_label = False
cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1])
self.inputs = {"X": X, "Label": label} def init_dtype_type(self):
self.outputs = {"Y": cross_entropy} self.dtype = np.float64
self.attrs = {"soft_label": False}
def test_check_output(self):
self.check_output()
def test_check_grad(self): def init_bs_class_num(self):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001) self.class_num = 10
class TestCrossEntropyOp5(OpTest): class TestCrossEntropyOp5(TestCrossEntropyOp):
"""Test high rank tensor cross-entropy with vectorized soft labels. """Test high rank tensor cross-entropy with vectorized soft labels.
""" """
def setUp(self): def init_x(self):
self.op_type = "cross_entropy" self.shape = [4, 3]
shape = [4, 3] self.ins_num = np.prod(np.array(self.shape))
ins_num = np.prod(np.array(shape)) self.X_2d = randomize_probability(self.ins_num,
class_num = 37 self.class_num).astype(self.dtype)
self.x = self.X_2d.reshape(self.shape + [self.class_num])
X_2d = randomize_probability(ins_num, class_num) def init_label(self):
label_2d = np.random.uniform(0.1, 1.0, self.label_2d = np.random.uniform(
[ins_num, class_num]).astype("float32") 0.1, 1.0, [self.ins_num, self.class_num]).astype(self.dtype)
label_2d /= label_2d.sum(axis=1, keepdims=True) self.label_2d /= self.label_2d.sum(axis=1, keepdims=True)
cross_entropy_2d = (-label_2d * np.log(X_2d)).sum( self.label = self.label_2d.reshape(self.shape + [self.class_num])
axis=1, keepdims=True).astype("float32")
X = X_2d.reshape(shape + [class_num]) def get_cross_entropy(self):
label = label_2d.reshape(shape + [class_num]) cross_entropy_2d = (-self.label_2d * np.log(self.X_2d)).sum(
cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1]) axis=1, keepdims=True).astype(self.dtype)
self.cross_entropy = np.array(cross_entropy_2d).reshape(self.shape +
[1])
self.inputs = {"X": X, "Label": label} def init_attr_type(self):
self.outputs = {"Y": cross_entropy} self.soft_label = True
self.attrs = {"soft_label": True}
def test_check_output(self): def init_dtype_type(self):
self.check_output() self.dtype = np.float32
def init_bs_class_num(self):
self.class_num = 37
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp6(OpTest): class TestCrossEntropyOp6(TestCrossEntropyOp):
"""Test high rank tensor cross-entropy with vectorized one-hot representation of labels. """Test high rank tensor cross-entropy with vectorized one-hot representation of labels.
""" """
def setUp(self): def init_x(self):
self.op_type = "cross_entropy" self.shape = [4, 3, 2]
shape = [4, 3, 2] self.ins_num = np.prod(np.array(self.shape))
ins_num = np.prod(np.array(shape)) self.X_2d = randomize_probability(self.ins_num,
class_num = 17 self.class_num).astype(self.dtype)
self.x = self.X_2d.reshape(self.shape + [self.class_num])
X_2d = randomize_probability(ins_num, class_num)
label_index_2d = np.random.randint( def init_label(self):
0, class_num, (ins_num), dtype="int32") self.label_index_2d = np.random.randint(
label_2d = np.zeros(X_2d.shape) 0, self.class_num, (self.ins_num), dtype="int64")
label_2d[np.arange(ins_num), label_index_2d] = 1 label_2d = np.zeros(self.X_2d.shape)
label_2d[np.arange(self.ins_num), self.label_index_2d] = 1
self.label = label_2d.reshape(self.shape + [self.class_num]).astype(
self.dtype)
def get_cross_entropy(self):
cross_entropy_2d = np.asmatrix( cross_entropy_2d = np.asmatrix(
[[-np.log(X_2d[i][label_index_2d[i]])] [[-np.log(self.X_2d[i][self.label_index_2d[i]])]
for i in range(X_2d.shape[0])], for i in range(self.X_2d.shape[0])])
dtype="float32") self.cross_entropy = np.array(cross_entropy_2d).reshape(
self.shape + [1]).astype(self.dtype)
X = X_2d.reshape(shape + [class_num]) def init_attr_type(self):
label = label_2d.reshape(shape + [class_num]) self.soft_label = True
cross_entropy = np.array(cross_entropy_2d).reshape(shape + [1])
self.inputs = {"X": X, "Label": label.astype(np.float32)} def init_dtype_type(self):
self.outputs = {"Y": cross_entropy} self.dtype = np.float32
self.attrs = {"soft_label": True}
def test_check_output(self): def init_bs_class_num(self):
self.check_output() self.class_num = 17
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001)
class TestCrossEntropyOp7(OpTest): class TestCrossEntropyOp7(TestCrossEntropyOp):
"""Test cross-entropy with ignore index. """Test cross-entropy with ignore index.
""" """
def setUp(self): def init_label(self):
self.op_type = "cross_entropy" self.label = np.random.randint(
batch_size = 30 0, self.class_num, (self.batch_size, 1), dtype="int64")
class_num = 10
ignore_index = 3 def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
X = randomize_probability(batch_size, class_num, dtype='float64') [[-np.log(self.x[i][self.label[i][0]])]
if self.label[i][0] != self.ignore_index else [0]
label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64") for i in range(self.x.shape[0])]).astype(self.dtype)
cross_entropy = np.asmatrix(
[[-np.log(X[i][label[i][0]])] def init_attr_type(self):
if label[i][0] != ignore_index else [0] self.soft_label = False
for i in range(X.shape[0])], self.ignore_index = 3
dtype="float64")
self.inputs = {"X": X, "Label": label} def init_dtype_type(self):
self.outputs = {"Y": cross_entropy} self.dtype = np.float64
self.attrs = {"soft_label": False, "ignore_index": ignore_index}
def init_bs_class_num(self):
def test_check_output(self): self.batch_size = 30
self.check_output() self.class_num = 10
def test_check_grad(self):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001) # Add Fp16 test
def create_test_class(parent, cls_name):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCrossEntropyFP16Op(parent):
def init_dtype_type(self):
return np.float16
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-1)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'], 'Y', max_relative_error=0.9)
cls_name = "{0}".format(cls_name)
TestCrossEntropyFP16Op.__name__ = cls_name
globals()[cls_name] = TestCrossEntropyFP16Op
create_test_class(TestCrossEntropyOp, "TestCrossEntropyF16Op")
#create_test_class(TestCrossEntropyOp2, "TestCrossEntropyF16Op2")
create_test_class(TestCrossEntropyOp3, "TestCrossEntropyF16Op3")
create_test_class(TestCrossEntropyOp4, "TestCrossEntropyF16Op4")
#create_test_class(TestCrossEntropyOp5, "TestCrossEntropyF16Op5")
create_test_class(TestCrossEntropyOp6, "TestCrossEntropyF16Op6")
create_test_class(TestCrossEntropyOp7, "TestCrossEntropyF16Op7")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,14 +17,20 @@ from __future__ import print_function ...@@ -17,14 +17,20 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
class TestMeanOp(OpTest): class TestMeanOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mean" self.op_type = "mean"
self.inputs = {'X': np.random.random((10, 10)).astype("float32")} self.dtype = np.float32
self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])} self.outputs = {'Out': np.mean(self.inputs["X"])}
def init_dtype_type(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -32,5 +38,23 @@ class TestMeanOp(OpTest): ...@@ -32,5 +38,23 @@ class TestMeanOp(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16MeanOp(TestMeanOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-3)
def test_checkout_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=0.8)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -23,12 +23,17 @@ from op_test import OpTest ...@@ -23,12 +23,17 @@ from op_test import OpTest
class TestMulOp(OpTest): class TestMulOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.dtype = np.float32
self.init_dtype_type()
self.inputs = { self.inputs = {
'X': np.random.random((2, 5)).astype("float32"), 'X': np.random.random((2, 5)).astype(self.dtype),
'Y': np.random.random((5, 3)).astype("float32") 'Y': np.random.random((5, 3)).astype(self.dtype)
} }
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def init_dtype_type(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -47,9 +52,11 @@ class TestMulOp(OpTest): ...@@ -47,9 +52,11 @@ class TestMulOp(OpTest):
class TestMulOp2(OpTest): class TestMulOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.dtype = np.float32
self.init_dtype_type()
self.inputs = { self.inputs = {
'X': np.random.random((3, 4, 4, 3)).astype("float32"), 'X': np.random.random((3, 4, 4, 3)).astype(self.dtype),
'Y': np.random.random((2, 6, 1, 2, 3)).astype("float32") 'Y': np.random.random((2, 6, 1, 2, 3)).astype(self.dtype)
} }
self.attrs = { self.attrs = {
'x_num_col_dims': 2, 'x_num_col_dims': 2,
...@@ -60,6 +67,9 @@ class TestMulOp2(OpTest): ...@@ -60,6 +67,9 @@ class TestMulOp2(OpTest):
result = result.reshape(3, 4, 1, 2, 3) result = result.reshape(3, 4, 1, 2, 3)
self.outputs = {'Out': result} self.outputs = {'Out': result}
def init_dtype_type(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -75,40 +85,76 @@ class TestMulOp2(OpTest): ...@@ -75,40 +85,76 @@ class TestMulOp2(OpTest):
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
class TestFP16MulOp1(OpTest): @unittest.skipIf(not core.is_compiled_with_cuda(),
def setUp(self): "core is not compiled with CUDA")
self.op_type = "mul" class TestFP16MulOp1(TestMulOp):
x = np.random.random((3, 5)).astype("float16") def init_dtype_type(self):
y = np.random.random((5, 4)).astype("float16") self.dtype = np.float16
self.inputs = {'X': x.view(np.float16), 'Y': y.view(np.float16)}
self.outputs = {'Out': np.dot(x, y)}
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_output_with_place(place, atol=1e-1)
self.check_output_with_place(place, atol=1e-1)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.5)
class TestFP16MulOp2(OpTest): def test_check_grad_ingore_x(self):
def setUp(self): place = core.CUDAPlace(0)
self.op_type = "mul" if core.is_float16_supported(place):
x = np.random.random((3, 4, 4, 3)).astype("float16") self.check_grad_with_place(
y = np.random.random((2, 6, 1, 2, 3)).astype("float16") place, ['Y'],
self.inputs = {'X': x.view(np.float16), 'Y': y.view(np.float16)} 'Out',
self.attrs = { max_relative_error=0.5,
'x_num_col_dims': 2, no_grad_set=set("X"))
'y_num_col_dims': 2,
} def test_check_grad_ingore_y(self):
result = np.dot(x.reshape(3 * 4, 4 * 3), y.reshape(2 * 6, 1 * 2 * 3)) place = core.CUDAPlace(0)
result = result.reshape(3, 4, 1, 2, 3) if core.is_float16_supported(place):
self.outputs = {'Out': result} self.check_grad_with_place(
place, ['X'],
'Out',
max_relative_error=0.5,
no_grad_set=set('Y'))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16MulOp2(TestMulOp2):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self): def test_check_output(self):
if core.is_compiled_with_cuda(): place = core.CUDAPlace(0)
place = core.CUDAPlace(0) if core.is_float16_supported(place):
if core.is_float16_supported(place): self.check_output_with_place(place, atol=2e-1)
self.check_output_with_place(place, atol=2e-1)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', max_relative_error=0.9)
def test_check_grad_ingore_x(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Y'],
'Out',
max_relative_error=0.5,
no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X'],
'Out',
max_relative_error=0.9,
no_grad_set=set('Y'))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from test_pool2d_op import TestPool2d_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5 from test_pool2d_op import TestPool2D_Op, TestCase1, TestCase2, TestCase3, TestCase4, TestCase5
class TestMKLDNNCase1(TestPool2d_Op): class TestMKLDNNCase1(TestPool2D_Op):
def init_kernel_type(self): def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
......
...@@ -81,7 +81,7 @@ def avg_pool2D_forward_naive(x, ...@@ -81,7 +81,7 @@ def avg_pool2D_forward_naive(x,
return out return out
class TestPool2d_Op(OpTest): class TestPool2D_Op(OpTest):
def setUp(self): def setUp(self):
self.op_type = "pool2d" self.op_type = "pool2d"
self.use_cudnn = False self.use_cudnn = False
...@@ -160,7 +160,7 @@ class TestPool2d_Op(OpTest): ...@@ -160,7 +160,7 @@ class TestPool2d_Op(OpTest):
self.exclusive = True self.exclusive = True
class TestCase1(TestPool2d_Op): class TestCase1(TestPool2D_Op):
def init_test_case(self): def init_test_case(self):
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
...@@ -175,7 +175,7 @@ class TestCase1(TestPool2d_Op): ...@@ -175,7 +175,7 @@ class TestCase1(TestPool2d_Op):
self.global_pool = False self.global_pool = False
class TestCase2(TestPool2d_Op): class TestCase2(TestPool2D_Op):
def init_test_case(self): def init_test_case(self):
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
...@@ -190,7 +190,7 @@ class TestCase2(TestPool2d_Op): ...@@ -190,7 +190,7 @@ class TestCase2(TestPool2d_Op):
self.global_pool = False self.global_pool = False
class TestCase3(TestPool2d_Op): class TestCase3(TestPool2D_Op):
def init_pool_type(self): def init_pool_type(self):
self.pool_type = "max" self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
...@@ -208,127 +208,98 @@ class TestCase5(TestCase2): ...@@ -208,127 +208,98 @@ class TestCase5(TestCase2):
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
#--------------------test pool2d-------------------- #--------------------test pool2d cudnn--------------------
class TestCUDNNCase1(TestPool2d_Op):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNCase1(TestPool2d_Op): def create_test_cudnn_class(parent):
def init_kernel_type(self): @unittest.skipIf(not core.is_compiled_with_cuda(),
self.use_cudnn = True "core is not compiled with CUDA")
self.dtype = np.float16 class TestCUDNNCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
def test_check_output(self): cls_name = "{0}_{1}".format(parent.__name__, "CUDNNOp")
if core.is_compiled_with_cuda(): TestCUDNNCase.__name__ = cls_name
place = core.CUDAPlace(0) globals()[cls_name] = TestCUDNNCase
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase2(TestCase1): create_test_cudnn_class(TestPool2D_Op)
def init_kernel_type(self): create_test_cudnn_class(TestCase1)
self.use_cudnn = True create_test_cudnn_class(TestCase2)
create_test_cudnn_class(TestCase3)
create_test_cudnn_class(TestCase4)
create_test_cudnn_class(TestCase5)
#--------------------test pool2d cudnn_fp16--------------------
class TestFP16CUDNNCase2(TestCase1):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self): def create_test_cudnn_fp16_class(parent, check_grad=True):
if core.is_compiled_with_cuda(): @unittest.skipIf(not core.is_compiled_with_cuda(),
place = core.CUDAPlace(0) "core is not compiled with CUDA")
if core.is_float16_supported(place): class TestCUDNNFp16Case(parent):
self.check_output_with_place(place, atol=1e-3) def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase3(TestCase2): def test_check_grad(self):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNCase3(TestCase2):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(
self.check_output_with_place(place, atol=1e-3) place) and self.pool_type != "max" and check_grad:
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
cls_name = "{0}_{1}".format(parent.__name__, "CUDNNFp16Op")
TestCUDNNFp16Case.__name__ = cls_name
globals()[cls_name] = TestCUDNNFp16Case
class TestCUDNNCase4(TestCase3):
def init_kernel_type(self):
self.use_cudnn = True
create_test_cudnn_fp16_class(TestPool2D_Op)
create_test_cudnn_fp16_class(TestCase1, check_grad=False)
create_test_cudnn_fp16_class(TestCase2)
create_test_cudnn_fp16_class(TestCase3)
create_test_cudnn_fp16_class(TestCase4)
create_test_cudnn_fp16_class(TestCase5)
class TestFP16CUDNNCase4(TestCase3): #--------------------test pool2d use ceil mode--------------------
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
def create_test_cudnn_use_ceil_class(parent):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPool2DUseCeilCase(parent):
def init_kernel_type(self):
self.use_cudnn = True
class TestCUDNNCase5(TestCase4): def init_ceil_mode(self):
def init_kernel_type(self): self.ceil_mode = True
self.use_cudnn = True
class TestFP16CUDNNCase5(TestCase4):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestCUDNNCase6(TestCase5):
def init_kernel_type(self):
self.use_cudnn = True
class TestFP16CUDNNCase6(TestCase5): cls_name = "{0}_{1}".format(parent.__name__, "CUDNNOpCeilMode")
def init_kernel_type(self): TestPool2DUseCeilCase.__name__ = cls_name
self.use_cudnn = True globals()[cls_name] = TestPool2DUseCeilCase
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
create_test_cudnn_use_ceil_class(TestPool2D_Op)
create_test_cudnn_use_ceil_class(TestCase1)
class TestCeilModeCase1(TestCUDNNCase1):
def init_ceil_mode(self):
self.ceil_mode = True
def create_test_use_ceil_class(parent):
class TestPool2DUseCeilCase(parent):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCeilModeCase2(TestCUDNNCase2): cls_name = "{0}_{1}".format(parent.__name__, "CeilModeCast")
def init_ceil_mode(self): TestPool2DUseCeilCase.__name__ = cls_name
self.ceil_mode = True globals()[cls_name] = TestPool2DUseCeilCase
class TestCeilModeCase3(TestCase1): create_test_use_ceil_class(TestCase1)
def init_ceil_mode(self): create_test_use_ceil_class(TestCase2)
self.ceil_mode = True
class TestCeilModeCase4(TestCase2):
def init_ceil_mode(self):
self.ceil_mode = True
class TestAvgInclude(TestCase2): class TestAvgInclude(TestCase2):
...@@ -336,7 +307,10 @@ class TestAvgInclude(TestCase2): ...@@ -336,7 +307,10 @@ class TestAvgInclude(TestCase2):
self.exclusive = False self.exclusive = False
class TestCUDNNAvgInclude(TestCUDNNCase3): class TestCUDNNAvgInclude(TestCase2):
def init_kernel_type(self):
self.use_cudnn = True
def init_exclusive(self): def init_exclusive(self):
self.exclusive = False self.exclusive = False
......
...@@ -24,9 +24,16 @@ from paddle.fluid.op import Operator ...@@ -24,9 +24,16 @@ from paddle.fluid.op import Operator
class TestScaleOp(OpTest): class TestScaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "scale" self.op_type = "scale"
self.inputs = {'X': np.random.random((10, 10)).astype("float32")} self.dtype = np.float32
self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.attrs = {'scale': -2.3} self.attrs = {'scale': -2.3}
self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']} self.outputs = {
'Out': self.inputs['X'] * self.dtype(self.attrs['scale'])
}
def init_dtype_type(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -36,9 +43,15 @@ class TestScaleOp(OpTest): ...@@ -36,9 +43,15 @@ class TestScaleOp(OpTest):
class TestScaleOpSelectedRows(unittest.TestCase): class TestScaleOpSelectedRows(unittest.TestCase):
def init_dtype_type(self):
pass
def check_with_place(self, place, in_name, out_name): def check_with_place(self, place, in_name, out_name):
scope = core.Scope() scope = core.Scope()
self.dtype = np.float32
self.init_dtype_type()
# create and initialize Grad Variable # create and initialize Grad Variable
in_height = 10 in_height = 10
in_rows = [0, 4, 7] in_rows = [0, 4, 7]
...@@ -49,7 +62,7 @@ class TestScaleOpSelectedRows(unittest.TestCase): ...@@ -49,7 +62,7 @@ class TestScaleOpSelectedRows(unittest.TestCase):
in_selected_rows.set_height(in_height) in_selected_rows.set_height(in_height)
in_selected_rows.set_rows(in_rows) in_selected_rows.set_rows(in_rows)
in_array = np.random.random( in_array = np.random.random(
(len(in_rows), in_row_numel)).astype("float32") (len(in_rows), in_row_numel)).astype(self.dtype)
in_tensor = in_selected_rows.get_tensor() in_tensor = in_selected_rows.get_tensor()
in_tensor.set(in_array, place) in_tensor.set(in_array, place)
...@@ -87,5 +100,41 @@ class TestScaleOpSelectedRows(unittest.TestCase): ...@@ -87,5 +100,41 @@ class TestScaleOpSelectedRows(unittest.TestCase):
self.check_with_place(place, 'in', 'in') self.check_with_place(place, 'in', 'in')
# Add FP16 test
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScaleFp16Op(TestScaleOp):
def init_dtype_type(self):
self.dtype = np.float16
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.002)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.05)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):
def init_dtype_type(self):
self.dtype = np.float16
def test_scale_selected_rows(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_with_place(place, 'in', 'out')
def test_scale_selected_rows_inplace(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_with_place(place, 'in', 'in')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -62,12 +62,11 @@ class TestSoftmaxOp(OpTest): ...@@ -62,12 +62,11 @@ class TestSoftmaxOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.use_cudnn or self.dtype == np.float16:
return
if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( if core.is_float16_supported(place):
place, ["X"], "Out", max_relative_error=0.01) self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.01)
else: else:
self.check_grad(["X"], "Out", max_relative_error=0.01) self.check_grad(["X"], "Out", max_relative_error=0.01)
...@@ -103,10 +102,23 @@ class TestSoftmaxFP16Op(TestSoftmaxOp): ...@@ -103,10 +102,23 @@ class TestSoftmaxFP16Op(TestSoftmaxOp):
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3) self.check_output_with_place(place, atol=1e-3)
# FIXME: If the x_shape is [10, 10], gradient failed.
def test_check_grad(self):
pass
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestSoftmaxFP16Op2(TestSoftmaxFP16Op): class TestSoftmaxFP16Op2(TestSoftmaxOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
def get_x_shape(self): def get_x_shape(self):
return [2, 3, 4, 5] return [2, 3, 4, 5]
......
...@@ -24,16 +24,20 @@ from paddle.fluid.op import Operator ...@@ -24,16 +24,20 @@ from paddle.fluid.op import Operator
class TestSumOp(OpTest): class TestSumOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sum" self.op_type = "sum"
self.init_kernel_type()
self.use_mkldnn = False self.use_mkldnn = False
self.init_kernel_type() self.init_kernel_type()
x0 = np.random.random((3, 4)).astype('float32') x0 = np.random.random((3, 4)).astype(self.dtype)
x1 = np.random.random((3, 4)).astype('float32') x1 = np.random.random((3, 4)).astype(self.dtype)
x2 = np.random.random((3, 4)).astype('float32') x2 = np.random.random((3, 4)).astype(self.dtype)
self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]} self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
y = x0 + x1 + x2 y = x0 + x1 + x2
self.outputs = {'Out': y} self.outputs = {'Out': y}
self.attrs = {'use_mkldnn': self.use_mkldnn} self.attrs = {'use_mkldnn': self.use_mkldnn}
def init_kernel_type(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -59,8 +63,11 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -59,8 +63,11 @@ class TestSelectedRowsSumOp(OpTest):
self.check_input_and_optput(core.Scope(), place, inplace, False, False, self.check_input_and_optput(core.Scope(), place, inplace, False, False,
False) False)
def init_kernel_type(self):
self.dtype = np.float32
def _get_array(self, row_num, row_numel): def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32") array = np.ones((row_num, row_numel)).astype(self.dtype)
for i in range(row_num): for i in range(row_num):
array[i] *= i array[i] *= i
return array return array
...@@ -129,5 +136,36 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -129,5 +136,36 @@ class TestSelectedRowsSumOp(OpTest):
self.check_with_place(place, inplace) self.check_with_place(place, inplace)
class TestFP16SumOp(TestSumOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=2e-2)
# FIXME: Because of the precision fp16, max_relative_error
# should be 0.15 here.
def test_check_grad(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad(['x0'], 'Out', max_relative_error=0.15)
class TestFP16SelectedRowsSumOp(TestSelectedRowsSumOp):
def init_kernel_type(self):
self.dtype = np.float16
def test_w_is_selected_rows(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
for inplace in [True, False]:
self.check_with_place(place, inplace)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册