提交 ecd9f330 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Fix to face model on AVX512 platforms (#19282)

- Refactor step 1

- Compilation fix

- Yet another compilation fix

- Even more compilation fix

- Lint fixes

test=develop

- Removed deprectaed PADDLE_ENFORCE occurance

test=develop

- Candidate fix to BN forward

- Lint fixes

test=develop

- Refactoring in data_layout_transform

- compilation fix

- Another comppilation fix

- Step further into darkness

- Yet another compilation fix

- Yet another compilation fix

- missing header

- compilation fix

- Added MKLDNN -> Paddle conversion in fetch op

test=develop

- Compilation fix

test=develop

- Lint

test=develop

- Mul fix

- Fix to MKLDNN MUL op and Elementwise MUL UT

test=develop

- Workaround for diffrent weights with groups representation Paddle vs
  MKL-DNN.

test=develop

- Candidate fix for 5D convolution with groups

- Refactor of fix for conv3d and conv2d in fetch op

test=develop

- Compilation fix

- Still same compilation fix

- Compilation fix

- Compilation fix

- Reverted refactoring of fixes

- Adapted test_conv2d_int8_mkldnn so it exects data in NCHW format
  not NHWC

test=develop

- minor fix in UT

test=develop

- Lint fixes

test=develop
上级 e8405e5c
...@@ -121,12 +121,19 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -121,12 +121,19 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const Tensor& in, Tensor* out) { const Tensor& in, Tensor* out) {
auto in_layout = kernel_type_for_var.data_layout_; auto in_layout = kernel_type_for_var.data_layout_;
auto out_layout = expected_kernel_type.data_layout_; auto out_layout = expected_kernel_type.data_layout_;
auto place = expected_kernel_type.place_;
PADDLE_ENFORCE( PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN, in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"); "non-MKLDNN");
innerTransDataLayoutFromMKLDNN(in_layout, out_layout, in, out, place);
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
PADDLE_ENFORCE(in.format() != memory::format::format_undef && PADDLE_ENFORCE(in.format() != memory::format::format_undef &&
in.format() != memory::format::any, in.format() != memory::format::any,
...@@ -137,8 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -137,8 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout; out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance(); auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>( auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
pool.Get(expected_kernel_type.place_));
auto& cpu_engine = dev_ctx->GetEngine(); auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims()); std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims());
...@@ -165,7 +171,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -165,7 +171,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data); auto reorder_src_memory_p = handler.AcquireSrcMemory(in_format, in_data);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out_format, expected_kernel_type.place_); handler.AcquireDstMemory(out, out_format, place);
auto reorder_p = auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
......
...@@ -69,6 +69,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -69,6 +69,10 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out,
platform::Place place);
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -55,7 +56,16 @@ class FetchOp : public framework::OperatorBase { ...@@ -55,7 +56,16 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
if (src_item.IsInitialized() && src_item.numel() > 0) { if (src_item.IsInitialized() && src_item.numel() > 0) {
TensorCopySync(src_item, platform::CPUPlace(), &dst_item); // Conversion from MKL-DNN to Paddle
if (src_item.layout() == framework::DataLayout::kMKLDNN) {
framework::Tensor out;
framework::innerTransDataLayoutFromMKLDNN(
src_item.layout(), framework::DataLayout::kNCHW, src_item, &out,
platform::CPUPlace());
TensorCopySync(out, platform::CPUPlace(), &dst_item);
} else {
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
}
} else { } else {
// Not copy, if the src tensor is empty. // Not copy, if the src tensor is empty.
dst_item.clear(); dst_item.clear();
......
...@@ -87,7 +87,6 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -87,7 +87,6 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
...@@ -119,7 +118,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -119,7 +118,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data)); auto src_memory_p = handler.AcquireSrcMemory(md, to_void_cast<T>(x_data));
auto dst_memory_p = auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(y_data)); handler.AcquireDstMemoryFromPrimitive<T>(y, ctx.GetPlace());
auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p); auto activation_p = handler.AcquireActivation(dst_memory_p, src_memory_p);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
......
...@@ -58,6 +58,15 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -58,6 +58,15 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
} }
template <typename T>
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(
framework::Tensor *output, platform::Place place) {
T *ptr = output->mutable_data<T>(
place, batch_norm_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(
batch_norm_pd_->dst_primitive_desc(), ptr, "@dst_mem_p");
}
std::shared_ptr<batch_norm_fwd::primitive_desc> std::shared_ptr<batch_norm_fwd::primitive_desc>
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc,
const mkldnn::engine &engine) { const mkldnn::engine &engine) {
...@@ -189,7 +198,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -189,7 +198,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
const T *mean_data = mean->data<T>(); const T *mean_data = mean->data<T>();
const T *variance_data = variance->data<T>(); const T *variance_data = variance->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace()); T *mean_out_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace()); T *variance_out_data = variance_out->mutable_data<T>(ctx.GetPlace());
T *batch_mean_data = nullptr; T *batch_mean_data = nullptr;
...@@ -250,8 +258,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -250,8 +258,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data()); handler.AcquireScaleshiftMemoryFromPrimitive(scaleshift_data.data());
// create mkldnn memory for output y tensor // create mkldnn memory for output y tensor
auto dst_memory = handler.AcquireDstMemory( auto dst_memory =
batch_norm_fwd_pd->dst_primitive_desc().desc(), y_data); handler.AcquireDstMemoryFromPrimitive<T>(y, ctx.GetPlace());
std::shared_ptr<batch_norm_fwd> batch_norm_p; std::shared_ptr<batch_norm_fwd> batch_norm_p;
if (global_stats) { if (global_stats) {
...@@ -334,6 +342,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -334,6 +342,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T *scale_data = scale->data<T>(); const T *scale_data = scale->data<T>();
const T *shift_data = shift->data<T>(); const T *shift_data = shift->data<T>();
T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace()); T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace()); T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace()); T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
......
...@@ -421,7 +421,8 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -421,7 +421,8 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
out->Resize(out_dims); out->Resize(out_dims);
} }
out->set_layout(DataLayout::kMKLDNN); out->set_layout(DataLayout::kMKLDNN);
out->set_format(out->format()); out->set_format(platform::MKLDNNFormatForSize(
out_dims.size(), mkldnn::memory::format::nchw));
} }
}; };
......
...@@ -131,7 +131,14 @@ inline mkldnn::memory::format MKLDNNFormatForSize( ...@@ -131,7 +131,14 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
} else if (data_format == mkldnn::memory::format::nhwc) { } else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::nwc; return mkldnn::memory::format::nwc;
} }
} else if (dims_size == 4) {
if (data_format == mkldnn::memory::format::goihw) {
return mkldnn::memory::format::oihw;
}
} else if (dims_size == 5) { } else if (dims_size == 5) {
if (data_format == mkldnn::memory::format::goidhw) {
return mkldnn::memory::format::oidhw;
}
if (data_format == mkldnn::memory::format::nchw) { if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncdhw; return mkldnn::memory::format::ncdhw;
} else if (data_format == mkldnn::memory::format::nhwc) { } else if (data_format == mkldnn::memory::format::nhwc) {
......
...@@ -337,27 +337,26 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -337,27 +337,26 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
const std::string key_activation_pd = key_common_ + "@activation_pd"; const std::string key_activation_pd = key_common_ + "@activation_pd";
activation_pd_ = fwd_pd_ = std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>( dev_ctx_.GetBlob(key_activation_pd));
dev_ctx_.GetBlob(key_activation_pd)); if (fwd_pd_ == nullptr) {
if (activation_pd_ == nullptr) {
static std::mutex acquire_barrier; static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job( std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier); acquire_barrier);
activation_pd_ = fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>( std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd)); dev_ctx_.GetBlob(key_activation_pd));
if (activation_pd_ == nullptr) { if (fwd_pd_ == nullptr) {
auto activation_desc = mkldnn::eltwise_forward::desc( auto activation_desc = mkldnn::eltwise_forward::desc(
prop_kind, algorithm, md, alpha, beta); prop_kind, algorithm, md, alpha, beta);
activation_pd_.reset(new mkldnn::eltwise_forward::primitive_desc( fwd_pd_.reset(new mkldnn::eltwise_forward::primitive_desc(
activation_desc, engine_)); activation_desc, engine_));
dev_ctx_.SetBlob(key_activation_pd, activation_pd_); dev_ctx_.SetBlob(key_activation_pd, fwd_pd_);
} }
} }
return activation_pd_; return fwd_pd_;
} }
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> std::shared_ptr<mkldnn::eltwise_backward::primitive_desc>
...@@ -366,23 +365,22 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -366,23 +365,22 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
const mkldnn::memory::desc& src_md, float alpha, float beta) { const mkldnn::memory::desc& src_md, float alpha, float beta) {
const std::string key_activation_pd = key_common_ + "@activation_pd"; const std::string key_activation_pd = key_common_ + "@activation_pd";
const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd"; const std::string key_activation_bwd_pd = key_ + "@activation_bwd_pd";
activation_bwd_pd_ = bwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>( std::static_pointer_cast<mkldnn::eltwise_backward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_bwd_pd)); dev_ctx_.GetBlob(key_activation_bwd_pd));
if (activation_bwd_pd_ == nullptr) { if (bwd_pd_ == nullptr) {
activation_pd_ = fwd_pd_ =
std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>( std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
dev_ctx_.GetBlob(key_activation_pd)); dev_ctx_.GetBlob(key_activation_pd));
// PD from FWD op has to exist. // PD from FWD op has to exist.
PADDLE_ENFORCE(activation_pd_ != nullptr, PADDLE_ENFORCE_NOT_NULL(fwd_pd_, "Eltwise MKL-DNN not found in cache!");
"Eltwise MKL-DNN not found in cache!");
auto backward_desc = mkldnn::eltwise_backward::desc( auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_md, src_md, alpha, beta); algorithm, diff_dst_md, src_md, alpha, beta);
activation_bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc( bwd_pd_.reset(new mkldnn::eltwise_backward::primitive_desc(
backward_desc, engine_, *activation_pd_)); backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_activation_bwd_pd, activation_bwd_pd_); dev_ctx_.SetBlob(key_activation_bwd_pd, bwd_pd_);
} }
return activation_bwd_pd_; return bwd_pd_;
} }
std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation( std::shared_ptr<mkldnn::eltwise_forward> AcquireActivation(
...@@ -395,22 +393,25 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -395,22 +393,25 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (eltwise_p == nullptr) { if (eltwise_p == nullptr) {
eltwise_p = std::make_shared<mkldnn::eltwise_forward>( eltwise_p = std::make_shared<mkldnn::eltwise_forward>(
*activation_pd_, *(src_memory_p), *(dst_memory_p)); *fwd_pd_, *(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_p); dev_ctx_.SetBlob(prim_key, eltwise_p);
} }
return eltwise_p; return eltwise_p;
} }
// TODO(jczaja): Merge all AcquireDstMemoryFromPrimitive into one template <typename T>
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(
return this->AcquireMemoryFromPrimitive( framework::Tensor* output, platform::Place place) {
activation_pd_->dst_primitive_desc(), ptr, "@dst_mem_p"); T* ptr = output->mutable_data<T>(place,
fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr,
"@dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) { std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
activation_bwd_pd_->diff_src_primitive_desc(), ptr, "@diff_src_mem_p"); ptr, "@diff_src_mem_p");
} }
std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward( std::shared_ptr<mkldnn::eltwise_backward> AcquireActivationBackward(
...@@ -424,7 +425,7 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -424,7 +425,7 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
dev_ctx_.GetBlob(prim_key)); dev_ctx_.GetBlob(prim_key));
if (eltwise_bwd_p == nullptr) { if (eltwise_bwd_p == nullptr) {
eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>( eltwise_bwd_p = std::make_shared<mkldnn::eltwise_backward>(
*activation_bwd_pd_, *(src_memory_p), *(diff_dst_memory_p), *bwd_pd_, *(src_memory_p), *(diff_dst_memory_p),
*(diff_src_memory_p)); *(diff_src_memory_p));
dev_ctx_.SetBlob(prim_key, eltwise_bwd_p); dev_ctx_.SetBlob(prim_key, eltwise_bwd_p);
} }
...@@ -449,8 +450,8 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -449,8 +450,8 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
} }
private: private:
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> activation_pd_; std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> activation_bwd_pd_; std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd_;
}; };
class LRNMKLDNNHandler : public MKLDNNHandler { class LRNMKLDNNHandler : public MKLDNNHandler {
......
...@@ -20,14 +20,12 @@ import numpy as np ...@@ -20,14 +20,12 @@ import numpy as np
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp
from mkldnn_op_test import format_reorder
def conv2d_forward_refer(input, filter, group, conv_param): def conv2d_forward_refer(input, filter, group, conv_param):
out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group, out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group,
conv_param) conv_param)
size = [in_n, out_c, out_h, out_w] return out
return format_reorder(out, size)
class TestConv2dInt8Op(TestConv2dOp): class TestConv2dInt8Op(TestConv2dOp):
...@@ -79,10 +77,8 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -79,10 +77,8 @@ class TestConv2dInt8Op(TestConv2dOp):
if self.fuse_residual: if self.fuse_residual:
input_residual = np.random.randint( input_residual = np.random.randint(
-5, 5, self.input_residual_size).astype(self.srctype) -5, 5, self.input_residual_size).astype(self.srctype)
output_tmp = np.round(output1 - output2 + format_reorder( output_tmp = np.round(output1 - output2 + input_residual.astype(
input_residual, self.input_residual_size).astype( self.srctype) * (self.scale_out / self.scale_in_eltwise))
self.srctype) * (self.scale_out / self.scale_in_eltwise
))
if self.fuse_activation == "relu": if self.fuse_activation == "relu":
output = np.maximum(output_tmp, 0).astype(self.dsttype) output = np.maximum(output_tmp, 0).astype(self.dsttype)
else: else:
...@@ -109,10 +105,9 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -109,10 +105,9 @@ class TestConv2dInt8Op(TestConv2dOp):
input_residual = np.random.randint( input_residual = np.random.randint(
0, 10, self.input_residual_size).astype(self.srctype) 0, 10, self.input_residual_size).astype(self.srctype)
output_tmp_res = np.round(output1 * (self.scale_out / ( output_tmp_res = np.round(output1 * (self.scale_out / (
self.scale_in * self.scale_weights[0])) + format_reorder( self.scale_in * self.scale_weights[
input_residual, self.input_residual_size).astype( 0])) + input_residual.astype(np.int32) * (
np.int32) * (self.scale_out / self.scale_in_eltwise self.scale_out / self.scale_in_eltwise))
))
if self.fuse_activation == "relu": if self.fuse_activation == "relu":
output = np.maximum(output_tmp_res, 0).astype(self.dsttype) output = np.maximum(output_tmp_res, 0).astype(self.dsttype)
else: else:
......
...@@ -182,7 +182,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp): ...@@ -182,7 +182,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
y = np.random.rand(1, 16, 2, 2).astype(self.dtype) y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = self.x * self.y self.out = x * y
def setUp(self): def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp() super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
...@@ -213,7 +213,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp): ...@@ -213,7 +213,7 @@ class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
y = np.random.rand(1, 16, 2, 2).astype(self.dtype) y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) self.y = y.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
self.out = self.x * self.y self.out = x * y
def setUp(self): def setUp(self):
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp() super(TestElementwiseMulMKLDNNOp_FallbackNoReorders, self).setUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册