提交 6c641827 编写于 作者: D dengkaipeng

refine softmax kernel. test=develop

上级 412b7cbd
...@@ -23,15 +23,16 @@ template <typename DeviceContext, typename T, bool is_test, ...@@ -23,15 +23,16 @@ template <typename DeviceContext, typename T, bool is_test,
typename Enable = void> typename Enable = void>
class SoftmaxFunctor { class SoftmaxFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor* X, void operator()(const DeviceContext& context, const int axis_dim,
framework::Tensor* Y); const framework::Tensor* X, framework::Tensor* Y);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SoftmaxGradFunctor { class SoftmaxGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor* y, void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y_grad, framework::Tensor* x_grad); const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
}; };
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -36,8 +36,8 @@ struct ValueClip { ...@@ -36,8 +36,8 @@ struct ValueClip {
template <typename DeviceContext, typename T, bool is_test, typename Enable> template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()( void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const framework::Tensor* X, const DeviceContext& context, const int axis_dim,
framework::Tensor* Y) { const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X); auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y); auto softmax = EigenMatrix<T>::From(*Y);
...@@ -46,10 +46,13 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()( ...@@ -46,10 +46,13 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const int batch_size = logits.dimension(kBatchDim); const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim); const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim); Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1); Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes); Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto shifted_logits = (logits - auto shifted_logits = (logits -
logits.maximum(along_class) logits.maximum(along_class)
...@@ -60,11 +63,11 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()( ...@@ -60,11 +63,11 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
softmax.device(*context.eigen_device()) = shifted_logits.exp(); softmax.device(*context.eigen_device()) = shifted_logits.exp();
softmax.device(*context.eigen_device()) = (softmax * softmax.device(*context.eigen_device()) = (softmax *
softmax.sum(along_class) softmax.reshape(batch_axis_remain)
.sum(along_class)
.inverse() .inverse()
.eval() .eval()
.reshape(batch_by_one) .broadcast(one_axis));
.broadcast(one_by_class));
} }
template <class DeviceContext> template <class DeviceContext>
...@@ -90,7 +93,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -90,7 +93,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()( void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y, const DeviceContext& context, const int axis_dim, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad) { const framework::Tensor* y_grad, framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y); auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad); auto softmax_grad = EigenMatrix<T>::From(*y_grad);
...@@ -101,16 +104,19 @@ void SoftmaxGradFunctor<DeviceContext, T>::operator()( ...@@ -101,16 +104,19 @@ void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const int batch_size = softmax.dimension(kBatchDim); const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim); const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim); Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1); Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes); Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto dot = (softmax * softmax_grad) auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
.sum(along_class) .sum(along_class)
.eval() .eval()
.reshape(batch_by_one) .broadcast(one_axis);
.broadcast(one_by_class);
logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax; logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax;
} }
......
...@@ -110,46 +110,28 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -110,46 +110,28 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine(); auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* X = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* Out = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
X->dims(), Out->dims(), input->dims(), output->dims(),
"The shape of softmax's input and output must be identical."); "The shape of softmax's input and output must be identical.");
const int axis = ctx.Attr<int>("axis");
int rank = X->dims().size();
// make sure 'output' holds memory, which will be shared by // make sure 'output' holds memory, which will be shared by
// 'flattened_output' later. // 'flattened_output' later.
Out->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
std::vector<int> perm, shape; // flatten input and output to 2-D matrixs
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); auto dims = input->dims(); // input and output share the same shape
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
auto dims = X_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(X_trans).Resize(flattened_dims);
Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims);
} else {
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(*X).Resize(flattened_dims); framework::Tensor flattened_input;
Out_2d.ShareDataWith(*Out).Resize(flattened_dims); framework::Tensor flattened_output;
} flattened_input.ShareDataWith(*input).Resize(flattened_dims);
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
const T* input_data = X_2d.data<T>(); const T* input_data = flattened_input.data<T>();
T* output_data = Out_2d.mutable_data<T>(ctx.GetPlace()); T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(X_2d.dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(flattened_dims);
std::vector<int> dst_tz = src_tz; std::vector<int> dst_tz = src_tz;
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
...@@ -179,16 +161,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -179,16 +161,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// We cannot use softmax_dst_memory_p to get prim desc as // We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can // it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims // have 2,3,4+ dims
if (axis != -1 && axis != rank - 1) {
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims( auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
shape, mkldnn::memory::format::blocked); paddle::framework::vectorize2int(output->dims()),
Out_trans.set_mkldnn_prim_desc(output_mem_pd);
} else {
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(Out->dims()),
mkldnn::memory::format::blocked); mkldnn::memory::format::blocked);
Out->set_mkldnn_prim_desc(output_mem_pd); output->set_mkldnn_prim_desc(output_mem_pd);
}
std::vector<primitive> pipeline{ std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))}; *(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
...@@ -202,11 +178,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -202,11 +178,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
output_data[i] < threshold ? threshold : output_data[i]; output_data[i] < threshold ? threshold : output_data[i];
} }
} }
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out,
perm);
}
} }
}; };
...@@ -219,55 +190,33 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -219,55 +190,33 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine(); auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* Out = ctx.Input<Tensor>("Out"); const Tensor* output = ctx.Input<Tensor>("Out");
auto* dOut = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dX = auto* dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X")); ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dOut->dims(), dX->dims(), dout->dims(), dx->dims(),
"The shape of softmax_grad's input and output must be identical."); "The shape of softmax_grad's input and output must be identical.");
const int axis = ctx.Attr<int>("axis");
int rank = Out->dims().size();
// make sure 'dx' holds memory, which will be shared by 'flattened_dx' // make sure 'dx' holds memory, which will be shared by 'flattened_dx'
// later. // later.
dX->template mutable_data<T>(ctx.GetPlace()); dx->template mutable_data<T>(ctx.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
auto dims = dX_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims);
Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims);
dOut_2d.ShareDataWith(dOut_trans).Resize(flattened_dims);
} else {
auto dims = dX->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(*dX).Resize(flattened_dims);
Out_2d.ShareDataWith(*Out).Resize(flattened_dims);
dOut_2d.ShareDataWith(*dOut).Resize(flattened_dims);
}
const T* dst_data = Out_2d.data<T>(); auto dims = dout->dims(); // input and output share the same shape
const T* diff_dst_ptr = dOut_2d.template data<T>(); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
T* diff_src_ptr = dX_2d.template mutable_data<T>(ctx.GetPlace()); framework::Tensor flattened_output;
framework::Tensor flattened_dout;
std::vector<int> dst_tz = paddle::framework::vectorize2int(Out_2d.dims()); framework::Tensor flattened_dx;
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
flattened_dout.ShareDataWith(*dout).Resize(flattened_dims);
flattened_dx.ShareDataWith(*dx).Resize(flattened_dims);
const T* dst_data = flattened_output.data<T>();
const T* diff_dst_ptr = flattened_dout.template data<T>();
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = paddle::framework::vectorize2int(flattened_dims);
std::vector<int> src_tz(dst_tz); std::vector<int> src_tz(dst_tz);
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
...@@ -312,11 +261,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -312,11 +261,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline{*softmax_bwd_p}; std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,44 +24,22 @@ template <typename T> ...@@ -25,44 +24,22 @@ template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> { class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out"); auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
int rank = X->dims().size();
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape; auto dims = X->dims();
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
Tensor X_2d, Out_2d; framework::LoDTensor flattened_out;
Tensor X_trans, Out_trans; flattened_x.ShareDataWith(*X).Resize(flattened_dims);
if (axis != -1 && axis != rank - 1) { flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
math::SoftmaxCUDNNFunctor<T>()( math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), &X_2d, context.template device_context<platform::CUDADeviceContext>(),
&Out_2d); &flattened_x, &flattened_out);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans,
Out, perm);
}
} }
}; };
...@@ -70,51 +47,25 @@ template <typename T> ...@@ -70,51 +47,25 @@ template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* Out = context.Input<Tensor>("Out"); auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out")); auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis");
int rank = Out->dims().size();
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape; auto dims = Out->dims();
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
Tensor dX_2d, Out_2d, dOut_2d; framework::LoDTensor flattened_d_out;
Tensor dX_trans, Out_trans, dOut_trans; framework::LoDTensor flattened_d_x;
if (axis != -1 && axis != rank - 1) { flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
Out_trans.mutable_data<T>(framework::make_ddim(shape), flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX,
&dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
math::SoftmaxGradCUDNNFunctor<T>()( math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), &Out_2d, context.template device_context<platform::CUDADeviceContext>(),
&dOut_2d, &dX_2d); &flattened_out, &flattened_d_out, &flattened_d_x);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
} }
}; };
......
...@@ -13,81 +13,66 @@ See the License for the specific language governing permissions and ...@@ -13,81 +13,66 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using DDim = framework::DDim;
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis, static inline int CanonicalAxis(const int axis, const int rank) {
std::vector<int>* perm, if (axis < 0) {
std::vector<int>* shape) { return axis + rank;
auto dim_x = x.dims();
int rank = dim_x.size();
if (axis == -1 || axis == rank - 1) {
return;
} }
return axis;
}
for (int i = 0; i < rank - 1; i++) { static inline int SizeToAxis(const int axis, DDim dims) {
if (i == axis) { int size = 1;
perm->push_back(rank - 1); for (int i = 0; i < axis; i++) {
shape->push_back(dim_x[rank - 1]); size *= dims[i];
} else {
perm->push_back(i);
shape->push_back(dim_x[i]);
} }
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
} }
perm->push_back(axis); return size;
shape->push_back(dim_x[axis]);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> { class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out"); auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis"); const int rank = X->dims().size();
int rank = X->dims().size(); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
// allocate memory on device. // allocate memory on device.
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape; const int n = SizeToAxis(axis, X->dims());
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape); const int d = SizeFromAxis(axis, X->dims());
Tensor X_2d, Out_2d; Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans; X_2d.ShareDataWith(*X).Resize({n, d});
if (axis != -1 && axis != rank - 1) { Out_2d.ShareDataWith(*Out).Resize({n, d});
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); // Tensor X_2d = framework::ReshapeToMatrix(*X, axis - 1);
Out_trans.mutable_data<T>(framework::make_ddim(shape), // Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()( math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), axis_dim, &X_2d, &Out_2d);
#else #else
math::SoftmaxFunctor<DeviceContext, T, false>()( math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), axis_dim, &X_2d, &Out_2d);
#endif #endif
if (axis != -1 && axis != rank - 1) {
TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
}
} }
}; };
...@@ -95,46 +80,29 @@ template <typename DeviceContext, typename T> ...@@ -95,46 +80,29 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> { class SoftmaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* Out = context.Input<Tensor>("Out"); auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out")); auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis"); const int rank = dX->dims().size();
int rank = Out->dims().size(); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape; const int n = SizeToAxis(axis, dX->dims());
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape); const int d = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d; Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans; dX_2d.ShareDataWith(*dX).Resize({n, d});
if (axis != -1 && axis != rank - 1) { Out_2d.ShareDataWith(*Out).Resize({n, d});
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); dOut_2d.ShareDataWith(*dOut).Resize({n, d});
Out_trans.mutable_data<T>(framework::make_ddim(shape), // Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
context.GetPlace()); // Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, axis - 1);
dOut_trans.mutable_data<T>(framework::make_ddim(shape), // Tensor dX_2d = framework::ReshapeToMatrix(*dX, axis - 1);
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d, context.template device_context<DeviceContext>(), axis_dim, &Out_2d, &dOut_2d,
&dX_2d); &dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<DeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
}
} }
}; };
......
...@@ -43,7 +43,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()( math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, logits, softmax); dev_ctx, -1, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()( math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"), dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index")); context.Attr<int>("ignore_index"));
......
...@@ -69,7 +69,7 @@ class CudnnCTCKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
int rank = logits->dims().size(); int rank = logits->dims().size();
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1); Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1); Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d); math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, -1, &in_2d, &out_2d);
// ctc needs sequences data stored in transposed padding format // ctc needs sequences data stored in transposed padding format
// logits and grad using padding data of layout 'TNC' // logits and grad using padding data of layout 'TNC'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册