提交 2ddd23da 编写于 作者: D dengkaipeng

fix format. test=develop

上级 3e4f3434
...@@ -131,8 +131,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -131,8 +131,10 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace()); X_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm); TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *X, &X_trans,
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
auto dims = X_trans.dims(); auto dims = X_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
X_2d.ShareDataWith(X_trans).Resize(flattened_dims); X_2d.ShareDataWith(X_trans).Resize(flattened_dims);
...@@ -202,7 +204,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -202,7 +204,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm); TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, Out_trans, Out,
perm);
} }
} }
}; };
...@@ -241,9 +244,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -241,9 +244,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
dX_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace()); dX_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace()); dOut_trans.mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm); TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dX, &dX_trans,
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm); TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
auto dims = dX_trans.dims(); auto dims = dX_trans.dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims); dX_2d.ShareDataWith(dX_trans).Resize(flattened_dims);
...@@ -308,7 +314,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -308,7 +314,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm); TransCompute<platform::CPUDeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
} }
} }
}; };
......
...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,7 +25,8 @@ template <typename T> ...@@ -25,7 +25,8 @@ template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> { class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out"); auto* Out = context.Output<Tensor>("Out");
const int axis = context.Attr<int>("axis"); const int axis = context.Attr<int>("axis");
...@@ -41,9 +42,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -41,9 +42,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
Tensor X_trans, Out_trans; Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape),
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm); context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans,
perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
} else { } else {
...@@ -52,11 +56,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -52,11 +56,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
} }
math::SoftmaxCUDNNFunctor<T>()( math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), context.template device_context<platform::CUDADeviceContext>(), &X_2d,
&X_2d, &Out_2d); &Out_2d);
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm); TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans,
Out, perm);
} }
} }
}; };
...@@ -65,7 +70,8 @@ template <typename T> ...@@ -65,7 +70,8 @@ template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
auto* Out = context.Input<Tensor>("Out"); auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out")); auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
...@@ -82,11 +88,16 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -82,11 +88,16 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
Tensor dX_trans, Out_trans, dOut_trans; Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape),
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm); dOut_trans.mutable_data<T>(framework::make_ddim(shape),
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); context.GetPlace());
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm); TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX,
&dX_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
&Out_trans, perm);
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut,
&dOut_trans, perm);
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1); dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1); Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1); dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
...@@ -97,11 +108,12 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -97,11 +108,12 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
} }
math::SoftmaxGradCUDNNFunctor<T>()( math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(), context.template device_context<platform::CUDADeviceContext>(), &Out_2d,
&Out_2d, &dOut_2d, &dX_2d); &dOut_2d, &dX_2d);
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm); TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX,
perm);
} }
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include <memory>
#include <string> #include <string>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -24,7 +24,8 @@ namespace operators { ...@@ -24,7 +24,8 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis, static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
std::vector<int>* perm, std::vector<int>* shape) { std::vector<int>* perm,
std::vector<int>* shape) {
auto dim_x = x.dims(); auto dim_x = x.dims();
int rank = dim_x.size(); int rank = dim_x.size();
...@@ -65,7 +66,8 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -65,7 +66,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor X_trans, Out_trans; Tensor X_trans, Out_trans;
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm); TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1); X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
...@@ -75,7 +77,6 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -75,7 +77,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1); Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
} }
#ifdef PADDLE_ON_INFERENCE #ifdef PADDLE_ON_INFERENCE
math::SoftmaxFunctor<DeviceContext, T, true>()( math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d); context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
...@@ -111,8 +112,10 @@ class SoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -111,8 +112,10 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
Tensor dX_trans, Out_trans, dOut_trans; Tensor dX_trans, Out_trans, dOut_trans;
if (axis != -1 && axis != rank - 1) { if (axis != -1 && axis != rank - 1) {
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); Out_trans.mutable_data<T>(framework::make_ddim(shape),
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace()); context.GetPlace());
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
context.GetPlace());
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm); TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm); TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm); TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
......
...@@ -1872,10 +1872,8 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -1872,10 +1872,8 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
type="softmax", type="softmax",
inputs={"X": input}, inputs={"X": input},
outputs={"Out": softmax_out}, outputs={"Out": softmax_out},
attrs={ attrs={"axis": axis,
"axis": axis, "use_cudnn": use_cudnn})
"use_cudnn": use_cudnn
})
return softmax_out return softmax_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册