提交 6cb66721 编写于 作者: D dengkaipeng

add cudnn support. test=develop

上级 518325f1
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -24,22 +25,40 @@ template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
// auto dims = X->dims();
const int axis = context.Attr<int>("axis");
int rank = X->dims().size();
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&flattened_x, &flattened_out);
&X_2d, &Out_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
}
}
};
......@@ -47,25 +66,44 @@ template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis");
int rank = Out->dims().size();
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&flattened_out, &flattened_d_out, &flattened_d_x);
&Out_2d, &dOut_2d, &dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
}
}
};
......
......@@ -23,59 +23,58 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out,
Tensor* x_trans, Tensor* out_trans,
const int axis, std::vector<int> perm,
const framework::ExecutionContext& ctx) {
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
std::vector<int>* perm, std::vector<int>* shape) {
auto dim_x = x.dims();
int rank = dim_x.size();
if (axis == -1 || axis == rank - 1) {
*x_trans = x;
*out_trans = out;
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int> shape;
for (int i = 0; i < rank - 1; i++) {
if (i == axis) {
perm.push_back(rank - 1);
shape.push_back(dim_x[rank - 1]);
perm->push_back(rank - 1);
shape->push_back(dim_x[rank - 1]);
} else {
perm.push_back(i);
shape.push_back(dim_x[i]);
perm->push_back(i);
shape->push_back(dim_x[i]);
}
}
perm.push_back(axis);
shape.push_back(dim_x[axis]);
x_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
out_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, x, x_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, out, out_trans, perm);
perm->push_back(axis);
shape->push_back(dim_x[axis]);
}
template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
int rank = X->dims().size();
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
std::vector<int> perm;
TransposeAxisToEnd<DeviceContext, T>(*X, *Out, &X_trans, &Out_trans, axis,
perm, context);
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
int rank = X->dims().size();
Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
......@@ -86,7 +85,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
#endif
if (axis != -1 && axis != rank - 1) {
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
}
}
......@@ -96,21 +94,44 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis");
int rank = Out->dims().size();
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
int rank = Out->dims().size();
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
&dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<DeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
}
}
};
......
......@@ -31,6 +31,9 @@ class TestSoftmaxOp(OpTest):
def get_x_shape(self):
return [10, 10]
def get_axis(self):
return -1
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
......@@ -38,15 +41,15 @@ class TestSoftmaxOp(OpTest):
self.dtype = np.float32
self.init_kernel_type()
self.shape = self.get_x_shape()
self.axis = self.get_axis()
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, 1,
x.reshape([-1, self.shape[-1]]))
out = out.reshape(self.shape)
out = np.apply_along_axis(stable_softmax, self.axis, x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {
'axis': self.axis,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
......@@ -76,6 +79,38 @@ class TestSoftmaxOp2(TestSoftmaxOp):
return [2, 3, 4, 5]
class TestSoftmaxOp3(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 0
class TestSoftmaxOp4(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 1
class TestSoftmaxOp5(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 2
class TestSoftmaxOp5(TestSoftmaxOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 3
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
......@@ -90,6 +125,26 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
return [2, 3, 4, 5]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 1
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp):
def get_x_shape(self):
return [2, 3, 4, 5]
def get_axis(self):
return 2
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftmaxFP16Op(TestSoftmaxOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册