未验证 提交 3292f0ef 编写于 作者: J Jacek Czaja 提交者: GitHub

[onednn] elementwise add broadcasting support (#24594)

上级 560c8153
......@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git)
SET(MKLDNN_TAG 589c09728e34d09d79106cba0211e93caf142d54)
SET(MKLDNN_TAG fb95345126ade4c54f5507e580a5f5da8d30a515)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
......
......@@ -100,9 +100,11 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// If broadcasting is needed, use native implementation
auto CanMKLDNNElementwiseAddBeUsed = [&]() {
return ctx.Input<Tensor>("X")->dims() == ctx.Input<Tensor>("Y")->dims();
int axis = ctx.Attr<int>("axis");
int rankdiff = ctx.Input<Tensor>("X")->dims().size() -
ctx.Input<Tensor>("Y")->dims().size();
return (axis == -1) || (axis == rankdiff);
};
if (platform::CanMKLDNNBeUsed(ctx) &&
......
......@@ -371,6 +371,13 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) {
// bradcasting combined with in-place may require longer key
auto rankdiff = x->dims().size() - y->dims().size();
if (rankdiff > 0) {
this->key_ += std::to_string(rankdiff);
this->key_common_ += std::to_string(rankdiff);
}
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
......@@ -390,17 +397,19 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
const auto src_y_tz = framework::vectorize(y->dims());
const auto dst_tz = framework::vectorize(z->dims());
// TODO(jczaja): Add function checking if data already exists
const auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
const auto src1_md = dnnl::memory::desc(
auto src1_md = dnnl::memory::desc(
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
if (rankdiff > 0) {
std::vector<int64_t> ones(rankdiff, 1);
std::vector<int64_t> dims1_ex(src_y_tz);
dims1_ex.insert(dims1_ex.begin(), ones.begin(), ones.end());
src1_md = src1_md.reshape(dims1_ex);
}
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add,
src0_md, src1_md, dst_md);
}
......@@ -410,7 +419,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
}
};
......
......@@ -49,5 +49,22 @@ class TestMKLDNNElementwiseAddOp3(TestMKLDNNElementwiseAddOp):
self.out = np.add(self.x, self.y)
class TestMKLDNNElementwiseAddOp4(TestMKLDNNElementwiseAddOp):
def init_input_output(self):
self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype)
self.out = np.add(self.x, self.y)
# TODO(jczaja): Enable when grad is ready
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册