提交 b7697f62 编写于 作者: D danleifeng 提交者: Yi Liu

fix broadcast bug;test=develop (#21898)

上级 e0d8b8f5
......@@ -25,7 +25,9 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} else {
......
......@@ -31,7 +31,9 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} else {
......
......@@ -71,7 +71,9 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} else {
......@@ -118,7 +120,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
z->mutable_data<T>(ctx.GetPlace());
if (x.numel() == y->numel()) {
auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z);
} else {
......
......@@ -26,7 +26,9 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) {
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册