提交 affce733 编写于 作者: C chengduoZH

refine elementwise_op

上级 4e7e39b4
...@@ -54,7 +54,15 @@ class CompareOpKernel ...@@ -54,7 +54,15 @@ class CompareOpKernel
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE; using T = typename Functor::ELEM_TYPE;
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context); using Tensor = framework::Tensor;
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
z);
} }
}; };
......
...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T> ...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -92,9 +99,19 @@ template <typename DeviceContext, typename T> ...@@ -92,9 +99,19 @@ template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>, ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx); ElementwiseAddBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T> ...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseDivKernel : public framework::OpKernel<T> { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -111,9 +118,19 @@ template <typename DeviceContext, typename T> ...@@ -111,9 +118,19 @@ template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> { class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>, ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx); ElementwiseDivBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T> ...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> { class ElementwiseMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T> ...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> { class ElementwiseMaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>,
ElementwiseMaxBroadCastGradFunctor<T>, ElementwiseMaxBroadCastGradFunctor<T>,
ElementwiseMaxBroadCast2GradFunctor<T>>(ctx); ElementwiseMaxBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T> ...@@ -28,7 +28,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> { class ElementwiseMinKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T> ...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> { class ElementwiseMinGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMinGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseMinGradFunctor<T>,
ElementwiseMinBroadCastGradFunctor<T>, ElementwiseMinBroadCastGradFunctor<T>,
ElementwiseMinBroadCast2GradFunctor<T>>(ctx); ElementwiseMinBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
...@@ -27,7 +27,14 @@ template <typename DeviceContext, typename T> ...@@ -27,7 +27,14 @@ template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T> ...@@ -110,9 +117,19 @@ template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>, ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx); ElementwiseMulBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
...@@ -313,21 +313,18 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV); ...@@ -313,21 +313,18 @@ EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, int axis,
framework::Tensor* dx, framework::Tensor* dy) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx) { if (dx) {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
} }
...@@ -348,7 +345,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -348,7 +345,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
x_dims = framework::make_ddim(extended_dims); x_dims = framework::make_ddim(extended_dims);
} }
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
int pre, n, post; int pre, n, post;
...@@ -367,13 +363,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -367,13 +363,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
template <typename Functor, typename DeviceContext, typename T, template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
using Tensor = framework::Tensor; const framework::Tensor* x,
const framework::Tensor* y, int axis,
auto* x = ctx.Input<Tensor>("X"); framework::Tensor* z) {
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<OutType>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext, OutType> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor()); x, y, z, ctx.template device_context<DeviceContext>(), Functor());
...@@ -394,7 +387,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { ...@@ -394,7 +387,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
x_dims = framework::make_ddim(extended_dims); x_dims = framework::make_ddim(extended_dims);
} }
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)"); "Axis should be in range [0, x_dims)");
......
...@@ -29,7 +29,14 @@ template <typename DeviceContext, typename T> ...@@ -29,7 +29,14 @@ template <typename DeviceContext, typename T>
class ElementwisePowKernel : public framework::OpKernel<T> { class ElementwisePowKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
......
...@@ -27,7 +27,14 @@ template <typename DeviceContext, typename T> ...@@ -27,7 +27,14 @@ template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z);
} }
}; };
...@@ -93,9 +100,19 @@ template <typename DeviceContext, typename T> ...@@ -93,9 +100,19 @@ template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> { class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>, ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>, ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx); ElementwiseSubBroadCast2GradFunctor<T>>(
ctx, x, y, out, dout, axis, dx, dy);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册