未验证 提交 9ec1432d 编写于 作者: F Feiyu Chan 提交者: GitHub

disable copying of datatype when sharing buffer between two tensors. (#37247)

* disable copying of datatype when sharing buffer between two tensors.
* fix for mkldnn operator kernels (elementwise_add, sum, softplus, softmax, scale, activation), mannually set the data type when reusing memory by ShareBufferWith.
上级 693c3c14
......@@ -254,7 +254,10 @@ class Tensor {
void ShareBufferWith(const Tensor& tensor) {
holder_ = tensor.holder_;
offset_ = tensor.offset_;
type_ = tensor.type_;
// NOTE(chenfeiyu): when sharing buffer, by definition only holder
// to the memory allocation and offset should be shared. Shape,
// data type, layout, and other metadata associated with a Tensor
// should not be copied.
}
bool IsSharedBufferWith(const Tensor& src) const {
......
......@@ -62,9 +62,22 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
// and they share a buffer (of
// shape x) then this buffer is not big enough to hold result of elementwise
// operation.
auto dst_memory = (x->numel() == z->numel() && x->IsSharedBufferWith(*z))
? src_x_memory
: handler.AcquireDstMemory(z);
const bool reuse_x_memopry =
x->numel() == z->numel() && x->IsSharedBufferWith(*z);
std::shared_ptr<dnnl::memory> dst_memory = nullptr;
if (reuse_x_memopry) {
dst_memory = src_x_memory;
// NOTE(chenfeiyu): when the output reuses memory from other tensor rather
// than allocate its own, it's still need to take care of its data type.
// Unfortunately, paddle's operator only infers the output' shape, but not
// the data type. mutable_data<T> takes care of allocation and data type
// normally, but if the memory is already allocated and there is no need
// to re-allocate, it just set the data type. So this it added there to
// get the right data type.
z->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory = handler.AcquireDstMemory(z);
}
const auto binary_prim = handler.AcquireForwardPrimitive();
......
......@@ -91,7 +91,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
y->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(y);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -41,8 +41,13 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p =
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -103,9 +103,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p =
is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output);
std::shared_ptr<dnnl::memory> softmax_dst_memory_p = nullptr;
if (is_inplaced) {
softmax_dst_memory_p = softmax_src_memory_p;
output->mutable_data<T>(ctx.GetPlace());
} else {
softmax_dst_memory_p = handler.AcquireDstMemory(output);
}
auto softmax_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
......@@ -43,7 +44,7 @@ class SoftplusMKLDNNHandler
1.0f / beta, 0.0f);
}
AppendFusedActivationIfExists(ctx, post_ops);
AppendFusedActivationIfExists(ctx, &post_ops);
dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops);
......@@ -59,16 +60,16 @@ class SoftplusMKLDNNHandler
private:
void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops) {
dnnl::post_ops* post_ops) {
const auto& fused_activation_type =
algo_map.find(ctx.Attr<std::string>("fuse_activation_type"));
if (fused_activation_type != algo_map.end()) {
auto scale_out =
ctx.Attr<float>("fuse_activation_scale"); // for future int8 support
post_ops.append_eltwise(scale_out, fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta"));
post_ops->append_eltwise(scale_out, fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta"));
}
}
......@@ -109,8 +110,13 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
auto src_memory_p = handler.AcquireSrcMemory(x);
auto beta_memory_p = handler.AcquireBetaMemory(&beta);
auto dst_memory_p =
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto binary_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -137,8 +137,13 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
++input_index;
}
auto dst_mem = in_place ? handler.AcquireDstMemory()
: handler.AcquireDstMemory(output);
std::shared_ptr<dnnl::memory> dst_mem = nullptr;
if (in_place) {
dst_mem = handler.AcquireDstMemory();
output->mutable_data<T>(ctx.GetPlace());
} else {
dst_mem = handler.AcquireDstMemory(output);
}
auto sum_p = handler.AcquireForwardPrimitive();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册