提交 f99841dd 编写于 作者: G gongweibao 提交者: GitHub

Elementwise operator. (#4139)

Elementwise operator add/sub/mul/div
上级 2d8467ee
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_add_op.h"
namespace paddle {
namespace operators {
class ElementwiseAddOpMaker : public ElementwiseOpMaker {
public:
ElementwiseAddOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("add", "Out = X + Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
elementwise_add_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/elementwise_add_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_add,
ops::ElementwiseAddKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseAddKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx);
}
};
template <typename T>
struct ElementwiseAddGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e;
}
}
};
template <typename T>
struct ElementwiseAddOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e.sum();
}
}
};
template <typename T>
struct ElementwiseAddBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
}
};
template <typename T>
struct ElementwiseAddBroadCast2GradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N, typename Post>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
Post post) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = dz_e.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
}
};
template <typename Place, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>,
ElementwiseAddOneGradFunctor<T>,
ElementwiseAddBroadCastGradFunctor<T>,
ElementwiseAddBroadCast2GradFunctor<T>>(ctx);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_div_op.h"
namespace paddle {
namespace operators {
class ElementwiseDivOpMaker : public ElementwiseOpMaker {
public:
ElementwiseDivOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Div", "Out = X / Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/elementwise_div_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseDivKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx);
}
};
template <typename T>
struct ElementwiseDivGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto z_e = framework::EigenVector<T>::Flatten(*z);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e / y_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = -1.0 * dz_e * z_e / y_e;
}
}
};
template <typename T>
struct ElementwiseDivBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e / y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast))
.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
}
};
template <typename T>
struct ElementwiseDivBroadCast2GradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N, typename Post>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
Post post) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e / y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast))
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
}
};
template <typename Place, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>,
ElementwiseDivGradFunctor<T>,
ElementwiseDivBroadCastGradFunctor<T>,
ElementwiseDivBroadCast2GradFunctor<T>>(ctx);
}
};
} // namespace operators
} // namespace paddle
......@@ -17,104 +17,25 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class ElementWiseMulOp : public framework::OperatorWithKernel {
class ElementwiseMulOpMaker : public ElementwiseOpMaker {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of ElementWiseMulOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of ElementWiseMulOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of ElementWiseMulOp should not be null.");
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
ctx.ShareLoD("X", /*->*/ "Out");
}
};
class ElementWiseMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementWiseMulOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The first input of elementwise mul op");
AddInput("Y", "The second input of elementwise mul op");
AddAttr<int>("axis",
R"DOC(
When shape(Y) does not equal shape(X),Y will be broadcasted
to match the shape of X and axis should be dimension index Y in X
)DOC")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddOutput("Out", "The output of elementwise mul op");
AddComment(R"DOC(
Limited elementwise multiple operator.The equation is: Out = X ⊙ Y.
1. The shape of Y should be same with X or
2. Y's shape is a subset of X.
Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
example:
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
Both the input X and Y can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with input X.
)DOC");
ElementwiseMulOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Mul", "Out = X ⊙ Y");
AddComment(comment_);
}
};
class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_grad) {
x_grad->Resize(x_dims);
}
if (y_grad) {
y_grad->Resize(y_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_mul, ops::ElementWiseMulOp, ops::ElementWiseMulOpMaker,
elementwise_mul_grad, ops::ElementWiseMulOpGrad);
REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementWiseMulKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementWiseMulGradKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>);
......@@ -19,7 +19,7 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_mul,
ops::ElementWiseMulKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
elementwise_mul_grad,
ops::ElementWiseMulGradKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>);
......@@ -13,171 +13,104 @@
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op.h"
namespace paddle {
namespace operators {
/*
* Out = X ⊙ Y
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
*/
inline void get_mid_dims(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis,
int& pre, int& n, int& post) {
pre = 1;
n = 1;
post = 1;
for (int i = 0; i < axis; ++i) {
pre *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
n *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
post *= x_dims[i];
}
}
template <typename Place, typename T>
class ElementWiseMulKernel : public framework::OpKernel {
class ElementwiseMulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx);
}
};
template <typename T>
struct ElementwiseMulGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto z_e = framework::EigenVector<T>::Flatten(*z);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_dims == y_dims || product(y_dims) == 1) {
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_e;
return;
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_bcast;
return;
} else {
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
z_e.device(ctx.GetEigenDevice<Place>()) = x_e * y_bcast;
return;
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = x_e * dz_e;
}
}
};
template <typename Place, typename T>
class ElementWiseMulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
template <typename T>
struct ElementwiseMulBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dout_e = framework::EigenVector<T>::Flatten(*dout);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto x_dims = x->dims();
auto y_dims = y->dims();
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e_bcast;
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (x_e * dz_e)
.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
}
};
if (x_dims == y_dims || product(y_dims) == 1) {
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(ctx.GetEigenDevice<Place>()) = x_e * dout_e;
}
return;
template <typename T>
struct ElementwiseMulBroadCast2GradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N, typename Post>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
Post post) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e * y_e_bcast;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
// TODO(gongweibao): wrap reshape to a function.
if (post == 1) {
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n))
.broadcast(Eigen::DSizes<int, 2>(pre, 1))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(ctx.GetEigenDevice<Place>()) =
(x_e * dout_e)
.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
return;
} else {
auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1))
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post))
.reshape(Eigen::DSizes<int, 1>(x_e.size()));
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(ctx.GetEigenDevice<Place>()) = dout_e * y_e_bcast;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(ctx.GetEigenDevice<Place>()) =
(x_e * dout_e)
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
return;
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (x_e * dz_e)
.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
}
};
template <typename Place, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>,
ElementwiseMulGradFunctor<T>,
ElementwiseMulBroadCastGradFunctor<T>,
ElementwiseMulBroadCast2GradFunctor<T>>(ctx);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
/*
* Out = X ⊙ Y
* If Y's shape does not match X' shape, they will be reshaped.
* For example:
* 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
* pre=2, n=3*4, post=5
* x.shape(2, 12, 5) * y.shape(1,12,1).broadcast(2,12,5)
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(2, 3, 20) * y.shape(1,1,20).broadcast(2,3,20)
*/
inline void get_mid_dims(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis,
int& pre, int& n, int& post) {
pre = 1;
n = 1;
post = 1;
for (int i = 0; i < axis; ++i) {
pre *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
n *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
post *= x_dims[i];
}
}
#define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \
template <typename Place, typename T> \
inline void Run(const framework::Tensor* x, const framework::Tensor* y, \
framework::Tensor* z, \
const framework::ExecutionContext& ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_e); \
} \
template <typename Place, typename T> \
inline void RunBroadCast(const framework::Tensor* x, \
const framework::Tensor* y, framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \
int n) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) \
.broadcast(Eigen::DSizes<int, 2>(pre, 1)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \
} \
template <typename Place, typename T> \
inline void RunBroadCast2(const framework::Tensor* x, \
const framework::Tensor* y, \
framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \
int n, int post) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
auto y_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) \
.broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) \
.reshape(Eigen::DSizes<int, 1>(x_e.size())); \
z_e.device(ctx.GetEigenDevice<Place>()) = eigen_op(x_e, y_bcast); \
} \
}
template <class functor, typename Place, typename T>
void ElementwiseCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_dims == y_dims || product(y_dims) == 1) {
functor f;
f.template Run<Place, T>(x, y, z, ctx);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor f;
f.template RunBroadCast<Place, T>(x, y, z, ctx, pre, n);
return;
} else {
functor f;
f.template RunBroadCast2<Place, T>(x, y, z, ctx, pre, n, post);
return;
}
}
#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);
#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);
#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename Place, typename T, typename functor, typename functor1,
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto place = ctx.GetEigenDevice<Place>();
auto x_dims = x->dims();
auto y_dims = y->dims();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
if (x_dims == y_dims) {
functor f;
f(place, x, y, out, dx, dy, dout);
return;
}
if (product(y_dims) == 1) {
functor1 f;
f(place, x, y, out, dx, dy, dout);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
broadcastfunctor f;
f(place, x, y, out, dx, dy, dout, pre, n);
return;
} else {
broadcast2functor f;
f(place, x, y, out, dx, dy, dout, pre, n, post);
return;
}
}
class ElementwiseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
using Tensor = framework::Tensor;
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"),
"Input(Y) of elementwise op should not be null");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Out"),
"Output(Out) of elementwise op should not be null.");
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
ctx.ShareLoD("X", /*->*/ "Out");
}
};
class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ElementwiseOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", R"DOC(
The first input of elementwise op, it's a tensor of any dimensions.
)DOC");
AddInput("Y", R"DOC(
The sencond input of elementwise op, it's a tensor and it's dimensions
must be small or equal to X's dimensions.
)DOC");
AddAttr<int>("axis",
R"DOC(
When the shape(Y) does not equal the shape(X),Y will be broadcasted
to match the shape of X and axis should be dimension index Y in X
)DOC")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddOutput("Out", "The output of elementwise op");
comment_ = R"DOC(
Limited elementwise {name} operator.The equation is: Out = {equation}.
1. The shape of Y should be same with X or
2. Y's shape is a subset of X.
Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
example:
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
Both the input X and Y can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD with input X.
)DOC";
AddComment(comment_);
}
protected:
std::string comment_;
void Replace(std::string& src, std::string from, std::string to) {
std::size_t len_from = std::strlen(from.c_str());
std::size_t len_to = std::strlen(to.c_str());
for (std::size_t pos = src.find(from); pos != std::string::npos;
pos = src.find(from, pos + len_to)) {
src.replace(pos, len_from, to);
}
}
void SetComment(std::string name, std::string equation) {
Replace(comment_, "{name}", name);
Replace(comment_, "{equation}", equation);
}
};
class ElementwiseOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_grad) {
x_grad->Resize(x_dims);
}
if (y_grad) {
y_grad->Resize(y_dims);
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_sub_op.h"
namespace paddle {
namespace operators {
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
public:
ElementwiseSubOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Sub", "Out = X - Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
elementwise_sub_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/elementwise_sub_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_op.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseSubKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx);
}
};
template <typename T>
struct ElementwiseSubGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) * dz_e;
}
}
};
template <typename T>
struct ElementwiseSubOneGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) * dz_e.sum();
}
}
};
template <typename T>
struct ElementwiseSubBroadCastGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) *
dz_e.reshape(Eigen::DSizes<int, 2>(pre, n))
.sum(Eigen::array<int, 1>{{0}});
}
}
};
template <typename T>
struct ElementwiseSubBroadCast2GradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ, typename Pre, typename N, typename Post>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n,
Post post) {
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
dx_e.device(d) = dz_e;
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
dy_e.device(d) = (-1.0) *
dz_e.reshape(Eigen::DSizes<int, 3>(pre, n, post))
.sum(Eigen::array<int, 2>{{0, 2}});
}
}
};
template <typename Place, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>,
ElementwiseSubOneGradFunctor<T>,
ElementwiseSubBroadCastGradFunctor<T>,
ElementwiseSubBroadCast2GradFunctor<T>>(ctx);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseAddOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.random((32, )).astype("float32"),
'Y': np.random.random((32, )).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseAddOp_broadcast_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(2).astype(np.float32)
}
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(2, 1, 1)
}
class TestElementwiseAddOp_broadcast_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 1)
}
class TestElementwiseAddOp_broadcast_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(4).astype(np.float32)
}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1, 4)
}
class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4, 1)
}
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class ElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
""" Warning
CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32")
"""
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.05, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
class TestElementwiseDivOp_Vector(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [32]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [32]).astype("float32")
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseDivOp_broadcast_0(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [2]).astype("float32")
}
self.attrs = {'axis': 0}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(2, 1, 1))
}
class TestElementwiseDivOp_broadcast_1(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [3]).astype("float32")
}
self.attrs = {'axis': 1}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 1))
}
class TestElementwiseDivOp_broadcast_2(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [4]).astype("float32")
}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 1, 4))
}
class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
def setUp(self):
self.op_type = "elementwise_div"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [3, 4]).astype("float32")
}
self.attrs = {'axis': 1}
self.outputs = {
'Out':
np.divide(self.inputs['X'], self.inputs['Y'].reshape(1, 3, 4, 1))
}
if __name__ == '__main__':
unittest.main()
......@@ -3,14 +3,9 @@ import numpy as np
from op_test import OpTest
class TestElementwiseMulOp_Matrix(OpTest):
class ElementwiseMulOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mul"
""" Warning
CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32")
"""
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
......@@ -32,7 +27,7 @@ class TestElementwiseMulOp_Matrix(OpTest):
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_Vector(OpTest):
class TestElementwiseMulOp_Vector(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
......@@ -41,22 +36,8 @@ class TestElementwiseMulOp_Vector(OpTest):
}
self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_broadcast_0(OpTest):
class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
......@@ -69,22 +50,8 @@ class TestElementwiseMulOp_broadcast_0(OpTest):
'Out': self.inputs['X'] * self.inputs['Y'].reshape(2, 1, 1)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_broadcast_1(OpTest):
class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
......@@ -97,22 +64,8 @@ class TestElementwiseMulOp_broadcast_1(OpTest):
'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 3, 1)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_broadcast_2(OpTest):
class TestElementwiseMulOp_broadcast_2(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
......@@ -124,22 +77,8 @@ class TestElementwiseMulOp_broadcast_2(OpTest):
'Out': self.inputs['X'] * self.inputs['Y'].reshape(1, 1, 4)
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_broadcast_3(OpTest):
class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.inputs = {
......
import unittest
import numpy as np
from op_test import OpTest
class TestElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseSubOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.random((32, )).astype("float32"),
'Y': np.random.random((32, )).astype("float32")
}
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
class TestElementwiseSubOp_broadcast_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(2).astype(np.float32)
}
self.attrs = {'axis': 0}
self.outputs = {
'Out': self.inputs['X'] - self.inputs['Y'].reshape(2, 1, 1)
}
class TestElementwiseSubOp_broadcast_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 3, 1)
}
class TestElementwiseSubOp_broadcast_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(4).astype(np.float32)
}
self.outputs = {
'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 1, 4)
}
class TestElementwiseSubOp_broadcast_3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_sub"
self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] - self.inputs['Y'].reshape(1, 3, 4, 1)
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册