提交 0e7baabe 编写于 作者: D danleifeng 提交者: gongweibao

extend elementwise broadcast function (#20957)

上级 d623e863
...@@ -99,8 +99,8 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add); ...@@ -99,8 +99,8 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
elementwise_add_grad, ops::ElementwiseOpExplicitGrad, elementwise_add_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradOpInplace, ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>, ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
......
...@@ -25,8 +25,13 @@ void default_elementwise_add(const framework::ExecutionContext &ctx, ...@@ -25,8 +25,13 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) { const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, if (x->numel() >= y->numel()) {
AddFunctor<T>(), z); ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseAddFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
...@@ -128,12 +133,13 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> { ...@@ -128,12 +133,13 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y // skip out
auto *out = dout; auto *out = dout;
auto *x = dout, *y = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
......
...@@ -76,6 +76,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -76,6 +76,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override { std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T()); std::unique_ptr<T> op(new T());
op->SetType("elementwise_div_grad"); op->SetType("elementwise_div_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y")); op->SetInput("Y", this->Input("Y"));
op->SetInput("Out", this->Output("Out")); op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
......
...@@ -31,8 +31,13 @@ void default_elementwise_div(const framework::ExecutionContext& ctx, ...@@ -31,8 +31,13 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, if (x->numel() >= y->numel()) {
DivFunctor<T>(), z); ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseDivFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
...@@ -112,13 +117,13 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { ...@@ -112,13 +117,13 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out"); auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto* x = dout; // Fake x, not used
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
...@@ -191,7 +196,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> { ...@@ -191,7 +196,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
// ddX_safe == null ? 0 : ddX // ddX_safe == null ? 0 : ddX
// ddY_safe == null ? 0 : ddY // ddY_safe == null ? 0 : ddY
Tensor ddX_safe, ddY_safe; Tensor ddX_safe, ddY_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Out, ddX, &ddX_safe); GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dX, ddX, &ddX_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe); GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
...@@ -209,8 +214,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> { ...@@ -209,8 +214,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if (dY) { if (dY) {
// dX_div_Y = dX / Y; // dX_div_Y = dX / Y;
Tensor dX_div_Y = tmp; Tensor dX_div_Y = tmp;
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>( default_elementwise_div<DeviceContext, T>(ctx, dX, Y, &dX_div_Y);
ctx, dX, Y, axis, DivFunctor<T>(), &dX_div_Y);
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first // first output tensor is nullptr, the branch to calculate first
...@@ -227,10 +231,8 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> { ...@@ -227,10 +231,8 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if (ddOut) { if (ddOut) {
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp); default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( default_elementwise_sub<DeviceContext, T>(ctx, &ddX_safe, &tmp, &tmp);
ctx, &ddX_safe, &tmp, 0, SubFunctor<T>(), &tmp); default_elementwise_div<DeviceContext, T>(ctx, &tmp, Y, ddOut);
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, &tmp, Y, axis, DivFunctor<T>(), ddOut);
} }
if (dOut) { if (dOut) {
......
...@@ -26,9 +26,15 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx, ...@@ -26,9 +26,15 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, if (x->numel() >= y->numel()) {
MulFunctor<T>(), z); ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseMulFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseMulFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseMul { struct SameDimsElemwiseMul {
void operator()(const framework::ExecutionContext& ctx, void operator()(const framework::ExecutionContext& ctx,
......
...@@ -14,12 +14,15 @@ limitations under the License. */ ...@@ -14,12 +14,15 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> // for max
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -35,12 +38,12 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -35,12 +38,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of elementwise op should not be null."); "Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) of elementwise op should not be null."); "Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of elementwise op should not be null."); "Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() == ctx->GetInputsVarType("Y").front() ==
...@@ -49,18 +52,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -49,18 +52,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front()); ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
if (ctx->GetInputsVarType("X").front() == if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::SELECTED_ROWS) {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(
x_dim.size(), y_dim.size(),
"ShapeError: the dimension of input X must greater than or equal to "
"the one of input Y. But received: the shape of input X = [%s], the "
"dimension of input X = %d, the shape of input Y = [%s], the "
"dimension of input Y = %d",
x_dim, x_dim.size(), y_dim, y_dim.size());
} else if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Y").size(), 1u, ctx->GetInputDim("Y").size(), 1u,
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS" "ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
...@@ -71,13 +63,31 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -71,13 +63,31 @@ class ElementwiseOp : public framework::OperatorWithKernel {
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS" "ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
"), Y must be scalar. But reveived the first dimension of Y = %s", "), Y must be scalar. But reveived the first dimension of Y = %s",
ctx->GetInputDim("Y")[0]); ctx->GetInputDim("Y")[0]);
} else { } else if (ctx->GetInputsVarType("X").front() !=
framework::proto::VarType::LOD_TENSOR) {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.", PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front()); ctx->GetInputsVarType("X").front());
} }
ctx->ShareDim("X", /*->*/ "Out"); if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
} else {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx->Attrs().Get<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
// to do
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
...@@ -207,26 +217,14 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -207,26 +217,14 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(out_grad_name), PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim(out_grad_name);
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(
x_dims.size(), y_dims.size(),
"ShapeError: the dimension of Out@GRAD must greater than or equal to "
"the one of input Y. But received: the shape of Out@GRAD = [%s], the "
"dimension of Out@GRAD = %d, the shape of input Y = [%s], the "
"dimension of of input Y = %d",
x_dims, x_dims.size(), y_dims, y_dims.size());
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim(out_grad_name, /*->*/ x_grad_name); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name); ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name); ctx->ShareDim("Y", /*->*/ y_grad_name);
...@@ -326,32 +324,6 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -326,32 +324,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
} }
}; };
// For Add, Sub op, the X, Out is not needed.
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
public:
using operators::ElementwiseOpGrad::ElementwiseOpGrad;
using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T> template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> { class ElemwiseGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -372,13 +344,13 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace, ...@@ -372,13 +344,13 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
framework::GradVarName("X")}); framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"}); DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y"); DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "X",
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
"Y", "DOut"); "Y", "DOut");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \ #define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
template <typename T> \ template <typename T> \
class kernel_type##GradMaker \ class kernel_type##GradMaker \
...@@ -390,6 +362,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, ...@@ -390,6 +362,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
std::unique_ptr<T> Apply() const override { \ std::unique_ptr<T> Apply() const override { \
auto *op = new T(); \ auto *op = new T(); \
op->SetType(#kernel_type "_grad"); \ op->SetType(#kernel_type "_grad"); \
op->SetInput("X", this->Input("X")); \
op->SetInput("Y", this->Input("Y")); \ op->SetInput("Y", this->Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetInput(::paddle::framework::GradVarName("Out"), \
this->OutputGrad("Out")); \ this->OutputGrad("Out")); \
...@@ -402,41 +375,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, ...@@ -402,41 +375,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
} \ } \
} }
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpMaker<::paddle::framework::OpDesc, \
true>, \
::paddle::framework::DefaultGradOpMaker<::paddle::imperative::OpBase, \
true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker<::paddle::framework::OpDesc>, \
op_type##GradMaker<::paddle::imperative::OpBase>, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad, \
::paddle::operators::ElementwiseGradOpInplace, \
::paddle::operators::ElementwiseGradNoBufVarsInference)
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name) \ #define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name) \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
::paddle::operators::Elementwise##op_name##OpMaker, \ ::paddle::operators::Elementwise##op_name##OpMaker, \
......
...@@ -44,6 +44,12 @@ namespace operators { ...@@ -44,6 +44,12 @@ namespace operators {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return a expr b; \ return a expr b; \
} \ } \
}; \
template <typename T, class Enable = void> \
struct Inverse##Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return b expr a; \
} \
}; };
DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +) DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +)
......
...@@ -93,13 +93,14 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -93,13 +93,14 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub); REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub); REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub);
namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
elementwise_sub_grad, ops::ElementwiseOpExplicitGrad, elementwise_sub_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradOpInplace, ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>, ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>); ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_sub_grad_grad, REGISTER_OPERATOR(elementwise_sub_grad_grad,
......
...@@ -26,8 +26,13 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx, ...@@ -26,8 +26,13 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, if (x->numel() >= y->numel()) {
SubFunctor<T>(), z); ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
} }
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
...@@ -98,13 +103,14 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> { ...@@ -98,13 +103,14 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
// skip out, x, y // skip out
auto* out = dout; auto* out = dout;
auto *x = dout, *y = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else { } else {
......
...@@ -108,8 +108,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -108,8 +108,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) { if (post == 1) {
functor.RunRowWise(n, pre); functor.RunRowWise(n, pre);
...@@ -212,6 +213,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -212,6 +213,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} }
} else { } else {
// Execute default kernel when broadcast is needed // Execute default kernel when broadcast is needed
x = ctx.Input<Tensor>("X");
y = ctx.Input<Tensor>("Y");
ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T, ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
IdentityGrad<T>, IdentityGrad<T>>( IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(), ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
......
...@@ -91,8 +91,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -91,8 +91,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc; const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable && if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
is_avx512_enabled) { is_avx512_enabled) {
int pre, n, post; int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) { if (post == 1) {
PADDLE_THROW("Not implemented when post is 1"); PADDLE_THROW("Not implemented when post is 1");
...@@ -168,8 +169,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -168,8 +169,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis; axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post; int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) { if (post == 1) {
functor.RunRowWise(n, pre); functor.RunRowWise(n, pre);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
...@@ -139,21 +140,6 @@ struct DivAndSqrtFunctor { ...@@ -139,21 +140,6 @@ struct DivAndSqrtFunctor {
T epsilon_; T epsilon_;
}; };
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T> template <typename T>
struct MulInvVarFunctor { struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { inline HOSTDEVICE T operator()(T a, T b) const {
......
...@@ -32,6 +32,7 @@ TEST(op_debug_str, test_unknown_dtype) { ...@@ -32,6 +32,7 @@ TEST(op_debug_str, test_unknown_dtype) {
framework::Scope scope; framework::Scope scope;
desc.SetType("elementwise_add_grad"); desc.SetType("elementwise_add_grad");
desc.SetInput("X", {"X"});
desc.SetInput("Y", {"Y"}); desc.SetInput("Y", {"Y"});
desc.SetInput(framework::GradVarName("Out"), {framework::GradVarName("Out")}); desc.SetInput(framework::GradVarName("Out"), {framework::GradVarName("Out")});
desc.SetOutput(framework::GradVarName("X"), {framework::GradVarName("X")}); desc.SetOutput(framework::GradVarName("X"), {framework::GradVarName("X")});
...@@ -41,6 +42,10 @@ TEST(op_debug_str, test_unknown_dtype) { ...@@ -41,6 +42,10 @@ TEST(op_debug_str, test_unknown_dtype) {
desc.SetAttr("x_data_format", ""); desc.SetAttr("x_data_format", "");
desc.SetAttr("y_data_format", ""); desc.SetAttr("y_data_format", "");
auto x_tensor = scope.Var("X")->GetMutable<framework::LoDTensor>();
x_tensor->Resize(dim);
x_tensor->mutable_data<float>(place);
auto y_tensor = scope.Var("Y")->GetMutable<framework::LoDTensor>(); auto y_tensor = scope.Var("Y")->GetMutable<framework::LoDTensor>();
y_tensor->Resize(dim); y_tensor->Resize(dim);
y_tensor->mutable_data<float>(place); y_tensor->mutable_data<float>(place);
......
...@@ -7,8 +7,8 @@ TURN_ON_MKL=$2 # use MKL or Openblas ...@@ -7,8 +7,8 @@ TURN_ON_MKL=$2 # use MKL or Openblas
# download models # download models
function download() { function download() {
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR/main_program wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/main_program
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR/startup_program wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/startup_program
} }
download download
......
...@@ -308,6 +308,36 @@ class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp): ...@@ -308,6 +308,36 @@ class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp):
self.axis = -1 self.axis = -1
class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1, 4).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 1, 5).astype(self.dtype)
self.y = np.random.rand(2, 1, 4, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(4, 5).astype(self.dtype)
self.y = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 2
class TestElementwiseAddOpError(OpTest): class TestElementwiseAddOpError(OpTest):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -151,6 +151,39 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp): ...@@ -151,6 +151,39 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [1, 1, 4]).astype("float32"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 1, 4, 1]).astype("float32"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [4, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"),
}
self.attrs = {'axis': 2}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_INT(OpTest): class TestElementwiseDivOp_INT(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
......
...@@ -162,6 +162,41 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp): ...@@ -162,6 +162,41 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp):
self.dtype = np.float16 self.dtype = np.float16
class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(1, 1, 4).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_commonuse_2(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 1, 5).astype(np.float64),
'Y': np.random.rand(2, 1, 4, 1).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(4, 5).astype(np.float64),
'Y': np.random.rand(2, 3, 4, 5).astype(np.float64)
}
self.attrs = {'axis': 2}
self.outputs = {
'Out': self.inputs['X'].reshape(1, 1, 4, 5) * self.inputs['Y']
}
class TestElementwiseMulOpError(OpTest): class TestElementwiseMulOpError(OpTest):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
......
...@@ -127,5 +127,40 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp): ...@@ -127,5 +127,40 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']} self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_commonuse_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1, 4).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_commonuse_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 1, 5).astype(np.float32),
'Y': np.random.rand(2, 1, 4, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(4, 5).astype(np.float32),
'Y': np.random.rand(2, 3, 4, 5).astype(np.float32)
}
self.attrs = {'axis': 2}
self.outputs = {
'Out': self.inputs['X'].reshape(1, 1, 4, 5) - self.inputs['Y']
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -47,7 +47,7 @@ class TestExecutorReturnTensorNotOverwritingWithOptest(OpTest): ...@@ -47,7 +47,7 @@ class TestExecutorReturnTensorNotOverwritingWithOptest(OpTest):
'Y': OpTest.np_dtype_to_fluid_dtype(self.y) 'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
} }
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
self.op_type = "elementwise_mul" self.op_type = "mul"
self.dtype = np.float32 self.dtype = np.float32
outs, fetch_list = self._calc_output(place, parallel=parallel) outs, fetch_list = self._calc_output(place, parallel=parallel)
return outs return outs
......
...@@ -57,7 +57,7 @@ def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id): ...@@ -57,7 +57,7 @@ def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id):
exe.run(trainer_startup_program) exe.run(trainer_startup_program)
for i in range(5): for i in range(5):
exe.run(recv_program) exe.run(recv_program)
exe.run(main_program, exe.run(fluid.default_main_program(),
feed={ feed={
"x": numpy.array([1, 2]).astype('float32').reshape(2, 1), "x": numpy.array([1, 2]).astype('float32').reshape(2, 1),
"y": numpy.array([2, 3]).astype('float32').reshape(2, 1) "y": numpy.array([2, 3]).astype('float32').reshape(2, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册