提交 2ddd23da 编写于 作者: D dengkaipeng

fix format. test=develop

上级 3e4f3434
......@@ -131,8 +131,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
auto dims = X_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(X_trans).Resize(flattened_dims);
......@@ -202,7 +204,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out,
perm);
}
}
};
......@@ -241,9 +244,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
dX_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans,
perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
auto dims = dX_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims);
......@@ -308,7 +314,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait();
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
}
};
......
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle {
namespace operators {
......@@ -25,7 +25,8 @@ template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis");
......@@ -41,9 +42,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else {
......@@ -52,11 +56,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
}
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&X_2d, &Out_2d);
context.template device_context<platform::CUDADeviceContext>(), &X_2d,
&Out_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans,
Out, perm);
}
}
};
......@@ -65,7 +70,8 @@ template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
......@@ -82,11 +88,16 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX,
&dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
......@@ -97,11 +108,12 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
}
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&Out_2d, &dOut_2d, &dX_2d);
context.template device_context<platform::CUDADeviceContext>(), &Out_2d,
&dOut_2d, &dX_2d);
if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
}
}
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#include <memory>
#include <string>
#ifdef PADDLE_WITH_CUDA
......
......@@ -24,7 +24,8 @@ namespace operators {
using Tensor = framework::Tensor;
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
std::vector<int>* perm, std::vector<int>* shape) {
std::vector<int>* perm,
std::vector<int>* shape) {
auto dim_x = x.dims();
int rank = dim_x.size();
......@@ -65,7 +66,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
......@@ -75,7 +77,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
}
#ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
......@@ -111,8 +112,10 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
......
......@@ -1872,10 +1872,8 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
type="softmax",
inputs={"X": input},
outputs={"Out": softmax_out},
attrs={
"axis": axis,
"use_cudnn": use_cudnn
})
attrs={"axis": axis,
"use_cudnn": use_cudnn})
return softmax_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册