未验证 提交 fc6eed5b 编写于 作者: J jakpiase 提交者: GitHub

Added mul BF16/FP32 FWD/BWD oneDNN kernel (#38552)

* base changes for mul reimplementation

* empty commit

* tmp save

* full implementation of mul bf16/fp32 fwd bwd

* CI fix

* CI rerun

* changed unity build cmake to avoid gpu issues

* removed mul mkldnn from unity build

* added skipping tests if not cpu_bf16

* CI fix

* CI fix

* CI fix
上级 281644cd
......@@ -20,6 +20,7 @@ using dnnl::memory;
using dnnl::primitive;
using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
......@@ -107,114 +108,6 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext& ctx,
return strides;
}
template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
if (is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data));
}
};
bool IsOutputFused(const ExecutionContext& ctx) {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace framework {
......@@ -32,13 +32,17 @@ namespace operators {
using framework::DataLayout;
using framework::DDim;
using framework::ExecutionContext;
using framework::LoDTensor;
using framework::Tensor;
using platform::MatMulV2MKLDNNHandler;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::prop_kind;
using dnnl::stream;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory {
......@@ -345,7 +349,7 @@ inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx,
/* XT: input x data type, YT: input y data type */
template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
......@@ -371,17 +375,175 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
}
};
template <typename XT, typename YT>
class MulMKLDNNKernel : public framework::OpKernel<XT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
protected:
void ExecuteMatMul(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine &onednn_engine,
const platform::Place &cpu_place, const Tensor *x,
const std::vector<int64_t> &x_dims, bool trans_x,
const Tensor *y, const std::vector<int64_t> &y_dims,
bool trans_y, Tensor *out) const {
static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y, false,
vec_placeholder, vec_placeholder);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
// plain output formats are enforced inside handler
out->set_format(platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw));
}
private:
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y;
// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> x_dims(3, 1);
y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];
x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];
ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(), &x_matrix,
x_dims, false, &y_matrix, y_dims, false, out);
}
};
template <typename XT, typename YT>
class MulGradMKLDNNKernel : public MulMKLDNNKernel<XT, YT> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
private:
template <typename OT = XT>
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<LoDTensor>("X");
const auto *y = ctx.Input<LoDTensor>("Y");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
const Tensor x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor &>(*x);
const Tensor y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor &>(*y);
Tensor dout_matrix = *dout;
dout_matrix.Resize(
{framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
// adding mb dim because MatMulV2 handler needs it
std::vector<int64_t> x_dims(3, 1);
std::vector<int64_t> y_dims(3, 1);
std::vector<int64_t> dout_dims(3, 1);
x_dims[1] = x_matrix.dims()[0];
x_dims[2] = x_matrix.dims()[1];
y_dims[1] = y_matrix.dims()[0];
y_dims[2] = y_matrix.dims()[1];
dout_dims[1] = dout_matrix.dims()[0];
dout_dims[2] = dout_matrix.dims()[1];
if (dx != nullptr) {
dx->set_lod(x->lod());
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(),
&dout_matrix, dout_dims, false, &y_matrix, y_dims,
true, static_cast<Tensor *>(dx));
}
if (dy != nullptr) {
dy->set_lod(y->lod());
this->ExecuteMatMul(ctx, dev_ctx, onednn_engine, ctx.GetPlace(),
&x_matrix, x_dims, true, &dout_matrix, dout_dims,
false, static_cast<Tensor *>(dy));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
U8, ops::kMULMKLDNNINT8,
ops::MulMKLDNNKernel<uint8_t, float>);
ops::MulMKLDNNINT8Kernel<uint8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
S8, ops::kMULMKLDNNINT8,
ops::MulMKLDNNKernel<int8_t, float>);
ops::MulMKLDNNINT8Kernel<int8_t, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace,
FP32, ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kMULMKLDNNFP32,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
ops::MulMKLDNNKernel<uint8_t, float>);
ops::MulMKLDNNINT8Kernel<uint8_t, float>,
ops::MulMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
mul_grad, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kMULMKLDNNFP32,
ops::MulGradMKLDNNKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>,
ops::MulGradMKLDNNKernel<float, float>);
......@@ -113,6 +113,12 @@ class MulOp : public framework::OperatorWithKernel {
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif
......@@ -233,6 +239,36 @@ class MulGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
if (input_data_type == framework::DataTypeTrait<int8_t>::DataType() ||
input_data_type == framework::DataTypeTrait<uint8_t>::DataType()) {
customized_type_value = kMULMKLDNNINT8;
} else if (input_data_type ==
framework::DataTypeTrait<
paddle::platform::bfloat16>::DataType() ||
input_data_type ==
framework::DataTypeTrait<float>::DataType()) {
customized_type_value = kMULMKLDNNFP32;
}
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};
template <typename T>
......
......@@ -25,6 +25,7 @@ namespace operators {
using Tensor = framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
template <typename DeviceContext, typename T>
class MulKernel : public framework::OpKernel<T> {
......
......@@ -192,7 +192,6 @@ register_unity_group(cc
pad_op.cc)
register_unity_group(cc
modified_huber_loss_op.cc
mkldnn/mul_mkldnn_op.cc
partial_sum_op.cc
pixel_shuffle_op.cc
pool_op.cc
......
......@@ -772,6 +772,114 @@ class ReductionMKLDNNHandler
}
};
template <typename T>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
if (is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data));
}
};
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
......
......@@ -83,7 +83,7 @@ class AutoMixedPrecisionListsBF16(object):
bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16
bf16_list = {'elementwise_add', }
bf16_list = {'elementwise_add', 'mul'}
# depends on the prev_op type
gray_list = {
......
......@@ -37,6 +37,15 @@ def convert_uint16_to_float(in_list):
return numpy.reshape(out, in_list.shape)
def convert_float_to_uint16(in_list):
out = []
for x in numpy.nditer(in_list):
out.append(
numpy.uint16(struct.unpack('<I', struct.pack('<f', x))[0] >> 16))
out = numpy.reshape(out, in_list.shape).view(numpy.uint16)
return out
def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
......@@ -158,6 +167,10 @@ def infer(use_cuda, save_dirname=None, use_bf16=False):
test_data = next(test_reader())
test_feat = numpy.array(
[data[0] for data in test_data]).astype("float32")
if use_bf16:
test_feat = convert_float_to_uint16(test_feat)
test_label = numpy.array(
[data[1] for data in test_data]).astype("float32")
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
'''
......@@ -159,4 +160,5 @@ class TestMKLDNNMulOpS8U8WithFlatten(TestMKLDNNMulOpS8S8WithFlatten):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from numpy.matrixlib import defmatrix
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, OpTestTool
@OpTestTool.skip_if_not_cpu_bf16()
class TestMulOneDNNOp(OpTest):
def setUp(self):
self.op_type = "mul"
self.attrs = {'use_mkldnn': True}
self.init_shapes_and_attrs()
self.x_fp32 = np.random.random(self.x_shape).astype(np.float32)
self.y_fp32 = np.random.random(self.y_shape).astype(np.float32)
self.x = self.x_fp32
self.y = self.y_fp32
self.init_inputs_dtype()
self.inputs = {'X': self.x, 'Y': self.y}
output = np.dot(
np.reshape(self.x_fp32, self.np_x_shape),
np.reshape(self.y_fp32, self.np_y_shape))
self.outputs = {'Out': np.reshape(output, self.out_shape)}
def init_shapes_and_attrs(self):
self.x_shape = (20, 5)
self.y_shape = (5, 21)
self.np_x_shape = (20, 5)
self.np_y_shape = (5, 21)
self.out_shape = (20, 21)
def init_inputs_dtype(self):
pass
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad_with_place(core.CPUPlace(), ['Y'], 'Out', set('X'))
def test_check_grad_ingore_y(self):
self.check_grad_with_place(core.CPUPlace(), ['X'], 'Out', set('Y'))
class TestMulXNumColDims2OneDNNOp(TestMulOneDNNOp):
def init_shapes_and_attrs(self):
self.x_shape = (6, 7, 5)
self.y_shape = (5, 21)
self.np_x_shape = (42, 5)
self.np_y_shape = (5, 21)
self.out_shape = (6, 7, 21)
self.attrs["x_num_col_dims"] = 2
class TestMulYNumColDims2OneDNNOp(TestMulOneDNNOp):
def init_shapes_and_attrs(self):
self.x_shape = (20, 6)
self.y_shape = (2, 3, 21)
self.np_x_shape = (20, 6)
self.np_y_shape = (6, 21)
self.out_shape = (20, 21)
self.attrs["y_num_col_dims"] = 2
class TestMulYAndXNumColDims2OneDNNOp(TestMulOneDNNOp):
def init_shapes_and_attrs(self):
self.x_shape = (10, 5, 6)
self.y_shape = (2, 3, 21)
self.np_x_shape = (50, 6)
self.np_y_shape = (6, 21)
self.out_shape = (10, 5, 21)
self.attrs["x_num_col_dims"] = 2
self.attrs["y_num_col_dims"] = 2
class TestMulBF16OneDNNOp(TestMulOneDNNOp):
def init_inputs_dtype(self):
self.x = convert_float_to_uint16(self.x)
self.y = convert_float_to_uint16(self.y)
def calculate_grads(self):
x_np = np.reshape(self.x_fp32, self.np_x_shape)
y_np = np.reshape(self.y_fp32, self.np_y_shape)
self.dout = self.outputs['Out']
self.dout_np = np.reshape(self.dout, (x_np.shape[0], y_np.shape[1]))
y_np_trans = np.transpose(y_np, (1, 0))
x_np_trans = np.transpose(x_np, (1, 0))
self.dx = np.matmul(self.dout_np, y_np_trans)
self.dy = np.matmul(x_np_trans, self.dout_np)
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ['X', 'Y'],
'Out',
user_defined_grads=[self.dx, self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def test_check_grad_ingore_x(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ['Y'],
'Out',
set('X'),
user_defined_grads=[self.dy],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
def test_check_grad_ingore_y(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ['X'],
'Out',
set('Y'),
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册