未验证 提交 0a5c99e8 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to issue #34554 (#34623)

* - Added softmax without caching

* - Binary is no longer manually cached

* - Activation onednn caching removed

* - Removed manual caching of activation

* - modified UT

* - fix

* - fix

* - fixes to building

* - fix

* - fix

* - fix to UT

* - Faulty UT workaround

* - approval workaround

* - Fixes after review

* - compilation fixes

* - more lint fixes

* - more fixes after review

* - fixes after another round of review
上级 99f8f5c8
......@@ -47,13 +47,24 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis");
platform::BinaryMKLDNNHandler<T> handler(
BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z,
scale_x, scale_y, scale_o, ctx.OutputName("Out"));
platform::BinaryMKLDNNHandler<T> handler(BINARY_OP, axis, mkldnn_engine,
ctx.GetPlace(), x, y, z, scale_x,
scale_y, scale_o);
const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_memory = handler.AcquireDstMemory(z);
// (jczaja) For Inplace src and dst should be the same memory object.
// So x should share buffer with z. But UT mechanics is testing inplace
// execution for this op not checking that x can be bradcasted to match in
// shape y tensor.
// This is wrong as when x is to be broadcasted then z(out) will match the
// shape of y which is bigger than x. Hence if x is smaller in shape than z
// and they share a buffer (of
// shape x) then this buffer is not big enough to hold result of elementwise
// operation.
auto dst_memory = (x->numel() == z->numel() && x->IsSharedBufferWith(*z))
? src_x_memory
: handler.AcquireDstMemory(z);
const auto binary_prim = handler.AcquireForwardPrimitive();
......
......@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
......@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
......
......@@ -79,15 +79,15 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*y);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, dev_ctx,
ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
......@@ -106,13 +106,14 @@ template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X"));
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x, diff_y);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op softmax scope device_context enforce)
cc_test(test_mkldnn_caching SRCS mkldnn/test_mkldnn_caching.cc DEPS op_registry elementwise_mul_op elementwise_add_op activation_op softmax_op conv_op im2col vol2col softmax scope device_context enforce)
......@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
......@@ -36,11 +37,12 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(),
x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto dst_memory_p =
is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
......@@ -32,69 +32,56 @@ using platform::to_void_cast;
template <typename T>
class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> {
: public platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> {
public:
SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
SoftmaxMKLDNNHandler(const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis,
const std::string uniq_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical."));
auto softmax_tz = framework::vectorize(input->dims());
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
Tensor* output, const int axis)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
"The shape of input and output tensor must be identical."));
auto softmax_tz = framework::vectorize(input->dims());
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(out->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input "
"and output must be identical."));
auto dims = out_grad->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = framework::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument(
"The shape of softmax_grad's input "
"and output must be identical, but shapes differ, "
"out_grad: %s in_grad: %s",
out_grad->dims(), in_x_grad->dims()));
auto dims = out_grad->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
auto softmax_tz = framework::vectorize<int64_t>(dims);
auto data_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out->format());
auto diff_softmax_md = MKLDNNMemDesc(
softmax_tz, platform::MKLDNNGetDataType<T>(), out_grad->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
};
......@@ -111,9 +98,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, output, axis, ctx.OutputName("Out"),
is_inplaced);
SoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), input,
output, axis);
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object
......@@ -149,11 +135,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out");
auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output,
SoftmaxMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), output,
out_grad, in_x_grad, ctx.InputName("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output);
......
......@@ -33,6 +33,8 @@ USE_OP(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(conv2d);
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);
namespace paddle {
namespace operators {
......@@ -64,16 +66,19 @@ class CacheTester {
template <typename T>
void RunOperator(const platform::Place &place, const std::string &op_type,
const framework::DDim &dims, const std::string &output_name,
bool inplace = false) {
const framework::DDim &dims, const std::string &first_input) {
framework::Scope scope;
std::map<const std::string, int> num_inputs = {{"softmax", 1},
{"relu", 1},
{"conv2d", 2},
{"elementwise_add", 2},
{"elementwise_mul", 2}};
std::string first_input = inplace == true ? output_name : "x";
std::string first_input_var_name = (op_type == "conv2d") ? "Input" : "X";
std::string second_input_var_name = (op_type == "conv2d") ? "Filter" : "Y";
std::string output_var_name = (op_type == "conv2d") ? "Output" : "Out";
std::string output_name = "output";
std::vector<InputVars> input_names = {
{first_input, scope.Var(first_input)->GetMutable<framework::LoDTensor>()},
......@@ -113,71 +118,40 @@ void RunOperator(const platform::Place &place, const std::string &op_type,
auto &pool = platform::DeviceContextPool::Instance();
auto op = num_inputs[op_type] > 1
? framework::OpRegistry::CreateOp(
op_type, {{"X", {first_input}}, {"Y", {"x1"}}},
{{"Out", {output_name}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{"X", {first_input}}}, {{"Out", {output_name}}},
{{"use_mkldnn", {true}}});
auto op =
num_inputs[op_type] > 1
? framework::OpRegistry::CreateOp(
op_type, {{first_input_var_name, {first_input}},
{second_input_var_name, {"x1"}}},
{{output_var_name, {output_name}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{first_input_var_name, {first_input}}},
{{output_var_name, {output_name}}}, {{"use_mkldnn", {true}}});
op->Run(scope, place);
pool.Get(place)->Wait();
}
TEST(test_softmax_reuse_cache, cpu_place) {
framework::DDim dims({32, 64});
TEST(test_conv2d_reuse_cache, cpu_place) {
framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out");
PADDLE_ENFORCE_EQ(ct.Analyze(4), true,
RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "conv2d", dims, "input_signal");
PADDLE_ENFORCE_EQ(ct.Analyze(9), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
"Invalid number of cached oneDNN objects"));
}
TEST(test_softmax_noreuse_cache, cpu_place) {
framework::DDim dims({32, 64});
TEST(test_conv2d_noreuse_cache, cpu_place) {
framework::DDim dims({1, 16, 32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out2");
PADDLE_ENFORCE_EQ(ct.Analyze(8), true,
RunOperator<float>(p, "conv2d", dims, "input_signal");
RunOperator<float>(p, "conv2d", dims, "input_signal2");
PADDLE_ENFORCE_EQ(ct.Analyze(18), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_softmax_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "softmax", dims, "softmax_out");
RunOperator<float>(p, "softmax", dims, "softmax_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_relu_inplace_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "relu", dims, "relu_out");
RunOperator<float>(p, "relu", dims, "relu_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(7), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
}
TEST(test_elementwise_add_reuse_cache, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
CacheTester ct;
RunOperator<float>(p, "elementwise_add", dims, "elementwise_add_out");
RunOperator<float>(p, "relu", dims, "elementwise_add_out", true);
PADDLE_ENFORCE_EQ(ct.Analyze(8), true,
platform::errors::InvalidArgument(
"Wrong number of cached oneDNN objects"));
"Invalid number of cached oneDNN objects"));
}
} // namespace operators
......
......@@ -34,6 +34,211 @@ using framework::Tensor;
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerNoCachingT {
public:
MKLDNNHandlerNoCachingT(mkldnn::engine engine, platform::Place cpu_place)
: engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
return std::make_shared<TForward>(*fwd_pd_);
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return std::make_shared<TBackward>(*bwd_pd_);
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, platform::errors::Unavailable("BWD_PD should be set when "
"getting BWD prim ."));
return std::make_shared<TBackward_params>(*bwd_w_pd_);
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(),
to_void_cast<T>(input_data));
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr);
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc());
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data));
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(),
to_void_cast<T>(ptr));
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr);
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
ptr);
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc());
}
protected:
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(fwd_pd_,
platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed."));
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(fwd_pd_,
platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed."));
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr) {
return std::make_shared<mkldnn::memory>(md, engine_, ptr);
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md) {
return std::make_shared<mkldnn::memory>(md, engine_);
}
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p;
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p = std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
return target_memory_p;
}
mkldnn::engine engine_;
platform::Place place_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
......@@ -79,7 +284,7 @@ class MKLDNNHandlerT {
std::static_pointer_cast<TBackward_params>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when "
"BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
backward_p = std::make_shared<TBackward_params>(*bwd_w_pd_);
......@@ -138,7 +343,7 @@ class MKLDNNHandlerT {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
"BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr,
......@@ -150,7 +355,7 @@ class MKLDNNHandlerT {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
"BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
"@diff_wei_mem_p");
}
......@@ -589,70 +794,70 @@ class MKLDNNHandler {
};
template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
class BinaryMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor."));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Y tensor."));
PADDLE_ENFORCE_NE(
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor."));
const auto src_x_tz = framework::vectorize(x->dims());
const auto src_y_tz = framework::vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (z == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: framework::vectorize(z->dims());
auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
auto src1_md = dnnl::memory::desc(
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
if (rankdiff > 0) { // Second input is of smaller rank than first
std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
src_y_tz.begin(), src_y_tz.end());
src1_md = src1_md.reshape(dims1_ex);
} else if (rankdiff < 0) { // First input is of smaller than second
std::vector<int64_t> dims0_ex(-rankdiff, 1);
dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)),
src_x_tz.begin(), src_x_tz.end());
src0_md = src0_md.reshape(dims0_ex);
}
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z);
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, dst_md);
float scale_x, float scale_y, float scale_z)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for X tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, x->layout()));
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for X tensor : %d (undef)",
static_cast<unsigned int>(x->format())));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Y tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, y->layout()));
PADDLE_ENFORCE_NE(y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Y tensor : %d (undef)",
static_cast<unsigned int>(y->format())));
const auto src_x_tz = framework::vectorize(x->dims());
const auto src_y_tz = framework::vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer
auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (z == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: framework::vectorize(z->dims());
auto src0_md = dnnl::memory::desc(
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
auto src1_md = dnnl::memory::desc(
src_y_tz, platform::MKLDNNGetDataType<T>(), y->format());
if (rankdiff > 0) { // Second input is of smaller rank than first
std::vector<int64_t> dims1_ex(rankdiff, 1);
dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)),
src_y_tz.begin(), src_y_tz.end());
src1_md = src1_md.reshape(dims1_ex);
} else if (rankdiff < 0) { // First input is of smaller than second
std::vector<int64_t> dims0_ex(-rankdiff, 1);
dims0_ex.insert(next(dims0_ex.begin(), (axis == -1 ? -rankdiff : axis)),
src_x_tz.begin(), src_x_tz.end());
src0_md = src0_md.reshape(dims0_ex);
}
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z);
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
dst_md);
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
to_void_cast<T>(input_data));
}
private:
......@@ -775,111 +980,95 @@ class ReductionMKLDNNHandler
template <typename T>
class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> {
: public MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> {
public:
ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const framework::Tensor* in_x,
const std::string& unique_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
algorithm, unique_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
unique_name)) {
if (!this->isCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op
if (algorithm == mkldnn::algorithm::eltwise_linear) {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
: (float)*(scale_tensor->data<T>());
beta = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) beta *= alpha;
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
const mkldnn::engine engine, Place cpu_place,
const framework::Tensor* in_x)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op
if (algorithm == mkldnn::algorithm::eltwise_linear) {
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
alpha = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
: (float)*(scale_tensor->data<T>());
beta = ctx.Attr<float>("bias");
// if bias_after_scale == true
// out = scale*X + bias
// else
// out = scale*(X + bias) = scale*X + scale*bias
if (!bias_after_scale) beta *= alpha;
} else {
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
}
PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6,
platform::errors::Unimplemented(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
in_x->dims().size()));
PADDLE_ENFORCE(in_x->dims().size() >= 1 || in_x->dims().size() <= 6,
platform::errors::Unimplemented(
"Input dimension size can be 1, 2, 3, 4, "
"5, or 6, but now the dimension size is",
in_x->dims().size()));
auto src_tz = framework::vectorize<int64_t>(in_x->dims());
auto src_fmt =
src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md = mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(),
src_fmt);
auto src_tz = framework::vectorize<int64_t>(in_x->dims());
auto src_fmt = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md =
mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta);
}
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, md, alpha, beta);
}
ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
"a", unique_name)) {
if (!this->isBwdCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
const mkldnn::engine engine, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(engine,
cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(out_grad->dims());
auto diff_dst_tz = framework::vectorize<int64_t>(out_grad->dims());
auto src_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto diff_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format();
auto src_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto diff_fmt =
diff_dst_tz.size() == 2 ? MKLDNNMemoryFormat::nc : out_grad->format();
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt);
auto dims = framework::vectorize(in_x->dims());
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
}
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta);
}
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
to_void_cast<T>(input_data),
"@bwd-src_mem_p");
to_void_cast<T>(input_data));
}
};
......@@ -1430,11 +1619,6 @@ using ConvMKLDNNHandler =
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>;
using ConvTransposeMKLDNNHandler =
ConvMKLDNNTemplateHandler<mkldnn::deconvolution_forward,
mkldnn::deconvolution_backward_data,
mkldnn::deconvolution_backward_weights>;
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册