提交 0e7baabe 编写于 作者: D danleifeng 提交者: gongweibao

extend elementwise broadcast function (#20957)

上级 d623e863
......@@ -99,8 +99,8 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_add_grad, ops::ElementwiseOpExplicitGrad,
ops::ElementwiseGradOpInplace, ops::ElementwiseGradNoBufVarsInference,
elementwise_add_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
......
......@@ -25,8 +25,13 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
if (x->numel() >= y->numel()) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseAddFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseAddFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
......@@ -128,12 +133,13 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
using Tensor = framework::Tensor;
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out, x, y
// skip out
auto *out = dout;
auto *x = dout, *y = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
......
......@@ -76,6 +76,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("elementwise_div_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
......
......@@ -31,8 +31,13 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
if (x->numel() >= y->numel()) {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseDivFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseDivFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
......@@ -112,13 +117,13 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto* x = dout; // Fake x, not used
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
......@@ -191,7 +196,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
// ddX_safe == null ? 0 : ddX
// ddY_safe == null ? 0 : ddY
Tensor ddX_safe, ddY_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Out, ddX, &ddX_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dX, ddX, &ddX_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
......@@ -209,8 +214,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if (dY) {
// dX_div_Y = dX / Y;
Tensor dX_div_Y = tmp;
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, dX, Y, axis, DivFunctor<T>(), &dX_div_Y);
default_elementwise_div<DeviceContext, T>(ctx, dX, Y, &dX_div_Y);
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
......@@ -227,10 +231,8 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if (ddOut) {
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, &tmp);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &ddX_safe, &tmp, 0, SubFunctor<T>(), &tmp);
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, &tmp, Y, axis, DivFunctor<T>(), ddOut);
default_elementwise_sub<DeviceContext, T>(ctx, &ddX_safe, &tmp, &tmp);
default_elementwise_div<DeviceContext, T>(ctx, &tmp, Y, ddOut);
}
if (dOut) {
......
......@@ -26,9 +26,15 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
if (x->numel() >= y->numel()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseMulFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseMulFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
struct SameDimsElemwiseMul {
void operator()(const framework::ExecutionContext& ctx,
......
......@@ -14,12 +14,15 @@ limitations under the License. */
#pragma once
#include <algorithm> // for max
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -35,12 +38,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
"Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() ==
......@@ -49,18 +52,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(
x_dim.size(), y_dim.size(),
"ShapeError: the dimension of input X must greater than or equal to "
"the one of input Y. But received: the shape of input X = [%s], the "
"dimension of input X = %d, the shape of input Y = [%s], the "
"dimension of input Y = %d",
x_dim, x_dim.size(), y_dim, y_dim.size());
} else if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS) {
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Y").size(), 1u,
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
......@@ -71,13 +63,31 @@ class ElementwiseOp : public framework::OperatorWithKernel {
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
"), Y must be scalar. But reveived the first dimension of Y = %s",
ctx->GetInputDim("Y")[0]);
} else {
} else if (ctx->GetInputsVarType("X").front() !=
framework::proto::VarType::LOD_TENSOR) {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
} else {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx->Attrs().Get<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
// to do
ctx->ShareLoD("X", /*->*/ "Out");
}
}
framework::OpKernelType GetExpectedKernelType(
......@@ -207,26 +217,14 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(out_grad_name),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim(out_grad_name);
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(
x_dims.size(), y_dims.size(),
"ShapeError: the dimension of Out@GRAD must greater than or equal to "
"the one of input Y. But received: the shape of Out@GRAD = [%s], the "
"dimension of Out@GRAD = %d, the shape of input Y = [%s], the "
"dimension of of input Y = %d",
x_dims, x_dims.size(), y_dims, y_dims.size());
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim(out_grad_name, /*->*/ x_grad_name);
ctx->ShareLoD(out_grad_name, /*->*/ x_grad_name);
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
......@@ -326,32 +324,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
}
};
// For Add, Sub op, the X, Out is not needed.
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
public:
using operators::ElementwiseOpGrad::ElementwiseOpGrad;
using operators::ElementwiseOpGrad::GetExpectedKernelType;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim(framework::GradVarName("Out"), /*->*/ x_grad_name);
ctx->ShareLoD(framework::GradVarName("Out"), /*->*/ x_grad_name);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
......@@ -372,13 +344,13 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
framework::GradVarName("X")});
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "X",
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
"Y", "DOut");
} // namespace operators
} // namespace paddle
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
template <typename T> \
class kernel_type##GradMaker \
......@@ -390,6 +362,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
std::unique_ptr<T> Apply() const override { \
auto *op = new T(); \
op->SetType(#kernel_type "_grad"); \
op->SetInput("X", this->Input("X")); \
op->SetInput("Y", this->Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
this->OutputGrad("Out")); \
......@@ -402,41 +375,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
} \
}
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpMaker<::paddle::framework::OpDesc, \
true>, \
::paddle::framework::DefaultGradOpMaker<::paddle::imperative::OpBase, \
true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker<::paddle::framework::OpDesc>, \
op_type##GradMaker<::paddle::imperative::OpBase>, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad, \
::paddle::operators::ElementwiseGradOpInplace, \
::paddle::operators::ElementwiseGradNoBufVarsInference)
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name) \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
::paddle::operators::Elementwise##op_name##OpMaker, \
......
......@@ -44,6 +44,12 @@ namespace operators {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return a expr b; \
} \
}; \
template <typename T, class Enable = void> \
struct Inverse##Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return b expr a; \
} \
};
DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +)
......
......@@ -16,11 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <functional> // for multiplies
#include <iterator>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
......@@ -33,6 +37,12 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
#define GetDivMod(dividend, divisor, div, mod) \
do { \
const auto dividend_copy = dividend; \
*div = dividend_copy / divisor; \
*mod = dividend_copy % divisor; \
} while (0)
namespace paddle {
namespace operators {
......@@ -48,72 +58,453 @@ namespace operators {
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
* broadcast cases.
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
* mid_flag should not be NULL.
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
* New parameter: *is_run_common_broadcast* is a flag to record whether to run
* common broadcast code.
*/
inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post, int *mid_flag = NULL) {
int *pre, int *n, int *post,
int *is_run_common_broadcast) {
*pre = 1;
*n = 1;
*post = 1;
if (mid_flag != NULL) {
*mid_flag = 0;
int mid = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ(*mid_flag, 0,
"Broadcast support y_dims with single 1.");
PADDLE_ENFORCE_EQ(y_dims[i], 1,
"ShapeError: broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y",
x_dims, y_dims, x_dims[i + axis], y_dims[i]);
// m*n*k m*1*k
for (int j = 0; j < i; ++j) {
(*pre) *= y_dims[j];
}
*n = std::max(x_dims[i + axis], y_dims[i]);
*mid_flag = 1;
mid = i;
break;
}
(*n) *= y_dims[i];
*is_run_common_broadcast = 0;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
if (x_dims[i + axis] != y_dims[i]) {
PADDLE_ENFORCE(y_dims[i] == 1 || x_dims[i + axis] == 1,
"ShapeError: broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y",
x_dims, y_dims, x_dims[i + axis], y_dims[i]);
*is_run_common_broadcast = 1;
return;
}
if (*mid_flag) {
for (int i = mid + 1; i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim,
const int *index_array) {
int index_ = 0;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] > 1) {
index_ = index_ * x_dims_array[i] + index_array[i];
}
}
return index_;
}
inline void UpdateElementwiseIndexArray(const int *out_dims_array,
const int max_dim, int *index_array) {
for (int i = max_dim - 1; i >= 0; --i) {
++index_array[i];
if (index_array[i] >= out_dims_array[i]) {
index_array[i] -= out_dims_array[i];
} else {
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
break;
}
}
}
inline void GetBroadcastDimsArrays(const framework::DDim &x_dims,
const framework::DDim &y_dims,
int *x_dims_array, int *y_dims_array,
int *out_dims_array, const int max_dim,
const int axis) {
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
std::fill(y_dims_array + axis + y_dims.size(), y_dims_array + max_dim, 1);
}
} else { // for fused_elementwise_activation_op. keep the old version.
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array + axis);
} else {
std::fill(x_dims_array, x_dims_array + axis, 1);
if (axis + x_dims.size() < max_dim) {
std::fill(x_dims_array + axis + x_dims.size(), x_dims_array + max_dim, 1);
}
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
}
for (int i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
(*n) *= y_dims[i];
for (int i = 0; i < max_dim; i++) {
PADDLE_ENFORCE(x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1,
"ShapeError: broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y",
x_dims, y_dims, x_dims_array[i], y_dims_array[i]);
if (x_dims_array[i] == -1 || y_dims_array[i] == -1) {
out_dims_array[i] = -1;
} else {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
}
}
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCPU(const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
int *x_dims_array, int *y_dims_array,
int *out_dims_array, int max_dim,
const platform::CPUDeviceContext &ctx,
Functor func,
const bool is_xsize_larger = true) {
std::vector<int> index_array(max_dim, 0);
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
if (is_xsize_larger) {
out_data[out_index] = func(x_data[x_index], y_data[y_index]);
} else {
out_data[out_index] = func(y_data[y_index], x_data[x_index]);
}
UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
}
}
#ifdef __NVCC__
template <typename Functor, typename T>
__global__ void CommonForwardBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const T *x, const T *y, T *out, int out_size,
int max_dim, Functor func, const bool is_xsize_larger) {
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x;
out_index < out_size; out_index += blockDim.x * gridDim.x) {
int x_index = 0;
int y_index = 0;
int out_index_quotient = out_index;
int remainder = 0;
#pragma unroll
for (int i = max_dim - 1; i >= 0; --i) {
GetDivMod(out_index_quotient, out_dims_array[i], &out_index_quotient,
&remainder);
x_index += remainder * x_strides_array[i];
y_index += remainder * y_strides_array[i];
}
if (is_xsize_larger) {
out[out_index] = func(x[x_index], y[y_index]);
} else {
out[out_index] = func(y[y_index], x[x_index]);
}
}
}
template <typename Functor, typename T>
void CommonForwardBroadcastCUDA(
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
int *out_dims_array, int max_dim, const platform::CUDADeviceContext &ctx,
Functor func, const bool is_xsize_larger = true) {
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
auto cplace = platform::CPUPlace();
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
T *out_data = z->mutable_data<T>(ctx.GetPlace());
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
}
int bytes = max_dim * sizeof(int);
auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
bytes, ctx.stream());
auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
bytes, ctx.stream());
auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
ctx.stream());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
dim3 gird_size = dim3(
(out_size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
CommonForwardBroadcastCUDAKernel<
Functor, T><<<gird_size, block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
}
#endif // __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCPU(
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out, const framework::Tensor &dout,
framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array,
int *y_dims_array, int *out_dims_array, int max_dim,
const platform::CPUDeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) {
std::vector<int> index_array(max_dim, 0);
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const T *out_data = out.data<T>();
const T *dout_data = dout.data<T>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
if (dx_data != nullptr) {
memset(dx_data, 0, dx->numel() * sizeof(T));
}
if (dy_data != nullptr) {
memset(dy_data, 0, dy->numel() * sizeof(T));
}
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data());
y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data());
if (dx_data != nullptr) {
dx_data[x_index] += dx_op(x_data[x_index], y_data[y_index],
out_data[out_index], dout_data[out_index]);
}
if (dy_data != nullptr) {
dy_data[y_index] += dy_op(x_data[x_index], y_data[y_index],
out_data[out_index], dout_data[out_index]);
}
UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data());
}
}
inline void ComputeBroadcastKernelSize(int *x_dims_array, int *out_dims_array,
int *x_blocks, int *x_threads,
int max_dim) {
*x_blocks = 1;
*x_threads = 1;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] == out_dims_array[i]) {
*x_blocks *= x_dims_array[i];
} else {
*x_threads *= out_dims_array[i];
}
}
}
inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
int *x_trans_indexs,
const int max_dim,
const int x_one_size) {
int diff = max_dim - x_one_size;
std::copy_n(x_one_indexs, x_one_size, x_trans_indexs + diff);
int p = 0;
int q = diff;
for (int i = 0; i < max_dim; ++i) {
if (q < max_dim && i == x_trans_indexs[q]) {
++q;
} else {
x_trans_indexs[p++] = i;
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP>
__global__ void CommonGradBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const int *y_strides_order,
const int *y_dims_order, const T *x, const T *y, const T *out,
const T *dout, T *dx, int out_size, int max_dim, int thread_num,
DX_OP dx_op) {
T val(0);
int i = blockIdx.x;
int tid = threadIdx.x;
for (int j = tid; j < thread_num; j += blockDim.x) {
const int X_index = i * thread_num + j;
int out_index = X_index;
int C_index = 0;
int B_index = i * thread_num + j;
int remainder = 0;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(B_index, y_dims_order[d], &B_index, &remainder);
C_index += remainder * y_strides_order[d];
}
int x_index = 0;
int y_index = 0;
int C_index_val = C_index;
#pragma unroll
for (int d = max_dim - 1; d >= 0; --d) {
GetDivMod(C_index_val, out_dims_array[d], &C_index_val, &remainder);
x_index += remainder * x_strides_array[d];
y_index += remainder * y_strides_array[d];
}
out_index = C_index;
val += dx_op(x[x_index], y[y_index], out[out_index], dout[out_index]);
}
val = paddle::platform::reduceSum(val, tid, thread_num);
if (threadIdx.x == 0) {
dx[i] = val;
}
}
template <typename T, typename DX_OP, typename DY_OP>
void CommonGradBroadcastCUDA(
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out, const framework::Tensor &dout,
framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array,
int *y_dims_array, int *out_dims_array, int max_dim,
const platform::CUDADeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) {
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
auto cplace = platform::CPUPlace();
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const T *out_data = out.data<T>();
const T *dout_data = dout.data<T>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
std::vector<int> x_one_indexs;
std::vector<int> y_one_indexs;
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] != y_dims_array[i]) {
if (x_dims_array[i] == 1) {
x_one_indexs.push_back(i);
}
if (y_dims_array[i] == 1) {
y_one_indexs.push_back(i);
}
}
}
std::vector<int> x_trans_indexs(max_dim);
std::vector<int> y_trans_indexs(max_dim);
ComputeBroadcastTranspositionArray(x_one_indexs.data(), x_trans_indexs.data(),
max_dim, x_one_indexs.size());
ComputeBroadcastTranspositionArray(y_one_indexs.data(), y_trans_indexs.data(),
max_dim, y_one_indexs.size());
// compute array stride for cuda kernel;
// e.g. x.dims=[2,3,4], x_stride=[12,4,1]
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
std::vector<int> out_strides_array(max_dim);
int x_stride = 1;
int y_stride = 1;
int z_stride = 1;
for (int i = max_dim - 1; i >= 0; i--) {
x_strides_array[i] = x_dims_array[i] == 1 ? 0 : x_stride;
y_strides_array[i] = y_dims_array[i] == 1 ? 0 : y_stride;
out_strides_array[i] = z_stride;
x_stride *= x_dims_array[i];
y_stride *= y_dims_array[i];
z_stride *= out_dims_array[i];
}
std::vector<int> x_strides_order(max_dim);
std::vector<int> y_strides_order(max_dim);
std::vector<int> x_dims_order(max_dim);
std::vector<int> y_dims_order(max_dim);
for (int i = 0; i < max_dim; ++i) {
x_strides_order[i] = out_strides_array[x_trans_indexs[i]];
y_strides_order[i] = out_strides_array[y_trans_indexs[i]];
x_dims_order[i] = out_dims_array[x_trans_indexs[i]];
y_dims_order[i] = out_dims_array[y_trans_indexs[i]];
}
int x_blocks = 0;
int x_threads = 0;
ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks,
&x_threads, max_dim);
int y_blocks = 0;
int y_threads = 0;
ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks,
&y_threads, max_dim);
int bytes = max_dim * sizeof(int);
auto x_strides_array_tmp = memory::Alloc(ctx, bytes);
int *x_strides_array_gpu =
reinterpret_cast<int *>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_strides_array.data(),
bytes, ctx.stream());
auto y_strides_array_tmp = memory::Alloc(ctx, bytes);
int *y_strides_array_gpu =
reinterpret_cast<int *>(y_strides_array_tmp->ptr());
memory::Copy(gplace, y_strides_array_gpu, cplace, y_strides_array.data(),
bytes, ctx.stream());
auto out_dims_array_tmp = memory::Alloc(ctx, bytes);
int *out_dims_array_gpu = reinterpret_cast<int *>(out_dims_array_tmp->ptr());
memory::Copy(gplace, out_dims_array_gpu, cplace, out_dims_array, bytes,
ctx.stream());
const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim,
1, std::multiplies<int>());
int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads);
int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads);
if (dx) {
auto x_strides_order_tmp = memory::Alloc(ctx, bytes);
int *x_strides_order_gpu =
reinterpret_cast<int *>(x_strides_order_tmp->ptr());
memory::Copy(gplace, x_strides_order_gpu, cplace, x_strides_order.data(),
bytes, ctx.stream());
auto x_dims_order_tmp = memory::Alloc(ctx, bytes);
int *x_dims_order_gpu = reinterpret_cast<int *>(x_dims_order_tmp->ptr());
memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DX_OP><<<x_blocks, x_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
}
if (dy) {
auto y_strides_order_tmp = memory::Alloc(ctx, bytes);
int *y_strides_order_gpu =
reinterpret_cast<int *>(y_strides_order_tmp->ptr());
memory::Copy(gplace, y_strides_order_gpu, cplace, y_strides_order.data(),
bytes, ctx.stream());
auto y_dims_order_tmp = memory::Alloc(ctx, bytes);
int *y_dims_order_gpu = reinterpret_cast<int *>(y_dims_order_tmp->ptr());
memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DY_OP><<<y_blocks, y_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data,
dout_data, dy_data, out_size, max_dim, y_threads, dy_op);
}
}
#endif // __NVCC__
inline framework::DDim trim_trailing_singular_dims(
const framework::DDim &dims) {
// Remove trailing dimensions of size 1 for y
......@@ -121,7 +512,7 @@ inline framework::DDim trim_trailing_singular_dims(
for (; actual_dims_size != 0; --actual_dims_size) {
if (dims[actual_dims_size - 1] != 1) break;
}
if (actual_dims_size == dims.size()) return dims;
std::vector<int> trim_dims;
trim_dims.resize(actual_dims_size);
for (int i = 0; i < actual_dims_size; ++i) {
......@@ -287,13 +678,19 @@ template <typename Functor, typename T, typename DeviceContext,
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, const DeviceContext &ctx, Functor func)
framework::Tensor *z, const DeviceContext &ctx, Functor func,
const bool is_xsize_larger = true)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
nx_(x->numel()),
ctx_(ctx),
func_(func) {}
func_(func),
is_xsize_larger_(is_xsize_larger) {
if (is_xsize_larger_ == false) {
nx_ = y->numel();
}
}
inline void Run() const {
platform::Transform<DeviceContext> trans;
......@@ -302,22 +699,23 @@ class TransformFunctor {
inline void RunRowWise(int n, int pre) const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_, func_);
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(x_, n), z_, func_);
}
}
inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
}
inline void RunMidRowWise(int n, int pre, int post) const {
platform::Transform<DeviceContext> trans;
for (int i = 0; i < pre; i++) {
trans(ctx_, x_ + i * n * post, x_ + (i + 1) * n * post,
RowwiseTransformIterator<T, DeviceContext>(y_ + i * post, post),
z_ + i * n * post, func_);
if (is_xsize_larger_) {
trans(ctx_, x_, x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post), z_, func_);
} else {
trans(ctx_, y_, y_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(x_, n, post), z_, func_);
}
}
......@@ -328,6 +726,7 @@ class TransformFunctor {
int64_t nx_;
const DeviceContext &ctx_;
Functor func_;
bool is_xsize_larger_;
};
template <typename T, typename DX_OP, typename DY_OP>
......@@ -354,20 +753,42 @@ struct ElemwiseGradNoBroadcast {
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
const T *dout, int h, int w, DX_OP dx_op,
const T *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (is_xsize_larger) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int x_offset = i * w + j;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int y_offset = i * w + j;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
}
}
}
......@@ -378,28 +799,48 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
if (is_xsize_larger) {
do {
int x_offset = i * w + j;
if (dx) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
do {
int x_offset = i * w + j;
if (dx) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
} else { // x.dims < y.dims, broadcast for x.
do {
int y_offset = i * w + j;
if (dy) {
dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
......@@ -412,7 +853,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0);
......@@ -422,33 +863,66 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] = dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
if (is_xsize_larger) {
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] =
dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
}
}
}
if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
} else { // x.dims < y.dims, broadcast for x.
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int y_offset = n * w + m;
if (dy && m < w && n < h) {
dy[y_offset] =
dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
if (m < w && n < h) {
T val = dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
if (dx) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dx[m] = sdata[0][threadIdx.x];
}
}
}
}
......@@ -457,21 +931,21 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int h, int w, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int h, int w, bool is_xsize_larger,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
// For small case use 1D block
constexpr int half_walf = 16;
if (w < half_walf || h < half_walf) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
} else {
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
}
}
......@@ -480,21 +954,44 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (i == 0 && k == 0) {
dy[j] = tmp;
} else {
dy[j] += tmp;
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int y_offset = i * n * post + j * post + k;
if (dy != nullptr) {
dy[y_offset] =
dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
if (i == 0 && k == 0) {
dx[j] = tmp;
} else {
dx[j] += tmp;
}
}
}
}
......@@ -506,37 +1003,66 @@ static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0);
int ttid = tid;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
if (is_xsize_larger) {
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int x_offset = i * n * post + j * post + k;
int x_offset = i * n * post + j * post + k;
if (dx != nullptr) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (dx != nullptr) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else { // x.dims < y.dims, broadcast for x.
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
int y_offset = i * n * post + j * post + k;
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
if (dy != nullptr) {
dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx != nullptr) {
val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dx) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
......@@ -544,98 +1070,57 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int pre, int n, int post, DX_OP dx_op,
int pre, int n, int post,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
}
#endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
T tmp =
dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
if (j == 0) {
dy[y_offset] = tmp;
} else {
dy[y_offset] += tmp;
}
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void CommonElementwiseBroadcastBackward(
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dout.Holder() == dx->Holder()) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
}
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcastMid2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = threadIdx.x;
int tid = blockIdx.x;
T val(0);
int ttid = tid;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int x_offset = i * n * post + j * post + k;
int y_offset = i * post + k;
if (dx != nullptr) {
dx[x_offset] =
dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
if (dy != nullptr) {
val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]);
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (dy) {
int h = n;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, j, h);
if (threadIdx.x == 0) {
dy[tid] = val;
}
CommonGradBroadcastCUDA<T, DX_OP, DY_OP>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), dx_op,
dy_op);
#endif
} else {
CommonGradBroadcastCPU<T, DX_OP, DY_OP>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), dx_op,
dy_op);
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcastMid2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out,
const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, n);
int gird_size = pre * post;
ElemwiseGradBroadcastMid2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
......@@ -659,47 +1144,54 @@ void ElemwiseGradComputeNoBroadcast(
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
int pre, n, post, mid_flag = 0;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &mid_flag);
if (mid_flag) {
PADDLE_ENFORCE_EQ(mid_flag, 1, "mid_flag should be no more than 1.");
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcastMid2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcastMid2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else if (post == 1) {
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
if (post == 1) {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, dx_op, dy_op,
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, is_xsize_larger,
dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
dx_op, dy_op,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
......@@ -708,20 +1200,56 @@ void ElemwiseGradComputeWithBroadcast(
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post, dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
dx_op, dy_op,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void CommonElementwiseBroadcastForward(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z,
const framework::DDim &x_dims, const framework::DDim &y_dims, Functor func,
int axis, const bool is_xsize_larger = true) {
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(), max_dim,
axis);
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
CommonForwardBroadcastCUDA<Functor, T>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), func,
is_xsize_larger);
#endif
} else {
CommonForwardBroadcastCPU<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), func,
is_xsize_larger);
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x, const framework::Tensor &y,
......@@ -734,7 +1262,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
} else {
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
......@@ -752,103 +1280,61 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) {
const framework::DDim &dx_dims = dout.dims();
auto dy_dims = dx_dims;
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
if (dout.dims() == dy->dims()) {
const framework::DDim &dx_dims = dout.dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
auto dx_dims = dout.dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
}
// Deprecated
template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy) {
auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims();
auto y_dims = y->dims();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
if (x_dims == y_dims) {
functor f;
f(place, x, y, out, dx, dy, dout);
return;
}
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
trim_trailing_singular_dims(y_dims);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
if (post == 1) {
broadcastfunctor f;
f(place, x, y, out, dx, dy, dout, pre, n);
return;
ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
broadcast2functor f;
f(place, x, y, out, dx, dy, dout, pre, n, post);
return;
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims();
auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(
x_dims.size(), y_dims_untrimed.size(),
"ShapeError: the dimension of input X must greater than or equal to "
"the one of input Y. But received: the shape of input X = [%s], the "
"dimension of input X = %d, the shape of input Y = [%s], the dimension "
"of of input Y = %d",
x_dims, x_dims.size(), y_dims_untrimed, y_dims_untrimed.size());
if (x_dims == y_dims_untrimed) {
auto y_dims = y->dims();
bool is_xsize_larger = true;
int max_dim = x_dims.size();
if (x_dims.size() < y_dims.size()) {
is_xsize_larger = false;
max_dim = y_dims.size();
}
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func,
is_xsize_larger);
if (x_dims == y_dims) {
functor.Run();
return;
}
axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post, mid_flag = 0;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &mid_flag);
if (mid_flag) {
functor.RunMidRowWise(n, pre, post);
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(axis, 0, "Axis should be in range [0, %d)", axis);
PADDLE_ENFORCE_LT(axis, max_dim, "Axis should be in range [0, %d)", axis);
int pre, n, post, is_run_common_broadcast, axis_trim = 0;
if (is_xsize_larger) {
auto y_dims_trimed = trim_trailing_singular_dims(y_dims);
axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis;
get_mid_dims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
} else {
auto x_dims_trimed = trim_trailing_singular_dims(x_dims);
axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis;
get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post,
&is_run_common_broadcast);
}
// special case for common implementation.
// case 1: x=[2,3,1,5], y=[2,1,4,1]
// case 2: x=[2,3,4], y=[1,1,4]
if (is_run_common_broadcast == 1) {
CommonElementwiseBroadcastForward<Functor, DeviceContext, T, OutType>(
ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger);
return;
}
if (post == 1) {
......@@ -1114,9 +1600,8 @@ void FusedElemwiseAndActComputeWithBroadcast(
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
if (post == 1) {
int h = pre;
int w = n;
......@@ -1628,8 +2113,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post, &is_run_common_broadcast);
if (post == 1) {
int h = pre;
int w = n;
......@@ -1763,16 +2248,7 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
} else {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x.dims().size() >= y.dims().size();
if (x.dims().size() == y.dims().size()) {
for (int i = 0; i < x.dims().size(); ++i) {
if (x.dims()[i] < y.dims()[i]) {
bcast_y = false;
break;
}
}
}
bool bcast_y = x.numel() >= y.numel();
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
......
......@@ -93,13 +93,14 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub);
namespace ops = paddle::operators;
REGISTER_OPERATOR(
elementwise_sub_grad, ops::ElementwiseOpExplicitGrad,
ops::ElementwiseGradOpInplace, ops::ElementwiseGradNoBufVarsInference,
elementwise_sub_grad, ops::ElementwiseOpGrad, ops::ElementwiseGradOpInplace,
ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_sub_grad_grad,
......
......@@ -26,8 +26,13 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
if (x->numel() >= y->numel()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
}
template <typename DeviceContext, typename T, class Enable = void>
......@@ -98,13 +103,14 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
// skip out, x, y
// skip out
auto* out = dout;
auto *x = dout, *y = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
......
......@@ -108,8 +108,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims = trim_trailing_singular_dims(y_dims_untrimed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) {
functor.RunRowWise(n, pre);
......@@ -212,6 +213,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
}
} else {
// Execute default kernel when broadcast is needed
x = ctx.Input<Tensor>("X");
y = ctx.Input<Tensor>("Y");
ElemwiseExplicitGradCompute<paddle::platform::CPUDeviceContext, T,
IdentityGrad<T>, IdentityGrad<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad<T>(),
......
......@@ -91,8 +91,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const bool is_y_format_correct = y->format() == MKLDNNMemoryFormat::nc;
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
is_avx512_enabled) {
int pre, n, post;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) {
PADDLE_THROW("Not implemented when post is 1");
......@@ -168,8 +169,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
int pre, n, post, is_run_common_broadcast;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post,
&is_run_common_broadcast);
if (post == 1) {
functor.RunRowWise(n, pre);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
......@@ -139,21 +140,6 @@ struct DivAndSqrtFunctor {
T epsilon_;
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
......
......@@ -32,6 +32,7 @@ TEST(op_debug_str, test_unknown_dtype) {
framework::Scope scope;
desc.SetType("elementwise_add_grad");
desc.SetInput("X", {"X"});
desc.SetInput("Y", {"Y"});
desc.SetInput(framework::GradVarName("Out"), {framework::GradVarName("Out")});
desc.SetOutput(framework::GradVarName("X"), {framework::GradVarName("X")});
......@@ -41,6 +42,10 @@ TEST(op_debug_str, test_unknown_dtype) {
desc.SetAttr("x_data_format", "");
desc.SetAttr("y_data_format", "");
auto x_tensor = scope.Var("X")->GetMutable<framework::LoDTensor>();
x_tensor->Resize(dim);
x_tensor->mutable_data<float>(place);
auto y_tensor = scope.Var("Y")->GetMutable<framework::LoDTensor>();
y_tensor->Resize(dim);
y_tensor->mutable_data<float>(place);
......
......@@ -7,8 +7,8 @@ TURN_ON_MKL=$2 # use MKL or Openblas
# download models
function download() {
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR/main_program
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR/startup_program
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/main_program
wget -q http://paddle-tar.bj.bcebos.com/train_demo/LR-1-7/startup_program
}
download
......
......@@ -308,6 +308,36 @@ class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp):
self.axis = -1
class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1, 4).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(2, 3, 1, 5).astype(self.dtype)
self.y = np.random.rand(2, 1, 4, 1).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = -1
class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp):
def init_input_output(self):
self.x = np.random.rand(4, 5).astype(self.dtype)
self.y = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.out = self.x + self.y
def init_axis(self):
self.axis = 2
class TestElementwiseAddOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -151,6 +151,39 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_commonuse_1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [1, 1, 4]).astype("float32"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_commonuse_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 1, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 1, 4, 1]).astype("float32"),
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_xsize_lessthan_ysize(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [4, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"),
}
self.attrs = {'axis': 2}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_INT(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
......
......@@ -162,6 +162,41 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp):
self.dtype = np.float16
class TestElementwiseMulOp_commonuse_1(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(1, 1, 4).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_commonuse_2(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(2, 3, 1, 5).astype(np.float64),
'Y': np.random.rand(2, 1, 4, 1).astype(np.float64)
}
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
class TestElementwiseMulOp_xsize_lessthan_ysize(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
'X': np.random.rand(4, 5).astype(np.float64),
'Y': np.random.rand(2, 3, 4, 5).astype(np.float64)
}
self.attrs = {'axis': 2}
self.outputs = {
'Out': self.inputs['X'].reshape(1, 1, 4, 5) * self.inputs['Y']
}
class TestElementwiseMulOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -127,5 +127,40 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_commonuse_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(1, 1, 4).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_commonuse_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 1, 5).astype(np.float32),
'Y': np.random.rand(2, 1, 4, 1).astype(np.float32)
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_xsize_lessthan_ysize(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(4, 5).astype(np.float32),
'Y': np.random.rand(2, 3, 4, 5).astype(np.float32)
}
self.attrs = {'axis': 2}
self.outputs = {
'Out': self.inputs['X'].reshape(1, 1, 4, 5) - self.inputs['Y']
}
if __name__ == '__main__':
unittest.main()
......@@ -47,7 +47,7 @@ class TestExecutorReturnTensorNotOverwritingWithOptest(OpTest):
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.outputs = {'Out': self.out}
self.op_type = "elementwise_mul"
self.op_type = "mul"
self.dtype = np.float32
outs, fetch_list = self._calc_output(place, parallel=parallel)
return outs
......
......@@ -57,7 +57,7 @@ def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id):
exe.run(trainer_startup_program)
for i in range(5):
exe.run(recv_program)
exe.run(main_program,
exe.run(fluid.default_main_program(),
feed={
"x": numpy.array([1, 2]).astype('float32').reshape(2, 1),
"y": numpy.array([2, 3]).astype('float32').reshape(2, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册