未验证 提交 1f34f7ec 编写于 作者: J jakpiase 提交者: GitHub

Fix for expand_v2 op (#35101)

* temporary change

* fix for expand_v2

* changes after review, activated ppyolov inference test
上级 d618de2d
...@@ -101,7 +101,7 @@ TEST(test_ppyolo_mbv3, multi_thread4_trt_fp32_bz2) { ...@@ -101,7 +101,7 @@ TEST(test_ppyolo_mbv3, multi_thread4_trt_fp32_bz2) {
std::cout << "finish multi-thread test" << std::endl; std::cout << "finish multi-thread test" << std::endl;
} }
TEST(DISABLED_test_ppyolo_mbv3, multi_thread4_mkl_bz2) { TEST(test_ppyolo_mbv3, multi_thread4_mkl_bz2) {
// TODO(OliverLPH): mkldnn multi thread will fail // TODO(OliverLPH): mkldnn multi thread will fail
int thread_num = 4; int thread_num = 4;
// init input data // init input data
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/expand_v2_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace { namespace {
...@@ -37,16 +38,20 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -37,16 +38,20 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto x_vec_dims = vectorize(x->dims()); auto x_vec_dims = vectorize(x->dims());
auto out_vec_dims = vectorize(out->dims());
auto out_new_dims = paddle::operators::get_expand_shape(ctx);
for (size_t i = 0; i < out_new_dims.size(); ++i) {
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
}
dnnl::memory::format_tag x_format_tag = x->format(); dnnl::memory::format_tag x_format_tag = x->format();
if (x_vec_dims.size() != out_vec_dims.size()) { if (x_vec_dims.size() != out_new_dims.size()) {
x_format_tag = x_format_tag =
GetExtendedFormatTag(x_vec_dims, out_vec_dims.size(), x_format_tag); GetExtendedFormatTag(x_vec_dims, out_new_dims.size(), x_format_tag);
} }
out->Resize(paddle::framework::make_ddim(out_new_dims));
out->set_format(x_format_tag); out->set_format(x_format_tag);
paddle::platform::BroadcastDataMKLDNNHandler<T> handler( paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(), dnnl::algorithm::binary_add, dev_ctx, onednn_engine, ctx.GetPlace(),
out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims); out, x, 0.0f, 1.0f, ctx.InputName("X"), x_vec_dims);
......
...@@ -69,6 +69,50 @@ class TestExpandV2CopyScenarioShapeNotGivenOneDNNOp(TestExpandV2OneDNNOp): ...@@ -69,6 +69,50 @@ class TestExpandV2CopyScenarioShapeNotGivenOneDNNOp(TestExpandV2OneDNNOp):
self.expand_times = (1, 1, 1, 1) self.expand_times = (1, 1, 1, 1)
class TestExpandV2ExpandShapesTensor1OneDNNOp(TestExpandV2OneDNNOp):
def init_data(self):
self.ori_shape = [100, 1]
self.expand_times = [1, 2]
self.expand_shape = [100, 2]
self.shape = [-1, -1]
def calc_expand_shapes_tensor(self):
self.expand_shapes_tensor = []
for index, ele in enumerate(self.expand_shape):
self.expand_shapes_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
def set_inputs(self):
self.calc_expand_shapes_tensor()
self.inputs = {
'X': self.x,
'expand_shapes_tensor': self.expand_shapes_tensor
}
class TestExpandV2ExpandShapesTensor2OneDNNOp(
TestExpandV2ExpandShapesTensor1OneDNNOp):
def init_data(self):
self.ori_shape = [12, 14]
self.expand_times = [1, 1]
self.expand_shape = [12, 14]
self.shape = [12, -1]
class TestExpandV2ShapesTensorOneDNNOp(TestExpandV2OneDNNOp):
def init_data(self):
self.ori_shape = [100]
self.expand_times = [2, 1]
self.expand_shape = [2, 100]
self.shape = [-1, -1]
def set_inputs(self):
self.inputs = {
'X': self.x,
'Shape': np.array(self.expand_shape).astype("int32")
}
# BF16 TESTS # BF16 TESTS
def create_expand_v2_bf16_test_class(parent): def create_expand_v2_bf16_test_class(parent):
@OpTestTool.skip_if_not_cpu_bf16() @OpTestTool.skip_if_not_cpu_bf16()
...@@ -101,6 +145,9 @@ create_expand_v2_bf16_test_class(TestExpandV2OneDNNOp) ...@@ -101,6 +145,9 @@ create_expand_v2_bf16_test_class(TestExpandV2OneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2ExpandDimOneDNNOp) create_expand_v2_bf16_test_class(TestExpandV2ExpandDimOneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioOneDNNOp) create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioOneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioShapeNotGivenOneDNNOp) create_expand_v2_bf16_test_class(TestExpandV2CopyScenarioShapeNotGivenOneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2ExpandShapesTensor1OneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2ExpandShapesTensor2OneDNNOp)
create_expand_v2_bf16_test_class(TestExpandV2ShapesTensorOneDNNOp)
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册