提交 5f0422f4 编写于 作者: J Jacek Czaja

- Added softmax without caching

上级 09892118
......@@ -32,25 +32,14 @@ using platform::to_void_cast;
template <typename T>
class SoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
: public platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward> {
public:
SoftmaxMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
SoftmaxMKLDNNHandler(const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
Tensor* output, const int axis,
const std::string uniq_name, bool is_inplaced)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(input->dims()),
uniq_name)) {
if (!this->isCached()) {
Tensor* output, const int axis)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward, mkldnn::softmax_backward>(
mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
platform::errors::InvalidArgument(
......@@ -60,22 +49,17 @@ class SoftmaxMKLDNNHandler
auto md = memory::desc(softmax_tz, platform::MKLDNNGetDataType<T>(),
input->format());
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
axis);
}
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, axis);
}
SoftmaxMKLDNNHandler(const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(out->dims()),
unique_name)) {
if (!this->isBwdCached()) {
dev_ctx, mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(
out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input "
......@@ -95,7 +79,6 @@ class SoftmaxMKLDNNHandler
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis);
}
}
};
template <typename T>
......@@ -111,9 +94,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), input->dims().size());
SoftmaxMKLDNNHandler<T> handler(dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, output, axis, ctx.OutputName("Out"),
is_inplaced);
SoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), input, output, axis);
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
// For Inplace src and and dst are the same memory object
......@@ -149,11 +130,12 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL SoftmaxGrad must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* output = ctx.Input<Tensor>("Out");
auto* out_grad = ctx.template Input<Tensor>(framework::GradVarName("Out"));
auto* in_x_grad = ctx.template Output<Tensor>(framework::GradVarName("X"));
SoftmaxMKLDNNHandler<T> handler(ctx, dev_ctx, ctx.GetPlace(), output,
SoftmaxMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), output,
out_grad, in_x_grad, ctx.InputName("Out"));
auto dst_memory_p = handler.AcquireDstMemory(output);
......
......@@ -34,6 +34,219 @@ using framework::Tensor;
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerNoCachingT {
public:
MKLDNNHandlerNoCachingT(mkldnn::engine engine, platform::Place cpu_place)
: engine_(engine),
place_(cpu_place),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
return forward_p = std::make_shared<TForward>(*fwd_pd_);
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return backward_p = std::make_shared<TBackward>(*bwd_pd_);
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
return std::make_shared<TBackward_params>(*bwd_w_pd_);
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
fwd_pd_->src_desc(), to_void_cast<T>(input_data));
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr);
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc());
}
template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data));
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr));
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) {
T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_desc(), ptr);
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr);
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc());
}
protected:
// If your primitive descriptor requires attributes, pass them as a
// first argument and paramters to descriptor constructor in the following
// arguments. Otherwise, all arguments will be forwarded to descriptor
// constructor, including the first one.
template <typename Arg, typename... Args>
void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) {
CreateForwardPrimitiveDescriptor(first_arg, std::forward<Args>(args)...);
}
// Using sfinae to specialise variadic function. Workaround for not having
// if constexpr in C++ 11.
template <class First, class... Args>
typename std::enable_if<std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<Args>(args)...);
fwd_pd_ = std::make_shared<typename TForward::primitive_desc>(
fwd_desc, first, engine_);
}
template <class First, class... Args>
typename std::enable_if<!std::is_same<typename std::decay<First>::type,
dnnl::primitive_attr>::value>::type
CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) {
auto fwd_desc = typename TForward::desc(std::forward<First>(first),
std::forward<Args>(args)...);
fwd_pd_ =
std::make_shared<typename TForward::primitive_desc>(fwd_desc, engine_);
}
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed."));
auto bwd_desc = typename TBackward::desc(std::forward<Args>(args)...);
bwd_pd_ = std::make_shared<typename TBackward::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptor(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptor
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed."));
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr) {
return std::make_shared<mkldnn::memory>(md, engine_, ptr);
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md) {
return std::make_shared<mkldnn::memory>(md, engine_);
}
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p) {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p;
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
return target_memory_p;
}
mkldnn::engine engine_;
platform::Place place_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册