提交 b7697f62 编写于 作者: D danleifeng 提交者: Yi Liu

fix broadcast bug;test=develop (#21898)

上级 e0d8b8f5
...@@ -25,7 +25,9 @@ void default_elementwise_add(const framework::ExecutionContext &ctx, ...@@ -25,7 +25,9 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, framework::Tensor *z) { const framework::Tensor *y, framework::Tensor *z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) { auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z); AddFunctor<T>(), z);
} else { } else {
......
...@@ -31,7 +31,9 @@ void default_elementwise_div(const framework::ExecutionContext& ctx, ...@@ -31,7 +31,9 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) { auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z); DivFunctor<T>(), z);
} else { } else {
......
...@@ -71,7 +71,9 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx, ...@@ -71,7 +71,9 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) { auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z); MulFunctor<T>(), z);
} else { } else {
...@@ -118,7 +120,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -118,7 +120,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
} }
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
if (x.numel() == y->numel()) { auto dims_equal = x.dims() == y->dims();
if (dims_equal) {
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul; SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
same_dims_mul(ctx, &x, y, z); same_dims_mul(ctx, &x, y, z);
} else { } else {
......
...@@ -26,7 +26,9 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx, ...@@ -26,7 +26,9 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) { const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (x->numel() >= y->numel()) { auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z); SubFunctor<T>(), z);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册