未验证 提交 128adf53 编写于 作者: D dzhwinter 提交者: GitHub

[Speed]implement cudnn sequence softmax cudnn (#8978)

* "add softmax cudnn functor support"

* "add testing"

* "refine cmakelist"

* "sequence softmax forward speed up"

* "add softmax grad"

* "fix sequence softmax test"

* "add double precision'

* "fix softmax test"

* "add softmax cudnn support"

* "fix softmax cudnn test"

* "add softmax to nn.py"

* "fix compile bug"

* "refine cmakelist"

* "fix ci"

* "fix based on comment"

* "fix based on comments"

* "fix ci"
上级 9b9f3f09
......@@ -14,13 +14,86 @@ limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/math/softmax_impl.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
void SoftmaxCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor xDesc;
ScopedTensorDescriptor yDesc;
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(X->dims());
DataLayout layout = DataLayout::kNCHW;
if (cudnn_tensor_dims.size() == 5) {
layout = DataLayout::kNCDHW;
}
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
// fill 1 at unused dims
if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1);
}
cudnnTensorDescriptor_t cudnn_x_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_y_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxForward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
Y->mutable_data<T>(context.GetPlace())));
}
template <typename T>
void SoftmaxGradCUDNNFunctor<T>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor* Y,
const framework::Tensor* YGrad, framework::Tensor* XGrad) {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor yDesc;
ScopedTensorDescriptor dyDesc;
ScopedTensorDescriptor dxDesc;
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(Y->dims());
DataLayout layout = DataLayout::kNCHW;
if (cudnn_tensor_dims.size() == 5) {
layout = DataLayout::kNCDHW;
}
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
// fill 1 at unused dims
if (cudnn_tensor_dims.size() <= 2) {
cudnn_tensor_dims.resize(4, 1);
}
cudnnTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_xgrad_desc =
dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
cudnnTensorDescriptor_t cudnn_ygrad_desc =
dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxBackward(
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_y_desc,
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
XGrad->mutable_data<T>(context.GetPlace())));
}
template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
......
......@@ -33,6 +33,23 @@ class SoftmaxGradFunctor {
const framework::Tensor* y_grad, framework::Tensor* x_grad);
};
#ifdef PADDLE_WITH_CUDA
template <typename T>
class SoftmaxCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y);
};
template <typename T>
class SoftmaxGradCUDNNFunctor {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor* Y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T>
class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
"The first dimension of Input(X) should be equal to the "
"sum of all sequences' lengths.");
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor x_i = x->Slice(start_pos, end_pos);
Tensor out_i = out->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i =
// framework::make_ddim({1UL, end_pos - start_pos, 1UL, 1UL});
framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxCUDNNFunctor<T>()(
ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
&out_i);
}
}
};
template <typename T>
class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradCUDNNFunctor<T>()(
ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
&out_grad_i, &x_grad_i);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace,
ops::SequenceSoftmaxCUDNNKernel<float>,
ops::SequenceSoftmaxCUDNNKernel<double>)
REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::SequenceSoftmaxGradCUDNNKernel<float>,
ops::SequenceSoftmaxGradCUDNNKernel<double>)
......@@ -29,6 +29,29 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -41,6 +64,17 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out",
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
"of length 1.");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC(
Sequence Softmax Operator.
......@@ -91,6 +125,29 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
} // namespace operators
......@@ -102,7 +159,9 @@ REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
ops::SequenceSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -17,7 +17,10 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>)
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>)
REGISTER_OP_CUDA_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), X, Out);
}
};
template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), Out,
dOut, dX);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>);
......@@ -33,6 +33,29 @@ class SoftmaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -43,6 +66,17 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Out", "The normalized values with the same shape as X.");
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC(
Softmax Operator.
......@@ -80,6 +114,29 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
} // namespace operators
......
......@@ -289,7 +289,7 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
......
......@@ -132,7 +132,7 @@ def detection_output(loc,
old_shape = scores.shape
scores = ops.reshape(x=scores, shape=(-1, old_shape[-1]))
scores = ops.softmax(x=scores)
scores = nn.softmax(input=scores)
scores = ops.reshape(x=scores, shape=old_shape)
scores = nn.transpose(scores, perm=[0, 2, 1])
......
......@@ -39,6 +39,8 @@ __all__ = [
'sequence_conv',
'conv2d',
'sequence_pool',
'sequence_softmax',
'softmax',
'pool2d',
'batch_norm',
'beam_search_decode',
......@@ -1085,6 +1087,30 @@ def sequence_conv(input,
return helper.append_activation(pre_act)
def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
helper = LayerHelper('sequence_softmax', **locals())
dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="sequence_softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={"use_cudnn": use_cudnn})
return softmax_out
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True):
helper = LayerHelper('softmax', **locals())
dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype)
helper.append_op(
type="softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={"use_cudnn": use_cudnn})
return softmax_out
def conv2d(input,
num_filters,
filter_size,
......
......@@ -45,13 +45,30 @@ __activations__ = [
]
__all__ = [
'mean', 'mul', 'reshape', 'scale', 'sigmoid_cross_entropy_with_logits',
'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul',
'elementwise_max', 'elementwise_min', 'elementwise_pow', 'clip',
'clip_by_norm', 'softmax', 'sequence_softmax', 'logical_and', 'logical_or',
'logical_xor', 'logical_not', 'uniform_random',
'uniform_random_batch_size_like', 'gaussian_random',
'gaussian_random_batch_size_like', 'cumsum', 'scatter'
'mean',
'mul',
'reshape',
'scale',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
'elementwise_div',
'elementwise_sub',
'elementwise_mul',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'clip',
'clip_by_norm',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'cumsum',
'scatter',
] + __activations__
for _OP in set(__all__):
......
......@@ -220,7 +220,7 @@ class TestBook(unittest.TestCase):
seq_data = layers.data(
name='seq_data', shape=[10, 10], dtype='float32', lod_level=1)
seq = layers.fc(input=seq_data, size=20)
self.assertIsNotNone(layers.sequence_softmax(x=seq))
self.assertIsNotNone(layers.sequence_softmax(seq))
print(str(program))
def test_softmax(self):
......@@ -228,7 +228,7 @@ class TestBook(unittest.TestCase):
with program_guard(program):
data = layers.data(name='data', shape=[10], dtype='float32')
hid = layers.fc(input=data, size=20)
self.assertIsNotNone(layers.softmax(x=hid))
self.assertIsNotNone(layers.softmax(hid))
print(str(program))
def test_get_places(self):
......
......@@ -16,11 +16,15 @@ import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
import paddle.fluid.core as core
class TestSequenceSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "sequence_softmax"
self.use_cudnn = False
self.init_op_type()
x = np.random.uniform(0.1, 1, (11, 1)).astype("float32")
lod = [[0, 4, 5, 8, 11]]
......@@ -34,12 +38,31 @@ class TestSequenceSoftmaxOp(OpTest):
self.inputs = {"X": (x, lod)}
self.outputs = {"Out": out}
self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self):
pass
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", max_relative_error=0.01)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.01)
else:
self.check_grad(["X"], "Out", max_relative_error=0.01)
# ----------------cudnn Sequencesoftmax----------------
class TestSequenceSoftmaxCUDNNOp(TestSequenceSoftmaxOp):
def init_op_type(self):
self.use_cudnn = True
if __name__ == "__main__":
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def stable_softmax(x):
......@@ -27,18 +28,37 @@ def stable_softmax(x):
class TestSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self):
pass
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.01)
else:
self.check_grad(["X"], "Out", max_relative_error=0.01)
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_op_type(self):
self.use_cudnn = True
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册