提交 ead7059b 编写于 作者: F fengjiayi

Refine code

上级 ee8e5374
......@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
class ElementwiseMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<MaxFunctor<T>, T, DeviceContext> functor(
x, y, z, ctx.template device_context<DeviceContext>(), MaxFunctor<T>());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor.Run();
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx);
}
};
......
......@@ -28,39 +28,7 @@ template <typename DeviceContext, typename T>
class ElementwiseMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<MinFunctor<T>, T, DeviceContext> functor(
x, y, z, ctx.template device_context<DeviceContext>(), MinFunctor<T>());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor.Run();
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx);
}
};
......
......@@ -356,5 +356,43 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
return;
}
}
template <typename Functor, typename DeviceContext, typename T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<Functor, T, DeviceContext> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor.Run();
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册