未验证 提交 9ec1432d 编写于 作者: F Feiyu Chan 提交者: GitHub

disable copying of datatype when sharing buffer between two tensors. (#37247)

* disable copying of datatype when sharing buffer between two tensors.
* fix for mkldnn operator kernels (elementwise_add, sum, softplus, softmax, scale, activation), mannually set the data type when reusing memory by ShareBufferWith.
上级 693c3c14
...@@ -254,7 +254,10 @@ class Tensor { ...@@ -254,7 +254,10 @@ class Tensor {
void ShareBufferWith(const Tensor& tensor) { void ShareBufferWith(const Tensor& tensor) {
holder_ = tensor.holder_; holder_ = tensor.holder_;
offset_ = tensor.offset_; offset_ = tensor.offset_;
type_ = tensor.type_; // NOTE(chenfeiyu): when sharing buffer, by definition only holder
// to the memory allocation and offset should be shared. Shape,
// data type, layout, and other metadata associated with a Tensor
// should not be copied.
} }
bool IsSharedBufferWith(const Tensor& src) const { bool IsSharedBufferWith(const Tensor& src) const {
......
...@@ -62,9 +62,22 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,22 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
// and they share a buffer (of // and they share a buffer (of
// shape x) then this buffer is not big enough to hold result of elementwise // shape x) then this buffer is not big enough to hold result of elementwise
// operation. // operation.
auto dst_memory = (x->numel() == z->numel() && x->IsSharedBufferWith(*z)) const bool reuse_x_memopry =
? src_x_memory x->numel() == z->numel() && x->IsSharedBufferWith(*z);
: handler.AcquireDstMemory(z); std::shared_ptr<dnnl::memory> dst_memory = nullptr;
if (reuse_x_memopry) {
dst_memory = src_x_memory;
// NOTE(chenfeiyu): when the output reuses memory from other tensor rather
// than allocate its own, it's still need to take care of its data type.
// Unfortunately, paddle's operator only infers the output' shape, but not
// the data type. mutable_data<T> takes care of allocation and data type
// normally, but if the memory is already allocated and there is no need
// to re-allocate, it just set the data type. So this it added there to
// get the right data type.
z->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory = handler.AcquireDstMemory(z);
}
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
......
...@@ -91,7 +91,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -91,7 +91,13 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
ctx.GetPlace(), x); ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y); std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
y->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(y);
}
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -41,8 +41,13 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -41,8 +41,13 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
x); x);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out); if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -103,9 +103,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -103,9 +103,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = std::shared_ptr<dnnl::memory> softmax_dst_memory_p = nullptr;
is_inplaced ? softmax_src_memory_p : handler.AcquireDstMemory(output); if (is_inplaced) {
softmax_dst_memory_p = softmax_src_memory_p;
output->mutable_data<T>(ctx.GetPlace());
} else {
softmax_dst_memory_p = handler.AcquireDstMemory(output);
}
auto softmax_p = handler.AcquireForwardPrimitive(); auto softmax_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
...@@ -43,7 +44,7 @@ class SoftplusMKLDNNHandler ...@@ -43,7 +44,7 @@ class SoftplusMKLDNNHandler
1.0f / beta, 0.0f); 1.0f / beta, 0.0f);
} }
AppendFusedActivationIfExists(ctx, post_ops); AppendFusedActivationIfExists(ctx, &post_ops);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
attrs.set_post_ops(post_ops); attrs.set_post_ops(post_ops);
...@@ -59,16 +60,16 @@ class SoftplusMKLDNNHandler ...@@ -59,16 +60,16 @@ class SoftplusMKLDNNHandler
private: private:
void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx, void AppendFusedActivationIfExists(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops) { dnnl::post_ops* post_ops) {
const auto& fused_activation_type = const auto& fused_activation_type =
algo_map.find(ctx.Attr<std::string>("fuse_activation_type")); algo_map.find(ctx.Attr<std::string>("fuse_activation_type"));
if (fused_activation_type != algo_map.end()) { if (fused_activation_type != algo_map.end()) {
auto scale_out = auto scale_out =
ctx.Attr<float>("fuse_activation_scale"); // for future int8 support ctx.Attr<float>("fuse_activation_scale"); // for future int8 support
post_ops.append_eltwise(scale_out, fused_activation_type->second, post_ops->append_eltwise(scale_out, fused_activation_type->second,
ctx.Attr<float>("fuse_activation_alpha"), ctx.Attr<float>("fuse_activation_alpha"),
ctx.Attr<float>("fuse_activation_beta")); ctx.Attr<float>("fuse_activation_beta"));
} }
} }
...@@ -109,8 +110,13 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) { ...@@ -109,8 +110,13 @@ void custom_softplus_eltwise_forward(const framework::ExecutionContext& ctx) {
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto beta_memory_p = handler.AcquireBetaMemory(&beta); auto beta_memory_p = handler.AcquireBetaMemory(&beta);
auto dst_memory_p = std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out); if (is_inplaced) {
dst_memory_p = src_memory_p;
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(out);
}
auto binary_p = handler.AcquireForwardPrimitive(); auto binary_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -137,8 +137,13 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -137,8 +137,13 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
++input_index; ++input_index;
} }
auto dst_mem = in_place ? handler.AcquireDstMemory() std::shared_ptr<dnnl::memory> dst_mem = nullptr;
: handler.AcquireDstMemory(output); if (in_place) {
dst_mem = handler.AcquireDstMemory();
output->mutable_data<T>(ctx.GetPlace());
} else {
dst_mem = handler.AcquireDstMemory(output);
}
auto sum_p = handler.AcquireForwardPrimitive(); auto sum_p = handler.AcquireForwardPrimitive();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册