提交 6c641827 编写于 作者: D dengkaipeng

refine softmax kernel. test=develop

上级 412b7cbd
......@@ -23,15 +23,16 @@ template <typename DeviceContext, typename T, bool is_test,
typename Enable = void>
class SoftmaxFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y);
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y);
};
template <typename DeviceContext, typename T>
class SoftmaxGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad);
void operator()(const DeviceContext& context, const int axis_dim,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad);
};
#ifdef PADDLE_WITH_CUDA
......
......@@ -36,8 +36,8 @@ struct ValueClip {
template <typename DeviceContext, typename T, bool is_test, typename Enable>
void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
const DeviceContext& context, const int axis_dim,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -46,10 +46,13 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto shifted_logits = (logits -
logits.maximum(along_class)
......@@ -60,11 +63,11 @@ void SoftmaxFunctor<DeviceContext, T, is_test, Enable>::operator()(
softmax.device(*context.eigen_device()) = shifted_logits.exp();
softmax.device(*context.eigen_device()) = (softmax *
softmax.sum(along_class)
softmax.reshape(batch_axis_remain)
.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
.broadcast(one_axis));
}
template <class DeviceContext>
......@@ -90,7 +93,7 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y,
const DeviceContext& context, const int axis_dim, const framework::Tensor* y,
const framework::Tensor* y_grad, framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
......@@ -101,16 +104,19 @@ void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
auto dot = (softmax * softmax_grad)
.reshape(batch_axis_remain)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
.broadcast(one_axis);
logits_grad.device(*context.eigen_device()) = (softmax_grad - dot) * softmax;
}
......
......@@ -110,46 +110,28 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* X = ctx.Input<Tensor>("X");
Tensor* Out = ctx.Output<Tensor>("Out");
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE_EQ(
X->dims(), Out->dims(),
input->dims(), output->dims(),
"The shape of softmax's input and output must be identical.");
const int axis = ctx.Attr<int>("axis");
int rank = X->dims().size();
// make sure 'output' holds memory, which will be shared by
// 'flattened_output' later.
Out->mutable_data<T>(ctx.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
auto dims = X_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(X_trans).Resize(flattened_dims);
Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims);
} else {
auto dims = X->dims();
output->mutable_data<T>(ctx.GetPlace());
// flatten input and output to 2-D matrixs
auto dims = input->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(*X).Resize(flattened_dims);
Out_2d.ShareDataWith(*Out).Resize(flattened_dims);
}
framework::Tensor flattened_input;
framework::Tensor flattened_output;
flattened_input.ShareDataWith(*input).Resize(flattened_dims);
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
const T* input_data = X_2d.data<T>();
T* output_data = Out_2d.mutable_data<T>(ctx.GetPlace());
const T* input_data = flattened_input.data<T>();
T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(X_2d.dims());
std::vector<int> src_tz = paddle::framework::vectorize2int(flattened_dims);
std::vector<int> dst_tz = src_tz;
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
......@@ -179,16 +161,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// We cannot use softmax_dst_memory_p to get prim desc as
// it contains flattened dims (2D) while output tensor can
// have 2,3,4+ dims
if (axis != -1 && axis != rank - 1) {
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
shape, mkldnn::memory::format::blocked);
Out_trans.set_mkldnn_prim_desc(output_mem_pd);
} else {
auto output_mem_pd = paddle::platform::create_prim_desc_from_dims(
paddle::framework::vectorize2int(Out->dims()),
paddle::framework::vectorize2int(output->dims()),
mkldnn::memory::format::blocked);
Out->set_mkldnn_prim_desc(output_mem_pd);
}
output->set_mkldnn_prim_desc(output_mem_pd);
std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
......@@ -202,11 +178,6 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
output_data[i] < threshold ? threshold : output_data[i];
}
}
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out,
perm);
}
}
};
......@@ -219,55 +190,33 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* Out = ctx.Input<Tensor>("Out");
auto* dOut = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dX =
const Tensor* output = ctx.Input<Tensor>("Out");
auto* dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(
dOut->dims(), dX->dims(),
dout->dims(), dx->dims(),
"The shape of softmax_grad's input and output must be identical.");
const int axis = ctx.Attr<int>("axis");
int rank = Out->dims().size();
// make sure 'dx' holds memory, which will be shared by 'flattened_dx'
// later.
dX->template mutable_data<T>(ctx.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
auto dims = dX_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims);
Out_2d.ShareDataWith(Out_trans).Resize(flattened_dims);
dOut_2d.ShareDataWith(dOut_trans).Resize(flattened_dims);
} else {
auto dims = dX->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(*dX).Resize(flattened_dims);
Out_2d.ShareDataWith(*Out).Resize(flattened_dims);
dOut_2d.ShareDataWith(*dOut).Resize(flattened_dims);
}
dx->template mutable_data<T>(ctx.GetPlace());
const T* dst_data = Out_2d.data<T>();
const T* diff_dst_ptr = dOut_2d.template data<T>();
T* diff_src_ptr = dX_2d.template mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = paddle::framework::vectorize2int(Out_2d.dims());
auto dims = dout->dims(); // input and output share the same shape
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::Tensor flattened_output;
framework::Tensor flattened_dout;
framework::Tensor flattened_dx;
flattened_output.ShareDataWith(*output).Resize(flattened_dims);
flattened_dout.ShareDataWith(*dout).Resize(flattened_dims);
flattened_dx.ShareDataWith(*dx).Resize(flattened_dims);
const T* dst_data = flattened_output.data<T>();
const T* diff_dst_ptr = flattened_dout.template data<T>();
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = paddle::framework::vectorize2int(flattened_dims);
std::vector<int> src_tz(dst_tz);
// Same memory descriptor to be used for input and output
......@@ -312,11 +261,6 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline{*softmax_bwd_p};
stream(stream::kind::eager).submit(pipeline).wait();
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
}
};
} // namespace operators
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
......@@ -25,44 +24,22 @@ template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
int rank = X->dims().size();
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), &X_2d,
&Out_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans,
Out, perm);
}
context.template device_context<platform::CUDADeviceContext>(),
&flattened_x, &flattened_out);
}
};
......@@ -70,51 +47,25 @@ template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis");
int rank = Out->dims().size();
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX,
&dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), &Out_2d,
&dOut_2d, &dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
context.template device_context<platform::CUDADeviceContext>(),
&flattened_out, &flattened_d_out, &flattened_d_x);
}
};
......
......@@ -13,81 +13,66 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
std::vector<int>* perm,
std::vector<int>* shape) {
auto dim_x = x.dims();
int rank = dim_x.size();
if (axis == -1 || axis == rank - 1) {
return;
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
for (int i = 0; i < rank - 1; i++) {
if (i == axis) {
perm->push_back(rank - 1);
shape->push_back(dim_x[rank - 1]);
} else {
perm->push_back(i);
shape->push_back(dim_x[i]);
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
perm->push_back(axis);
shape->push_back(dim_x[axis]);
return size;
}
template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
int rank = X->dims().size();
const int rank = X->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = X->dims()[axis];
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
const int n = SizeToAxis(axis, X->dims());
const int d = SizeFromAxis(axis, X->dims());
Tensor X_2d, Out_2d;
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
X_2d.ShareDataWith(*X).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
// Tensor X_2d = framework::ReshapeToMatrix(*X, axis - 1);
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
context.template device_context<DeviceContext>(), axis_dim, &X_2d, &Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
context.template device_context<DeviceContext>(), axis_dim, &X_2d, &Out_2d);
#endif
if (axis != -1 && axis != rank - 1) {
TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
}
}
};
......@@ -95,46 +80,29 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<DeviceContext>();
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
const int axis = context.Attr<int>("axis");
int rank = Out->dims().size();
const int rank = dX->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = dX->dims()[axis];
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
std::vector<int> perm, shape;
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
const int n = SizeToAxis(axis, dX->dims());
const int d = SizeFromAxis(axis, dX->dims());
Tensor dX_2d, Out_2d, dOut_2d;
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
} else {
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
}
dX_2d.ShareDataWith(*dX).Resize({n, d});
Out_2d.ShareDataWith(*Out).Resize({n, d});
dOut_2d.ShareDataWith(*dOut).Resize({n, d});
// Tensor Out_2d = framework::ReshapeToMatrix(*Out, axis - 1);
// Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, axis - 1);
// Tensor dX_2d = framework::ReshapeToMatrix(*dX, axis - 1);
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
context.template device_context<DeviceContext>(), axis_dim, &Out_2d, &dOut_2d,
&dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<DeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
}
}
};
......
......@@ -43,7 +43,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, logits, softmax);
dev_ctx, -1, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
......
......@@ -69,7 +69,7 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
int rank = logits->dims().size();
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, -1, &in_2d, &out_2d);
// ctc needs sequences data stored in transposed padding format
// logits and grad using padding data of layout 'TNC'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册